Compare commits
140 Commits
agent-tool
...
preview-ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f0ffa0109 | ||
|
|
95d8409900 | ||
|
|
d27fdd96f2 | ||
|
|
c8e909afc6 | ||
|
|
abad6d9be9 | ||
|
|
9c50d19841 | ||
|
|
17b98d068a | ||
|
|
6386336eee | ||
|
|
c168fc335c | ||
|
|
b2df395918 | ||
|
|
2b431d3e9d | ||
|
|
4812c9094b | ||
|
|
fcef101227 | ||
|
|
7e25460708 | ||
|
|
9b37206147 | ||
|
|
756fcd0733 | ||
|
|
3fd37799b4 | ||
|
|
ab180855de | ||
|
|
2beefc8158 | ||
|
|
5092f0f18b | ||
|
|
3a212e72a4 | ||
|
|
4dc8ce8cf7 | ||
|
|
2cc5a0de26 | ||
|
|
bc665b2a76 | ||
|
|
17903a0999 | ||
|
|
5102c4c002 | ||
|
|
2139219832 | ||
|
|
9abeedf0c6 | ||
|
|
1d7c86bf0d | ||
|
|
17703310ae | ||
|
|
bbe8d6a654 | ||
|
|
bbc66748dd | ||
|
|
99df1190a9 | ||
|
|
0e477e7db9 | ||
|
|
0afb980f7b | ||
|
|
d360f77796 | ||
|
|
92b9bc599d | ||
|
|
ed367e1636 | ||
|
|
b41ffae161 | ||
|
|
ef33666701 | ||
|
|
cd86905ebe | ||
|
|
abb48b7711 | ||
|
|
8afac388bb | ||
|
|
53b36b328e | ||
|
|
ce93961fe0 | ||
|
|
e3c987e2fb | ||
|
|
4dc0551105 | ||
|
|
bf9e5b4f76 | ||
|
|
cfb8cae29c | ||
|
|
68e0105627 | ||
|
|
e98e6c7426 | ||
|
|
3a1bd38503 | ||
|
|
8a69d252f5 | ||
|
|
bf30beacc2 | ||
|
|
2a0be48875 | ||
|
|
1c4ba07b20 | ||
|
|
8a717abe0d | ||
|
|
f735c90c3f | ||
|
|
ddfeb202a3 | ||
|
|
9bd0828303 | ||
|
|
4dff47ae20 | ||
|
|
52eef3c35d | ||
|
|
f060918b57 | ||
|
|
609c528ceb | ||
|
|
6db974dd32 | ||
|
|
60ec55b179 | ||
|
|
bb7a5b13df | ||
|
|
1e47dfce79 | ||
|
|
3fdbc3090d | ||
|
|
f2b4004c00 | ||
|
|
ec5821f76d | ||
|
|
e22cae6459 | ||
|
|
21bafd7856 | ||
|
|
ee74edbbb1 | ||
|
|
d832b8e687 | ||
|
|
539f4f1576 | ||
|
|
9a325a23e5 | ||
|
|
ce31312268 | ||
|
|
d46890978a | ||
|
|
67615b968b | ||
|
|
053fafa90e | ||
|
|
d23024609f | ||
|
|
3961d87ae0 | ||
|
|
8b910e1cd9 | ||
|
|
12c645e154 | ||
|
|
cfb7a30724 | ||
|
|
7623fce4b4 | ||
|
|
7f5c874a38 | ||
|
|
8cc2ade21c | ||
|
|
c3177e6f5b | ||
|
|
c3570fbcf3 | ||
|
|
3aa313010f | ||
|
|
5f9c91d05a | ||
|
|
6692bd9f2b | ||
|
|
cc57bc7c96 | ||
|
|
c157b1c455 | ||
|
|
136e83e0b1 | ||
|
|
b28756ae3f | ||
|
|
65401d6d7b | ||
|
|
a5405fcbd7 | ||
|
|
4f9cadabf7 | ||
|
|
7443f89a2e | ||
|
|
9bee765d7f | ||
|
|
8c553ee9f0 | ||
|
|
3389327df5 | ||
|
|
f106dfca42 | ||
|
|
37fa437990 | ||
|
|
9be7bf72a4 | ||
|
|
357e38b471 | ||
|
|
ae37f3ca2e | ||
|
|
49003d8038 | ||
|
|
93862838bd | ||
|
|
c39adc5242 | ||
|
|
ebb39d9231 | ||
|
|
187f851613 | ||
|
|
a77db45865 | ||
|
|
6bb6be826d | ||
|
|
7d9a55d101 | ||
|
|
57d8397f53 | ||
|
|
17ecf94f6f | ||
|
|
d492939bed | ||
|
|
720dfee803 | ||
|
|
a98c648201 | ||
|
|
c147daae4a | ||
|
|
d3911e34de | ||
|
|
87f85f1863 | ||
|
|
1a4dab97db | ||
|
|
cd365b0cf5 | ||
|
|
58604fba86 | ||
|
|
b0609272c0 | ||
|
|
a17807d8b1 | ||
|
|
f81e65ae7c | ||
|
|
952fe34aaa | ||
|
|
f527df6fa1 | ||
|
|
b54bbebc03 | ||
|
|
8bb7a1f9e7 | ||
|
|
e70d8d4dfd | ||
|
|
ea5ce2a1a4 | ||
|
|
fd8eeb537d | ||
|
|
92f21ee39d |
11
.github/workflows/eval.yml
vendored
11
.github/workflows/eval.yml
vendored
@@ -2,7 +2,7 @@ name: Run Agent Eval
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 * * * *"
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
@@ -25,6 +25,15 @@ env:
|
||||
ZED_EVAL_TELEMETRY: 1
|
||||
|
||||
jobs:
|
||||
# This is a no-op job that we run to prevent GitHub from marking the workflow
|
||||
# as failed for PRs that don't have the `run-eval` label.
|
||||
noop:
|
||||
name: No-op
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: No-op
|
||||
run: echo "Nothing to do"
|
||||
|
||||
run_eval:
|
||||
timeout-minutes: 60
|
||||
name: Run Agent Eval
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -23,6 +23,7 @@
|
||||
/crates/theme/schemas/theme.json
|
||||
/crates/zed/resources/flatpak/flatpak-cargo-sources.json
|
||||
/dev.zed.Zed*.json
|
||||
/node_modules/
|
||||
/plugins/bin
|
||||
/script/node_modules
|
||||
/snap
|
||||
@@ -32,4 +33,5 @@ Packages
|
||||
xcuserdata/
|
||||
|
||||
# Don't commit any secrets to the repo.
|
||||
.env
|
||||
.env.secret.toml
|
||||
|
||||
5
.rules
5
.rules
@@ -119,3 +119,8 @@ GPUI has had some changes to its APIs. Always write code using the new APIs:
|
||||
* Use `App` references. This replaces `AppContext` which no longer exists and should NEVER be used.
|
||||
* Use `Context<T>` references. This replaces `ModelContext<T>` which no longer exists and should NEVER be used.
|
||||
* `Window` is now passed around explicitly. The new interface adds a `Window` reference parameter to some methods, and adds some new "*_in" methods for plumbing `Window`. The old types `WindowContext` and `ViewContext<T>` should NEVER be used.
|
||||
|
||||
|
||||
## General guidelines
|
||||
|
||||
- Use `./script/clippy` instead of `cargo clippy`
|
||||
|
||||
211
Cargo.lock
generated
211
Cargo.lock
generated
@@ -61,7 +61,6 @@ dependencies = [
|
||||
"buffer_diff",
|
||||
"chrono",
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
@@ -95,13 +94,15 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"paths",
|
||||
"picker",
|
||||
"postage",
|
||||
"project",
|
||||
"prompt_library",
|
||||
"prompt_store",
|
||||
"proto",
|
||||
"rand 0.8.5",
|
||||
"ref-cast",
|
||||
"release_channel",
|
||||
"rope",
|
||||
"rules_library",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -446,6 +447,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"shlex",
|
||||
"smol",
|
||||
"tempfile",
|
||||
"util",
|
||||
@@ -498,11 +500,11 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_library",
|
||||
"prompt_store",
|
||||
"proto",
|
||||
"rand 0.8.5",
|
||||
"rope",
|
||||
"rules_library",
|
||||
"schemars",
|
||||
"search",
|
||||
"serde",
|
||||
@@ -735,7 +737,6 @@ dependencies = [
|
||||
"web_search",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"worktree",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
@@ -3158,6 +3159,7 @@ dependencies = [
|
||||
"go_to_line",
|
||||
"gpui",
|
||||
"language",
|
||||
"log",
|
||||
"menu",
|
||||
"picker",
|
||||
"postage",
|
||||
@@ -3201,18 +3203,23 @@ dependencies = [
|
||||
name = "component_preview"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
"db",
|
||||
"gpui",
|
||||
"languages",
|
||||
"log",
|
||||
"notifications",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"serde",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -4546,6 +4553,12 @@ dependencies = [
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dotenv"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
||||
|
||||
[[package]]
|
||||
name = "dotenvy"
|
||||
version = "0.15.7"
|
||||
@@ -4967,7 +4980,8 @@ dependencies = [
|
||||
"client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"dirs 5.0.1",
|
||||
"dirs 4.0.0",
|
||||
"dotenv",
|
||||
"env_logger 0.11.8",
|
||||
"extension",
|
||||
"fs",
|
||||
@@ -4981,6 +4995,7 @@ dependencies = [
|
||||
"language_models",
|
||||
"languages",
|
||||
"node_runtime",
|
||||
"pathdiff",
|
||||
"paths",
|
||||
"project",
|
||||
"prompt_store",
|
||||
@@ -7100,6 +7115,7 @@ dependencies = [
|
||||
"editor",
|
||||
"file_icons",
|
||||
"gpui",
|
||||
"log",
|
||||
"project",
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -8234,7 +8250,7 @@ dependencies = [
|
||||
"prost 0.9.0",
|
||||
"prost-build 0.9.0",
|
||||
"prost-types 0.9.0",
|
||||
"reqwest 0.12.8",
|
||||
"reqwest 0.12.15",
|
||||
"serde",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -10092,7 +10108,7 @@ name = "perplexity"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"zed_extension_api 0.4.0",
|
||||
"zed_extension_api 0.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -11083,32 +11099,6 @@ dependencies = [
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt_library"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"editor",
|
||||
"gpui",
|
||||
"language",
|
||||
"language_model",
|
||||
"log",
|
||||
"menu",
|
||||
"picker",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"rope",
|
||||
"serde",
|
||||
"settings",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt_store"
|
||||
version = "0.1.0"
|
||||
@@ -11129,6 +11119,7 @@ dependencies = [
|
||||
"paths",
|
||||
"rope",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"text",
|
||||
"util",
|
||||
"uuid",
|
||||
@@ -11742,6 +11733,26 @@ dependencies = [
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
|
||||
dependencies = [
|
||||
"ref-cast-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast-impl"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "refineable"
|
||||
version = "0.1.0"
|
||||
@@ -12013,8 +12024,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.12.8"
|
||||
source = "git+https://github.com/zed-industries/reqwest.git?rev=fd110f6998da16bbca97b6dddda9be7827c50e29#fd110f6998da16bbca97b6dddda9be7827c50e29"
|
||||
version = "0.12.15"
|
||||
source = "git+https://github.com/zed-industries/reqwest.git?rev=951c770a32f1998d6e999cef3e59e0013e6c4415#951c770a32f1998d6e999cef3e59e0013e6c4415"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes 1.10.1",
|
||||
@@ -12049,13 +12060,14 @@ dependencies = [
|
||||
"tokio-rustls 0.26.2",
|
||||
"tokio-socks",
|
||||
"tokio-util",
|
||||
"tower 0.5.2",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"windows-registry 0.2.0",
|
||||
"windows-registry 0.4.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -12070,7 +12082,7 @@ dependencies = [
|
||||
"http_client_tls",
|
||||
"log",
|
||||
"regex",
|
||||
"reqwest 0.12.8",
|
||||
"reqwest 0.12.15",
|
||||
"serde",
|
||||
"smol",
|
||||
"tokio",
|
||||
@@ -12271,6 +12283,32 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rules_library"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"editor",
|
||||
"gpui",
|
||||
"language",
|
||||
"language_model",
|
||||
"log",
|
||||
"menu",
|
||||
"picker",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"rope",
|
||||
"serde",
|
||||
"settings",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "runtimelib"
|
||||
version = "0.25.0"
|
||||
@@ -14281,6 +14319,7 @@ dependencies = [
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"language",
|
||||
"menu",
|
||||
@@ -14290,6 +14329,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
@@ -15101,6 +15141,11 @@ version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
@@ -17187,13 +17232,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "windows-registry"
|
||||
version = "0.2.0"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
|
||||
checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
|
||||
dependencies = [
|
||||
"windows-result 0.2.0",
|
||||
"windows-strings 0.1.0",
|
||||
"windows-targets 0.52.6",
|
||||
"windows-result 0.3.2",
|
||||
"windows-strings 0.3.1",
|
||||
"windows-targets 0.53.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -17244,6 +17289,15 @@ dependencies = [
|
||||
"windows-targets 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-strings"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-strings"
|
||||
version = "0.4.0"
|
||||
@@ -17328,13 +17382,29 @@ dependencies = [
|
||||
"windows_aarch64_gnullvm 0.52.6",
|
||||
"windows_aarch64_msvc 0.52.6",
|
||||
"windows_i686_gnu 0.52.6",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_gnullvm 0.52.6",
|
||||
"windows_i686_msvc 0.52.6",
|
||||
"windows_x86_64_gnu 0.52.6",
|
||||
"windows_x86_64_gnullvm 0.52.6",
|
||||
"windows_x86_64_msvc 0.52.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.53.0",
|
||||
"windows_aarch64_msvc 0.53.0",
|
||||
"windows_i686_gnu 0.53.0",
|
||||
"windows_i686_gnullvm 0.53.0",
|
||||
"windows_i686_msvc 0.53.0",
|
||||
"windows_x86_64_gnu 0.53.0",
|
||||
"windows_x86_64_gnullvm 0.53.0",
|
||||
"windows_x86_64_msvc 0.53.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
@@ -17353,6 +17423,12 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
@@ -17371,6 +17447,12 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
@@ -17389,12 +17471,24 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
@@ -17413,6 +17507,12 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
@@ -17431,6 +17531,12 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
@@ -17449,6 +17555,12 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
@@ -17467,6 +17579,12 @@ version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.6"
|
||||
@@ -17977,6 +18095,7 @@ dependencies = [
|
||||
"subtle",
|
||||
"syn 1.0.109",
|
||||
"syn 2.0.100",
|
||||
"sync_wrapper 1.0.2",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"time-macros",
|
||||
@@ -17987,6 +18106,7 @@ dependencies = [
|
||||
"tokio-util",
|
||||
"toml_datetime",
|
||||
"toml_edit",
|
||||
"tower 0.5.2",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tungstenite 0.26.2",
|
||||
@@ -18372,7 +18492,6 @@ dependencies = [
|
||||
"collab_ui",
|
||||
"collections",
|
||||
"command_palette",
|
||||
"command_palette_hooks",
|
||||
"component_preview",
|
||||
"copilot",
|
||||
"dap",
|
||||
@@ -18513,7 +18632,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_extension_api"
|
||||
version = "0.4.0"
|
||||
version = "0.5.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -18573,7 +18692,7 @@ dependencies = [
|
||||
name = "zed_test_extension"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.4.0",
|
||||
"zed_extension_api 0.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
46
Cargo.toml
46
Cargo.toml
@@ -39,9 +39,9 @@ members = [
|
||||
"crates/credentials_provider",
|
||||
"crates/dap",
|
||||
"crates/dap_adapters",
|
||||
"crates/db",
|
||||
"crates/debugger_tools",
|
||||
"crates/debugger_ui",
|
||||
"crates/db",
|
||||
"crates/deepseek",
|
||||
"crates/diagnostics",
|
||||
"crates/docs_preprocessor",
|
||||
@@ -109,7 +109,6 @@ members = [
|
||||
"crates/project",
|
||||
"crates/project_panel",
|
||||
"crates/project_symbols",
|
||||
"crates/prompt_library",
|
||||
"crates/prompt_store",
|
||||
"crates/proto",
|
||||
"crates/recent_projects",
|
||||
@@ -123,6 +122,7 @@ members = [
|
||||
"crates/rich_text",
|
||||
"crates/rope",
|
||||
"crates/rpc",
|
||||
"crates/rules_library",
|
||||
"crates/schema_generator",
|
||||
"crates/search",
|
||||
"crates/semantic_index",
|
||||
@@ -229,6 +229,7 @@ auto_update_ui = { path = "crates/auto_update_ui" }
|
||||
aws_http_client = { path = "crates/aws_http_client" }
|
||||
bedrock = { path = "crates/bedrock" }
|
||||
breadcrumbs = { path = "crates/breadcrumbs" }
|
||||
buffer_diff = { path = "crates/buffer_diff" }
|
||||
call = { path = "crates/call" }
|
||||
channel = { path = "crates/channel" }
|
||||
cli = { path = "crates/cli" }
|
||||
@@ -248,11 +249,10 @@ credentials_provider = { path = "crates/credentials_provider" }
|
||||
dap = { path = "crates/dap" }
|
||||
dap_adapters = { path = "crates/dap_adapters" }
|
||||
db = { path = "crates/db" }
|
||||
debugger_ui = { path = "crates/debugger_ui" }
|
||||
debugger_tools = { path = "crates/debugger_tools" }
|
||||
debugger_ui = { path = "crates/debugger_ui" }
|
||||
deepseek = { path = "crates/deepseek" }
|
||||
diagnostics = { path = "crates/diagnostics" }
|
||||
buffer_diff = { path = "crates/buffer_diff" }
|
||||
editor = { path = "crates/editor" }
|
||||
extension = { path = "crates/extension" }
|
||||
extension_host = { path = "crates/extension_host" }
|
||||
@@ -296,7 +296,6 @@ livekit_api = { path = "crates/livekit_api" }
|
||||
livekit_client = { path = "crates/livekit_client" }
|
||||
lmstudio = { path = "crates/lmstudio" }
|
||||
lsp = { path = "crates/lsp" }
|
||||
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
|
||||
markdown = { path = "crates/markdown" }
|
||||
markdown_preview = { path = "crates/markdown_preview" }
|
||||
media = { path = "crates/media" }
|
||||
@@ -310,8 +309,8 @@ ollama = { path = "crates/ollama" }
|
||||
open_ai = { path = "crates/open_ai" }
|
||||
outline = { path = "crates/outline" }
|
||||
outline_panel = { path = "crates/outline_panel" }
|
||||
paths = { path = "crates/paths" }
|
||||
panel = { path = "crates/panel" }
|
||||
paths = { path = "crates/paths" }
|
||||
picker = { path = "crates/picker" }
|
||||
plugin = { path = "crates/plugin" }
|
||||
plugin_macros = { path = "crates/plugin_macros" }
|
||||
@@ -319,7 +318,6 @@ prettier = { path = "crates/prettier" }
|
||||
project = { path = "crates/project" }
|
||||
project_panel = { path = "crates/project_panel" }
|
||||
project_symbols = { path = "crates/project_symbols" }
|
||||
prompt_library = { path = "crates/prompt_library" }
|
||||
prompt_store = { path = "crates/prompt_store" }
|
||||
proto = { path = "crates/proto" }
|
||||
recent_projects = { path = "crates/recent_projects" }
|
||||
@@ -332,6 +330,7 @@ reqwest_client = { path = "crates/reqwest_client" }
|
||||
rich_text = { path = "crates/rich_text" }
|
||||
rope = { path = "crates/rope" }
|
||||
rpc = { path = "crates/rpc" }
|
||||
rules_library = { path = "crates/rules_library" }
|
||||
search = { path = "crates/search" }
|
||||
semantic_index = { path = "crates/semantic_index" }
|
||||
semantic_version = { path = "crates/semantic_version" }
|
||||
@@ -418,7 +417,6 @@ bitflags = "2.6.0"
|
||||
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
|
||||
blade-macros = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
|
||||
blade-util = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
|
||||
naga = { version = "23.1.0", features = ["wgsl-in"] }
|
||||
blake3 = "1.5.3"
|
||||
bytes = "1.0"
|
||||
cargo_metadata = "0.19"
|
||||
@@ -428,15 +426,16 @@ circular-buffer = "1.0"
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
cocoa = "0.26"
|
||||
cocoa-foundation = "0.2.0"
|
||||
core-video = { version = "0.4.3", features = ["metal"] }
|
||||
convert_case = "0.8.0"
|
||||
core-foundation = "0.10.0"
|
||||
core-foundation-sys = "0.8.6"
|
||||
core-video = { version = "0.4.3", features = ["metal"] }
|
||||
ctor = "0.4.0"
|
||||
dashmap = "6.0"
|
||||
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "be69a016ba710191b9fdded28c8b042af4b617f7" }
|
||||
dashmap = "6.0"
|
||||
derive_more = "0.99.17"
|
||||
dirs = "4.0"
|
||||
dotenv = "0.15.0"
|
||||
ec4rs = "1.1"
|
||||
emojis = "0.6.1"
|
||||
env_logger = "0.11"
|
||||
@@ -453,8 +452,8 @@ heck = "0.5"
|
||||
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
|
||||
hex = "0.4.3"
|
||||
html5ever = "0.27.0"
|
||||
hyper = "0.14"
|
||||
http = "1.1"
|
||||
hyper = "0.14"
|
||||
ignore = "0.4.22"
|
||||
image = "0.25.1"
|
||||
imara-diff = "0.1.8"
|
||||
@@ -470,24 +469,27 @@ libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
|
||||
linkify = "0.10.0"
|
||||
linkme = "0.3.31"
|
||||
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
|
||||
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
|
||||
markup5ever_rcdom = "0.3.0"
|
||||
metal = "0.29"
|
||||
mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
|
||||
naga = { version = "23.1.0", features = ["wgsl-in"] }
|
||||
nanoid = "0.4"
|
||||
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
|
||||
nix = "0.29"
|
||||
num-format = "0.4.4"
|
||||
objc = "0.2"
|
||||
open = "5.0.0"
|
||||
num-format = "0.4.4"
|
||||
ordered-float = "2.1.1"
|
||||
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
||||
parking_lot = "0.12.1"
|
||||
partial-json-fixer = "0.5.3"
|
||||
pathdiff = "0.2"
|
||||
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
postage = { version = "0.5", features = ["futures-traits"] }
|
||||
@@ -501,9 +503,10 @@ pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
quote = "1.0.9"
|
||||
rand = "0.8.5"
|
||||
rayon = "1.8"
|
||||
ref-cast = "1.0.24"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "fd110f6998da16bbca97b6dddda9be7827c50e29", default-features = false, features = [
|
||||
reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c770a32f1998d6e999cef3e59e0013e6c4415", default-features = false, features = [
|
||||
"charset",
|
||||
"http2",
|
||||
"macos-system-configuration",
|
||||
@@ -515,8 +518,8 @@ rsa = "0.9.6"
|
||||
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
] }
|
||||
rustc-demangle = "0.1.23"
|
||||
rust-embed = { version = "8.4", features = ["include-exclude"] }
|
||||
rustc-demangle = "0.1.23"
|
||||
rustc-hash = "2.1.0"
|
||||
rustls = { version = "0.23.26" }
|
||||
rustls-platform-verifier = "0.5.0"
|
||||
@@ -558,15 +561,16 @@ time = { version = "0.3", features = [
|
||||
"formatting",
|
||||
] }
|
||||
tiny_http = "0.8"
|
||||
toml = "0.8"
|
||||
tokio = { version = "1" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
|
||||
toml = "0.8"
|
||||
tower-http = "0.4.4"
|
||||
tree-sitter = { version = "0.25.3", features = ["wasm"] }
|
||||
tree-sitter-bash = "0.23"
|
||||
tree-sitter-c = "0.23"
|
||||
tree-sitter-cpp = "0.23"
|
||||
tree-sitter-css = "0.23"
|
||||
tree-sitter-diff = "0.1.0"
|
||||
tree-sitter-elixir = "0.3"
|
||||
tree-sitter-embedded-template = "0.23.0"
|
||||
tree-sitter-gitcommit = { git = "https://github.com/zed-industries/tree-sitter-git-commit", rev = "88309716a69dd13ab83443721ba6e0b491d37ee9" }
|
||||
@@ -574,7 +578,6 @@ tree-sitter-go = "0.23"
|
||||
tree-sitter-go-mod = { git = "https://github.com/camdencheek/tree-sitter-go-mod", rev = "6efb59652d30e0e9cd5f3b3a669afd6f1a926d3c", package = "tree-sitter-gomod" }
|
||||
tree-sitter-gowork = { git = "https://github.com/zed-industries/tree-sitter-go-work", rev = "acb0617bf7f4fda02c6217676cc64acb89536dc7" }
|
||||
tree-sitter-heex = { git = "https://github.com/zed-industries/tree-sitter-heex", rev = "1dd45142fbb05562e35b2040c6129c9bca346592" }
|
||||
tree-sitter-diff = "0.1.0"
|
||||
tree-sitter-html = "0.23"
|
||||
tree-sitter-jsdoc = "0.23"
|
||||
tree-sitter-json = "0.24"
|
||||
@@ -586,15 +589,15 @@ tree-sitter-rust = "0.24"
|
||||
tree-sitter-typescript = "0.23"
|
||||
tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" }
|
||||
unicase = "2.6"
|
||||
unindent = "0.2.0"
|
||||
unicode-segmentation = "1.10"
|
||||
unicode-script = "0.5.7"
|
||||
unicode-segmentation = "1.10"
|
||||
unindent = "0.2.0"
|
||||
url = "2.2"
|
||||
urlencoding = "2.1.2"
|
||||
uuid = { version = "1.1.2", features = ["v4", "v5", "v7", "serde"] }
|
||||
walkdir = "2.3"
|
||||
wasmparser = "0.221"
|
||||
wasm-encoder = "0.221"
|
||||
wasmparser = "0.221"
|
||||
wasmtime = { version = "29", default-features = false, features = [
|
||||
"async",
|
||||
"demangle",
|
||||
@@ -608,7 +611,6 @@ wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.7.1"
|
||||
zstd = "0.11"
|
||||
metal = "0.29"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
git = "https://github.com/zed-industries/async-stripe"
|
||||
|
||||
@@ -212,7 +212,7 @@
|
||||
"ctrl-shift-g": "search::SelectPreviousMatch",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-k h": "assistant::DeployHistory",
|
||||
"ctrl-k l": "assistant::OpenPromptLibrary",
|
||||
"ctrl-k l": "assistant::OpenRulesLibrary",
|
||||
"new": "assistant::NewChat",
|
||||
"ctrl-t": "assistant::NewChat",
|
||||
"ctrl-n": "assistant::NewChat"
|
||||
@@ -241,7 +241,7 @@
|
||||
"ctrl-alt-n": "agent::NewTextThread",
|
||||
"ctrl-shift-h": "agent::OpenHistory",
|
||||
"ctrl-alt-c": "agent::OpenConfiguration",
|
||||
"ctrl-alt-p": "assistant::OpenPromptLibrary",
|
||||
"ctrl-alt-p": "assistant::OpenRulesLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -308,9 +308,9 @@
|
||||
{
|
||||
"context": "PromptLibrary",
|
||||
"bindings": {
|
||||
"new": "prompt_library::NewPrompt",
|
||||
"ctrl-n": "prompt_library::NewPrompt",
|
||||
"ctrl-shift-s": "prompt_library::ToggleDefaultPrompt"
|
||||
"new": "rules_library::NewRule",
|
||||
"ctrl-n": "rules_library::NewRule",
|
||||
"ctrl-shift-s": "rules_library::ToggleDefaultRule"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -675,7 +675,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
"shift-enter": "editor::ExpandExcerpts",
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
"context": "PromptLibrary",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-n": "prompt_library::NewPrompt",
|
||||
"cmd-shift-s": "prompt_library::ToggleDefaultPrompt",
|
||||
"cmd-n": "rules_library::NewRule",
|
||||
"cmd-shift-s": "rules_library::ToggleDefaultRule",
|
||||
"cmd-w": "workspace::CloseWindow"
|
||||
}
|
||||
},
|
||||
@@ -257,7 +257,7 @@
|
||||
"cmd-shift-g": "search::SelectPreviousMatch",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-k h": "assistant::DeployHistory",
|
||||
"cmd-k l": "assistant::OpenPromptLibrary",
|
||||
"cmd-k l": "assistant::OpenRulesLibrary",
|
||||
"cmd-t": "assistant::NewChat",
|
||||
"cmd-n": "assistant::NewChat"
|
||||
}
|
||||
@@ -286,7 +286,7 @@
|
||||
"cmd-alt-n": "agent::NewTextThread",
|
||||
"cmd-shift-h": "agent::OpenHistory",
|
||||
"cmd-alt-c": "agent::OpenConfiguration",
|
||||
"cmd-alt-p": "assistant::OpenPromptLibrary",
|
||||
"cmd-alt-p": "assistant::OpenRulesLibrary",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -738,7 +738,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
|
||||
@@ -50,6 +50,12 @@
|
||||
"] -": "vim::NextLesserIndent",
|
||||
"] +": "vim::NextGreaterIndent",
|
||||
"] =": "vim::NextSameIndent",
|
||||
"] b": "pane::ActivateNextItem",
|
||||
"[ b": "pane::ActivatePreviousItem",
|
||||
"] shift-b": "pane::ActivateLastItem",
|
||||
"[ shift-b": ["pane::ActivateItem", 0],
|
||||
"] space": "vim::InsertEmptyLineBelow",
|
||||
"[ space": "vim::InsertEmptyLineAbove",
|
||||
// Word motions
|
||||
"w": "vim::NextWordStart",
|
||||
"e": "vim::NextWordEnd",
|
||||
@@ -108,7 +114,11 @@
|
||||
"ctrl-e": "vim::LineDown",
|
||||
"ctrl-y": "vim::LineUp",
|
||||
// "g" commands
|
||||
"g r": "vim::PushReplaceWithRegister",
|
||||
"g shift-r": "vim::PushReplaceWithRegister",
|
||||
"g r n": "editor::Rename",
|
||||
"g r r": "editor::FindAllReferences",
|
||||
"g r i": "editor::GoToImplementation",
|
||||
"g r a": "editor::ToggleCodeActions",
|
||||
"g g": "vim::StartOfDocument",
|
||||
"g h": "editor::Hover",
|
||||
"g t": "pane::ActivateNextItem",
|
||||
@@ -127,6 +137,7 @@
|
||||
"g <": ["editor::SelectPrevious", { "replace_newest": true }],
|
||||
"g a": "editor::SelectAllMatches",
|
||||
"g s": "outline::Toggle",
|
||||
"g shift-o": "outline::Toggle",
|
||||
"g shift-s": "project_symbols::Toggle",
|
||||
"g .": "editor::ToggleCodeActions", // zed specific
|
||||
"g shift-a": "editor::FindAllReferences", // zed specific
|
||||
@@ -305,7 +316,7 @@
|
||||
"!": "vim::ShellCommand",
|
||||
"i": ["vim::PushObject", { "around": false }],
|
||||
"a": ["vim::PushObject", { "around": true }],
|
||||
"g r": ["vim::Paste", { "preserve_clipboard": true }],
|
||||
"g shift-r": ["vim::Paste", { "preserve_clipboard": true }],
|
||||
"g c": "vim::ToggleComments",
|
||||
"g q": "vim::Rewrap",
|
||||
"g ?": "vim::ConvertToRot13",
|
||||
@@ -339,7 +350,8 @@
|
||||
"ctrl-shift-q": ["vim::PushLiteral", {}],
|
||||
"ctrl-r": "vim::PushRegister",
|
||||
"insert": "vim::ToggleReplace",
|
||||
"ctrl-o": "vim::TemporaryNormal"
|
||||
"ctrl-o": "vim::TemporaryNormal",
|
||||
"ctrl-s": "editor::ShowSignatureHelp"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -630,9 +642,10 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "vim_operator == gr",
|
||||
"context": "vim_operator == gR",
|
||||
"bindings": {
|
||||
"r": "vim::CurrentLine"
|
||||
"r": "vim::CurrentLine",
|
||||
"shift-r": "vim::CurrentLine"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -693,7 +706,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "GitPanel || ProjectPanel || CollabPanel || OutlinePanel || ChatPanel || VimControl || EmptyPane || SharedScreen || MarkdownPreview || KeyContextView",
|
||||
"context": "GitPanel || ProjectPanel || CollabPanel || OutlinePanel || ChatPanel || VimControl || EmptyPane || SharedScreen || MarkdownPreview || KeyContextView || DebugPanel",
|
||||
"bindings": {
|
||||
// window related commands (ctrl-w X)
|
||||
"ctrl-w": null,
|
||||
|
||||
@@ -27,13 +27,28 @@ If appropriate, use tool calls to explore the current project, which contains th
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
||||
{{# if (has_tool 'grep') }}
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
{{! TODO: Only mention tools if they are enabled }}
|
||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
||||
{{/if}}
|
||||
|
||||
## Code Block Formatting
|
||||
|
||||
Whenever you mention a code block, you MUST use ONLY use the following format when the code in the block comes from a file
|
||||
in the project:
|
||||
|
||||
```path/to/Something.blah#L123-456
|
||||
(code goes here)
|
||||
```
|
||||
|
||||
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
|
||||
is a path in the project. (If this code block does not come from a file in the project, then you may instead use
|
||||
the normal markdown style of three backticks followed by language name. However, you MUST use this format if
|
||||
the code in the block comes from a file in the project.)
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
|
||||
@@ -167,7 +167,23 @@
|
||||
// Default: not set, defaults to "bar"
|
||||
"cursor_shape": null,
|
||||
// Determines when the mouse cursor should be hidden in an editor or input box.
|
||||
//
|
||||
// 1. Never hide the mouse cursor:
|
||||
// "never"
|
||||
// 2. Hide only when typing:
|
||||
// "on_typing"
|
||||
// 3. Hide on both typing and cursor movement:
|
||||
// "on_typing_and_movement"
|
||||
"hide_mouse": "on_typing_and_movement",
|
||||
// Determines how snippets are sorted relative to other completion items.
|
||||
//
|
||||
// 1. Place snippets at the top of the completion list:
|
||||
// "top"
|
||||
// 2. Place snippets normally without any preference:
|
||||
// "inline"
|
||||
// 3. Place snippets at the bottom of the completion list:
|
||||
// "bottom"
|
||||
"snippet_sort_order": "inline",
|
||||
// How to highlight the current line in the editor.
|
||||
//
|
||||
// 1. Don't highlight the current line:
|
||||
@@ -210,7 +226,7 @@
|
||||
// Hide the values of in variables from visual display in private files
|
||||
"redact_private_values": false,
|
||||
// The default number of lines to expand excerpts in the multibuffer by.
|
||||
"expand_excerpt_lines": 3,
|
||||
"expand_excerpt_lines": 5,
|
||||
// Globs to match against file paths to determine if a file is private.
|
||||
"private_files": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/secrets.yml"],
|
||||
// Whether to use additional LSP queries to format (and amend) the code after
|
||||
@@ -585,6 +601,13 @@
|
||||
//
|
||||
// Default: main
|
||||
"fallback_branch_name": "main",
|
||||
|
||||
// Whether to sort entries in the panel by path
|
||||
// or by status (the default).
|
||||
//
|
||||
// Default: false
|
||||
"sort_by_path": false,
|
||||
|
||||
"scrollbar": {
|
||||
// When to show the scrollbar in the git panel.
|
||||
//
|
||||
|
||||
@@ -1,2 +1,7 @@
|
||||
allow-private-module-inception = true
|
||||
avoid-breaking-exported-api = false
|
||||
ignore-interior-mutability = [
|
||||
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
|
||||
# and Hash impls do not use fields with interior mutability.
|
||||
"agent::context::AgentContextKey"
|
||||
]
|
||||
|
||||
@@ -28,7 +28,6 @@ async-watch.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -61,10 +60,12 @@ ordered-float.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
picker.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
prompt_library.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
ref-cast.workspace = true
|
||||
release_channel.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -597,6 +597,10 @@ impl Item for AgentDiff {
|
||||
editor.added_to_workspace(workspace, window, cx)
|
||||
});
|
||||
}
|
||||
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"Agent Diff".into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AgentDiff {
|
||||
@@ -947,6 +951,7 @@ mod tests {
|
||||
ThemeSettings::register(cx);
|
||||
ContextServerSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
language_model::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
@@ -962,12 +967,14 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let prompt_store = None;
|
||||
let thread_store = cx
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
prompt_store,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
||||
@@ -34,14 +34,15 @@ use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use settings::Settings as _;
|
||||
use thread::ThreadId;
|
||||
pub use thread::{MessageSegment, ThreadId};
|
||||
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
|
||||
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
pub use crate::thread_store::{SharedProjectContext, ThreadStore};
|
||||
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
|
||||
|
||||
actions!(
|
||||
|
||||
@@ -272,7 +272,7 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
.get(id.as_ref())
|
||||
.and_then(|preset| preset.tools.get(&tool.name))
|
||||
.copied()
|
||||
.unwrap_or(false),
|
||||
.unwrap_or(self.profile.enable_all_context_servers),
|
||||
};
|
||||
|
||||
Some(
|
||||
|
||||
@@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings;
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle, SharedString};
|
||||
|
||||
use crate::Thread;
|
||||
use language_model::{ConfiguredModel, LanguageModelRegistry};
|
||||
use language_model_selector::{
|
||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
|
||||
};
|
||||
@@ -9,7 +11,11 @@ use settings::update_settings_file;
|
||||
use std::sync::Arc;
|
||||
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
|
||||
pub use language_model_selector::ModelType;
|
||||
#[derive(Clone)]
|
||||
pub enum ModelType {
|
||||
Default(Entity<Thread>),
|
||||
InlineAssistant,
|
||||
}
|
||||
|
||||
pub struct AssistantModelSelector {
|
||||
selector: Entity<LanguageModelSelector>,
|
||||
@@ -24,18 +30,39 @@ impl AssistantModelSelector {
|
||||
focus_handle: FocusHandle,
|
||||
model_type: ModelType,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
selector: cx.new(|cx| {
|
||||
selector: cx.new(move |cx| {
|
||||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
{
|
||||
let model_type = model_type.clone();
|
||||
move |cx| match &model_type {
|
||||
ModelType::Default(thread) => thread.read(cx).configured_model(),
|
||||
ModelType::InlineAssistant => {
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
}
|
||||
}
|
||||
},
|
||||
move |model, cx| {
|
||||
let provider = model.provider_id().0.to_string();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
match model_type {
|
||||
ModelType::Default => {
|
||||
match &model_type {
|
||||
ModelType::Default(thread) => {
|
||||
thread.update(cx, |thread, cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(provider) = registry.provider(&model.provider_id())
|
||||
{
|
||||
thread.set_configured_model(
|
||||
Some(ConfiguredModel {
|
||||
provider,
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
cx,
|
||||
@@ -58,7 +85,6 @@ impl AssistantModelSelector {
|
||||
}
|
||||
}
|
||||
},
|
||||
model_type,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -5,8 +5,9 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_context_editor::{
|
||||
AssistantPanelDelegate, ConfigurationError, ContextEditor, SlashCommandCompletionProvider,
|
||||
humanize_token_count, make_lsp_adapter_delegate, render_remaining_tokens,
|
||||
AssistantContext, AssistantPanelDelegate, ConfigurationError, ContextEditor, ContextEvent,
|
||||
SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate,
|
||||
render_remaining_tokens,
|
||||
};
|
||||
use assistant_settings::{AssistantDockPosition, AssistantSettings};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
@@ -24,9 +25,9 @@ use language::LanguageRegistry;
|
||||
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
|
||||
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
|
||||
use proto::Plan;
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use time::UtcOffset;
|
||||
use ui::{
|
||||
@@ -36,7 +37,7 @@ use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
use workspace::dock::{DockPosition, Panel, PanelEvent};
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
|
||||
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
||||
|
||||
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
|
||||
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
|
||||
@@ -79,11 +80,11 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_prompt_library(action, window, cx)
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
@@ -116,6 +117,8 @@ enum ActiveView {
|
||||
},
|
||||
PromptEditor {
|
||||
context_editor: Entity<ContextEditor>,
|
||||
title_editor: Entity<Editor>,
|
||||
_subscriptions: Vec<gpui::Subscription>,
|
||||
},
|
||||
History,
|
||||
Configuration,
|
||||
@@ -176,6 +179,83 @@ impl ActiveView {
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompt_editor(
|
||||
context_editor: Entity<ContextEditor>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let title = context_editor.read(cx).title(cx).to_string();
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::single_line(window, cx);
|
||||
editor.set_text(title, window, cx);
|
||||
editor
|
||||
});
|
||||
|
||||
// This is a workaround for `editor.set_text` emitting a `BufferEdited` event, which would
|
||||
// cause a custom summary to be set. The presence of this custom summary would cause
|
||||
// summarization to not happen.
|
||||
let mut suppress_first_edit = true;
|
||||
|
||||
let subscriptions = vec![
|
||||
window.subscribe(&editor, cx, {
|
||||
{
|
||||
let context_editor = context_editor.clone();
|
||||
move |editor, event, window, cx| match event {
|
||||
EditorEvent::BufferEdited => {
|
||||
if suppress_first_edit {
|
||||
suppress_first_edit = false;
|
||||
return;
|
||||
}
|
||||
let new_summary = editor.read(cx).text(cx);
|
||||
|
||||
context_editor.update(cx, |context_editor, cx| {
|
||||
context_editor
|
||||
.context()
|
||||
.update(cx, |assistant_context, cx| {
|
||||
assistant_context.set_custom_summary(new_summary, cx);
|
||||
})
|
||||
})
|
||||
}
|
||||
EditorEvent::Blurred => {
|
||||
if editor.read(cx).text(cx).is_empty() {
|
||||
let summary = context_editor
|
||||
.read(cx)
|
||||
.context()
|
||||
.read(cx)
|
||||
.summary_or_default();
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(summary, window, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}),
|
||||
window.subscribe(&context_editor.read(cx).context().clone(), cx, {
|
||||
let editor = editor.clone();
|
||||
move |assistant_context, event, window, cx| match event {
|
||||
ContextEvent::SummaryGenerated => {
|
||||
let summary = assistant_context.read(cx).summary_or_default();
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(summary, window, cx);
|
||||
})
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
Self::PromptEditor {
|
||||
context_editor,
|
||||
title_editor: editor,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AssistantPanel {
|
||||
@@ -188,7 +268,9 @@ pub struct AssistantPanel {
|
||||
thread: Entity<ActiveThread>,
|
||||
message_editor: Entity<MessageEditor>,
|
||||
_active_thread_subscriptions: Vec<Subscription>,
|
||||
_default_model_subscription: Subscription,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
configuration: Option<Entity<AssistantConfiguration>>,
|
||||
configuration_subscription: Option<Subscription>,
|
||||
local_timezone: UtcOffset,
|
||||
@@ -205,14 +287,25 @@ impl AssistantPanel {
|
||||
pub fn load(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: AsyncWindowContext,
|
||||
mut cx: AsyncWindowContext,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = match prompt_store {
|
||||
Ok(prompt_store) => prompt_store.await.ok(),
|
||||
Err(_) => None,
|
||||
};
|
||||
let tools = cx.new(|_| ToolWorkingSet::default())?;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
ThreadStore::load(
|
||||
project,
|
||||
tools.clone(),
|
||||
prompt_builder.clone(),
|
||||
prompt_store.clone(),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
@@ -230,7 +323,16 @@ impl AssistantPanel {
|
||||
.await?;
|
||||
|
||||
workspace.update_in(cx, |workspace, window, cx| {
|
||||
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
|
||||
cx.new(|cx| {
|
||||
Self::new(
|
||||
workspace,
|
||||
thread_store,
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -239,6 +341,7 @@ impl AssistantPanel {
|
||||
workspace: &Workspace,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -262,6 +365,7 @@ impl AssistantPanel {
|
||||
fs.clone(),
|
||||
workspace.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
thread.clone(),
|
||||
window,
|
||||
@@ -293,7 +397,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -306,6 +409,20 @@ impl AssistantPanel {
|
||||
}
|
||||
});
|
||||
|
||||
let _default_model_subscription = cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
|this, _, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::DefaultModelChanged => {
|
||||
this.thread
|
||||
.read(cx)
|
||||
.thread()
|
||||
.clone()
|
||||
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
);
|
||||
|
||||
Self {
|
||||
active_view,
|
||||
workspace,
|
||||
@@ -321,7 +438,9 @@ impl AssistantPanel {
|
||||
active_thread_subscription,
|
||||
message_editor_subscription,
|
||||
],
|
||||
_default_model_subscription,
|
||||
context_store,
|
||||
prompt_store,
|
||||
configuration: None,
|
||||
configuration_subscription: None,
|
||||
local_timezone: UtcOffset::from_whole_seconds(
|
||||
@@ -355,6 +474,10 @@ impl AssistantPanel {
|
||||
self.local_timezone
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
|
||||
&self.thread_store
|
||||
}
|
||||
@@ -411,7 +534,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -430,6 +552,7 @@ impl AssistantPanel {
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -475,22 +598,20 @@ impl AssistantPanel {
|
||||
});
|
||||
|
||||
self.set_active_view(
|
||||
ActiveView::PromptEditor {
|
||||
context_editor: context_editor.clone(),
|
||||
},
|
||||
ActiveView::prompt_editor(context_editor.clone(), window, cx),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn deploy_prompt_library(
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenPromptLibrary,
|
||||
action: &OpenRulesLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_prompt_library(
|
||||
open_rules_library(
|
||||
self.language_registry.clone(),
|
||||
Box::new(PromptLibraryInlineAssist::new(self.workspace.clone())),
|
||||
Arc::new(|| {
|
||||
@@ -500,9 +621,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action.prompt_to_select.map(|uuid| PromptId::User {
|
||||
uuid: UserPromptId(uuid),
|
||||
}),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -551,10 +672,9 @@ impl AssistantPanel {
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
this.set_active_view(
|
||||
ActiveView::PromptEditor {
|
||||
context_editor: editor,
|
||||
},
|
||||
ActiveView::prompt_editor(editor.clone(), window, cx),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -598,7 +718,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
this.thread_store.clone(),
|
||||
this.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
this.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -617,6 +736,7 @@ impl AssistantPanel {
|
||||
this.fs.clone(),
|
||||
this.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
this.prompt_store.clone(),
|
||||
this.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -794,7 +914,7 @@ impl AssistantPanel {
|
||||
|
||||
pub(crate) fn active_context_editor(&self) -> Option<Entity<ContextEditor>> {
|
||||
match &self.active_view {
|
||||
ActiveView::PromptEditor { context_editor } => Some(context_editor.clone()),
|
||||
ActiveView::PromptEditor { context_editor, .. } => Some(context_editor.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -837,7 +957,7 @@ impl Focusable for AssistantPanel {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
|
||||
ActiveView::History => self.history.focus_handle(cx),
|
||||
ActiveView::PromptEditor { context_editor } => context_editor.focus_handle(cx),
|
||||
ActiveView::PromptEditor { context_editor, .. } => context_editor.focus_handle(cx),
|
||||
ActiveView::Configuration => {
|
||||
if let Some(configuration) = self.configuration.as_ref() {
|
||||
configuration.focus_handle(cx)
|
||||
@@ -961,9 +1081,34 @@ impl AssistantPanel {
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
ActiveView::PromptEditor { context_editor } => {
|
||||
let title = SharedString::from(context_editor.read(cx).title(cx).to_string());
|
||||
Label::new(title).ml_2().truncate().into_any_element()
|
||||
ActiveView::PromptEditor {
|
||||
title_editor,
|
||||
context_editor,
|
||||
..
|
||||
} => {
|
||||
let context_editor = context_editor.read(cx);
|
||||
let summary = context_editor.context().read(cx).summary();
|
||||
|
||||
match summary {
|
||||
None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone())
|
||||
.truncate()
|
||||
.ml_2()
|
||||
.into_any_element(),
|
||||
Some(summary) => {
|
||||
if summary.done {
|
||||
div()
|
||||
.ml_2()
|
||||
.w_full()
|
||||
.child(title_editor.clone())
|
||||
.into_any_element()
|
||||
} else {
|
||||
Label::new(LOADING_SUMMARY_PLACEHOLDER)
|
||||
.ml_2()
|
||||
.truncate()
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ActiveView::History => Label::new("History").truncate().into_any_element(),
|
||||
ActiveView::Configuration => Label::new("Settings").truncate().into_any_element(),
|
||||
@@ -1120,7 +1265,7 @@ impl AssistantPanel {
|
||||
"New Text Thread",
|
||||
NewTextThread.boxed_clone(),
|
||||
)
|
||||
.action("Prompt Library", Box::new(OpenPromptLibrary::default()))
|
||||
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.separator()
|
||||
.header("MCPs")
|
||||
@@ -1145,12 +1290,13 @@ impl AssistantPanel {
|
||||
let is_generating = thread.is_generating();
|
||||
let message_editor = self.message_editor.read(cx);
|
||||
|
||||
let conversation_token_usage = thread.total_token_usage(cx);
|
||||
let conversation_token_usage = thread.total_token_usage()?;
|
||||
|
||||
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
|
||||
self.thread.read(cx).editing_message_id()
|
||||
{
|
||||
let combined = thread
|
||||
.token_usage_up_to_message(editing_message_id, cx)
|
||||
.token_usage_up_to_message(editing_message_id)
|
||||
.add(unsent_tokens);
|
||||
|
||||
(combined, unsent_tokens > 0)
|
||||
@@ -1236,7 +1382,7 @@ impl AssistantPanel {
|
||||
|
||||
Some(token_count)
|
||||
}
|
||||
ActiveView::PromptEditor { context_editor } => {
|
||||
ActiveView::PromptEditor { context_editor, .. } => {
|
||||
let element = render_remaining_tokens(context_editor, cx)?;
|
||||
|
||||
Some(element.into_any_element())
|
||||
@@ -1833,7 +1979,7 @@ impl Render for AssistantPanel {
|
||||
this.open_configuration(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(Self::open_active_thread_as_markdown))
|
||||
.on_action(cx.listener(Self::deploy_prompt_library))
|
||||
.on_action(cx.listener(Self::deploy_rules_library))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.on_action(cx.listener(Self::go_back))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
@@ -1844,7 +1990,9 @@ impl Render for AssistantPanel {
|
||||
.child(h_flex().child(self.message_editor.clone()))
|
||||
.children(self.render_last_error(cx)),
|
||||
ActiveView::History => parent.child(self.history.clone()),
|
||||
ActiveView::PromptEditor { context_editor } => parent.child(context_editor.clone()),
|
||||
ActiveView::PromptEditor { context_editor, .. } => {
|
||||
parent.child(context_editor.clone())
|
||||
}
|
||||
ActiveView::Configuration => parent.children(self.configuration.clone()),
|
||||
})
|
||||
}
|
||||
@@ -1860,13 +2008,13 @@ impl PromptLibraryInlineAssist {
|
||||
}
|
||||
}
|
||||
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
_initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
let Some(project) = self
|
||||
@@ -1876,11 +2024,14 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let prompt_store = None;
|
||||
let thread_store = None;
|
||||
assistant.assist(
|
||||
&prompt_editor,
|
||||
self.workspace.clone(),
|
||||
project,
|
||||
None,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -1959,8 +2110,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
// being updated.
|
||||
cx.defer_in(window, move |panel, window, cx| {
|
||||
if panel.has_active_thread() {
|
||||
panel.thread.update(cx, |thread, cx| {
|
||||
thread.context_store().update(cx, |store, cx| {
|
||||
panel.message_editor.update(cx, |message_editor, cx| {
|
||||
message_editor.context_store().update(cx, |store, cx| {
|
||||
let buffer = buffer.read(cx);
|
||||
let selection_ranges = selection_ranges
|
||||
.into_iter()
|
||||
@@ -1977,9 +2128,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (buffer, range) in selection_ranges {
|
||||
store
|
||||
.add_selection(buffer, range, cx)
|
||||
.detach_and_log_err(cx);
|
||||
store.add_selection(buffer, range, cx);
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context::ContextLoadResult;
|
||||
use crate::inline_prompt_editor::CodegenStatus;
|
||||
use crate::{context::load_context, context_store::ContextStore};
|
||||
use anyhow::Result;
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
@@ -14,7 +16,9 @@ use language_model::{
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
@@ -39,6 +43,8 @@ pub struct BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
pub is_insertion: bool,
|
||||
@@ -50,6 +56,8 @@ impl BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -60,6 +68,8 @@ impl BufferCodegen {
|
||||
range.clone(),
|
||||
false,
|
||||
Some(context_store.clone()),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
Some(telemetry.clone()),
|
||||
builder.clone(),
|
||||
cx,
|
||||
@@ -75,6 +85,8 @@ impl BufferCodegen {
|
||||
range,
|
||||
initial_transaction_id,
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
builder,
|
||||
};
|
||||
@@ -153,6 +165,8 @@ impl BufferCodegen {
|
||||
self.range.clone(),
|
||||
false,
|
||||
Some(self.context_store.clone()),
|
||||
self.project.clone(),
|
||||
self.prompt_store.clone(),
|
||||
Some(self.telemetry.clone()),
|
||||
self.builder.clone(),
|
||||
cx,
|
||||
@@ -229,13 +243,14 @@ pub struct CodegenAlternative {
|
||||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
builder: Arc<PromptBuilder>,
|
||||
active: bool,
|
||||
edits: Vec<(Range<Anchor>, String)>,
|
||||
line_operations: Vec<LineOperation>,
|
||||
request: Option<LanguageModelRequest>,
|
||||
elapsed_time: Option<f64>,
|
||||
completion: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
@@ -249,6 +264,8 @@ impl CodegenAlternative {
|
||||
range: Range<Anchor>,
|
||||
active: bool,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -290,6 +307,8 @@ impl CodegenAlternative {
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
builder,
|
||||
@@ -297,7 +316,6 @@ impl CodegenAlternative {
|
||||
edits: Vec::new(),
|
||||
line_operations: Vec::new(),
|
||||
range,
|
||||
request: None,
|
||||
elapsed_time: None,
|
||||
completion: None,
|
||||
}
|
||||
@@ -366,16 +384,18 @@ impl CodegenAlternative {
|
||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(user_prompt, cx)?;
|
||||
self.request = Some(request.clone());
|
||||
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
@@ -408,30 +428,45 @@ impl CodegenAlternative {
|
||||
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
let context_task = self.context_store.as_ref().map(|context_store| {
|
||||
if let Some(project) = self.project.upgrade() {
|
||||
let context = context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
load_context(context, &project, &self.prompt_store, cx)
|
||||
} else {
|
||||
Task::ready(ContextLoadResult::default())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(context_store) = &self.context_store {
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
if let Some(context_task) = context_task {
|
||||
context_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
@@ -469,7 +504,7 @@ impl CodegenAlternative {
|
||||
}
|
||||
}
|
||||
|
||||
let http_client = cx.http_client().clone();
|
||||
let http_client = cx.http_client();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let language_name = {
|
||||
let multibuffer = self.buffer.read(cx);
|
||||
@@ -508,7 +543,9 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks = StripInvalidSpans::new(
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
@@ -1034,6 +1071,7 @@ impl Diff {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fs::FakeFs;
|
||||
use futures::{
|
||||
Stream,
|
||||
stream::{self},
|
||||
@@ -1076,12 +1114,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1140,12 +1182,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1207,12 +1253,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1274,12 +1324,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1329,12 +1383,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
false,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,8 +10,11 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use completion_provider::ContextPickerCompletionProvider;
|
||||
use editor::display_map::{Crease, FoldId};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use fetch_context_picker::FetchContextPicker;
|
||||
use file_context_picker::FileContextPicker;
|
||||
use file_context_picker::render_file_context_entry;
|
||||
use gpui::{
|
||||
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
|
||||
@@ -20,10 +23,10 @@ use gpui::{
|
||||
use language::Buffer;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use project::{Entry, ProjectPath};
|
||||
use prompt_store::UserPromptId;
|
||||
use rules_context_picker::RulesContextEntry;
|
||||
use prompt_store::{PromptStore, UserPromptId};
|
||||
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
|
||||
use symbol_context_picker::SymbolContextPicker;
|
||||
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
|
||||
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
|
||||
use ui::{
|
||||
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
|
||||
};
|
||||
@@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::AssistantPanel;
|
||||
use crate::context::RULES_ICON;
|
||||
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::context_picker::fetch_context_picker::FetchContextPicker;
|
||||
use crate::context_picker::file_context_picker::FileContextPicker;
|
||||
use crate::context_picker::rules_context_picker::RulesContextPicker;
|
||||
use crate::context_picker::thread_context_picker::ThreadContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::ThreadId;
|
||||
use crate::thread_store::ThreadStore;
|
||||
@@ -166,6 +164,7 @@ pub(super) struct ContextPicker {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -193,6 +192,13 @@ impl ContextPicker {
|
||||
)
|
||||
.collect::<Vec<Subscription>>();
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
ContextPicker {
|
||||
mode: ContextPickerState::Default(ContextMenu::build(
|
||||
window,
|
||||
@@ -202,6 +208,7 @@ impl ContextPicker {
|
||||
workspace,
|
||||
context_store,
|
||||
thread_store,
|
||||
prompt_store,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
@@ -226,7 +233,12 @@ impl ContextPicker {
|
||||
.workspace
|
||||
.upgrade()
|
||||
.map(|workspace| {
|
||||
available_context_picker_entries(&self.thread_store, &workspace, cx)
|
||||
available_context_picker_entries(
|
||||
&self.prompt_store,
|
||||
&self.thread_store,
|
||||
&workspace,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
@@ -304,10 +316,10 @@ impl ContextPicker {
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Rules => {
|
||||
if let Some(thread_store) = self.thread_store.as_ref() {
|
||||
if let Some(prompt_store) = self.prompt_store.as_ref() {
|
||||
self.mode = ContextPickerState::Rules(cx.new(|cx| {
|
||||
RulesContextPicker::new(
|
||||
thread_store.clone(),
|
||||
prompt_store.clone(),
|
||||
context_picker.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
@@ -376,7 +388,7 @@ impl ContextPicker {
|
||||
ContextMenuItem::custom_entry(
|
||||
move |_window, cx| {
|
||||
render_file_context_entry(
|
||||
ElementId::NamedInteger("ctx-recent".into(), ix),
|
||||
ElementId::named_usize("ctx-recent", ix),
|
||||
worktree_id,
|
||||
&path,
|
||||
&path_prefix,
|
||||
@@ -526,6 +538,7 @@ enum RecentEntry {
|
||||
}
|
||||
|
||||
fn available_context_picker_entries(
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
thread_store: &Option<WeakEntity<ThreadStore>>,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -550,6 +563,9 @@ fn available_context_picker_entries(
|
||||
|
||||
if thread_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
|
||||
}
|
||||
|
||||
if prompt_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
|
||||
}
|
||||
|
||||
@@ -585,22 +601,21 @@ fn recent_context_picker_entries(
|
||||
}),
|
||||
);
|
||||
|
||||
let mut current_threads = context_store.read(cx).thread_ids();
|
||||
let current_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
if let Some(active_thread) = workspace
|
||||
let active_thread_id = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|panel| panel.read(cx).active_thread(cx))
|
||||
{
|
||||
current_threads.insert(active_thread.read(cx).id().clone());
|
||||
}
|
||||
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
|
||||
|
||||
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
|
||||
recent.extend(
|
||||
thread_store
|
||||
.read(cx)
|
||||
.threads()
|
||||
.reverse_chronological_threads()
|
||||
.into_iter()
|
||||
.filter(|thread| !current_threads.contains(&thread.id))
|
||||
.filter(|thread| {
|
||||
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
|
||||
})
|
||||
.take(2)
|
||||
.map(|thread| {
|
||||
RecentEntry::Thread(ThreadContextEntry {
|
||||
@@ -622,9 +637,7 @@ fn add_selections_as_context(
|
||||
let selection_ranges = selection_ranges(workspace, cx);
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in selection_ranges {
|
||||
context_store
|
||||
.add_selection(buffer, range, cx)
|
||||
.detach_and_log_err(cx);
|
||||
context_store.add_selection(buffer, range, cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,22 +15,21 @@ use itertools::Itertools;
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use lsp::CompletionContext;
|
||||
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
|
||||
use prompt_store::PromptId;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Point;
|
||||
use text::{Anchor, OffsetRangeExt, ToPoint};
|
||||
use ui::prelude::*;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::file_context_picker::search_files;
|
||||
use crate::context_picker::symbol_context_picker::search_symbols;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
use super::fetch_context_picker::fetch_url_content;
|
||||
use super::file_context_picker::FileMatch;
|
||||
use super::file_context_picker::{FileMatch, search_files};
|
||||
use super::rules_context_picker::{RulesContextEntry, search_rules};
|
||||
use super::symbol_context_picker::SymbolMatch;
|
||||
use super::symbol_context_picker::search_symbols;
|
||||
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
|
||||
use super::{
|
||||
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
|
||||
@@ -38,8 +37,8 @@ use super::{
|
||||
};
|
||||
|
||||
pub(crate) enum Match {
|
||||
Symbol(SymbolMatch),
|
||||
File(FileMatch),
|
||||
Symbol(SymbolMatch),
|
||||
Thread(ThreadMatch),
|
||||
Fetch(SharedString),
|
||||
Rules(RulesContextEntry),
|
||||
@@ -69,6 +68,7 @@ fn search(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
recent_entries: Vec<RecentEntry>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
workspace: Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -85,6 +85,7 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Symbol) => {
|
||||
let search_symbols_task =
|
||||
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
@@ -96,6 +97,7 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Thread) => {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
let search_threads_task =
|
||||
@@ -111,6 +113,7 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Fetch) => {
|
||||
if !query.is_empty() {
|
||||
Task::ready(vec![Match::Fetch(query.into())])
|
||||
@@ -118,10 +121,11 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Rules) => {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
let search_rules_task =
|
||||
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
|
||||
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_rules_task
|
||||
.await
|
||||
@@ -133,6 +137,7 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
if query.is_empty() {
|
||||
let mut matches = recent_entries
|
||||
@@ -163,7 +168,7 @@ fn search(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
matches.extend(
|
||||
available_context_picker_entries(&thread_store, &workspace, cx)
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
|
||||
.into_iter()
|
||||
.map(|mode| {
|
||||
Match::Entry(EntryMatch {
|
||||
@@ -180,7 +185,8 @@ fn search(
|
||||
let search_files_task =
|
||||
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
|
||||
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
|
||||
let entries =
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
|
||||
let entry_candidates = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider {
|
||||
move |_, _: &mut Window, cx: &mut App| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in &selections {
|
||||
context_store
|
||||
.add_selection(buffer.clone(), range.clone(), cx)
|
||||
.detach_and_log_err(cx)
|
||||
context_store.add_selection(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider {
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
context_store: Entity<ContextStore>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
) -> Completion {
|
||||
let new_text = MentionLink::for_rules(&rules);
|
||||
let new_text_len = new_text.len();
|
||||
@@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider {
|
||||
new_text_len,
|
||||
editor.clone(),
|
||||
move |cx| {
|
||||
let prompt_uuid = rules.prompt_id;
|
||||
let prompt_id = PromptId::User { uuid: prompt_uuid };
|
||||
let context_store = context_store.clone();
|
||||
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
|
||||
log::error!("Can't add user rules as prompt store is missing.");
|
||||
return;
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(title) = metadata.title else {
|
||||
return;
|
||||
};
|
||||
let text_task = prompt_store.load(prompt_id, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let text = text_task.await?;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(prompt_uuid, title, text, false, cx)
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
let user_prompt_id = rules.prompt_id;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(user_prompt_id, false, cx);
|
||||
});
|
||||
},
|
||||
)),
|
||||
}
|
||||
@@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider {
|
||||
let url_to_fetch = url_to_fetch.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if context_store.update(cx, |context_store, _| {
|
||||
context_store.includes_url(&url_to_fetch).is_some()
|
||||
context_store.includes_url(&url_to_fetch)
|
||||
})? {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider {
|
||||
move |cx| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
let task = if is_directory {
|
||||
context_store.add_directory(project_path.clone(), false, cx)
|
||||
Task::ready(context_store.add_directory(&project_path, false, cx))
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path.clone(), false, cx)
|
||||
};
|
||||
@@ -720,7 +708,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
|
||||
let thread_store = self.thread_store.clone();
|
||||
let editor = self.editor.clone();
|
||||
let http_client = workspace.read(cx).client().http_client().clone();
|
||||
let http_client = workspace.read(cx).client().http_client();
|
||||
|
||||
let MentionCompletion { mode, argument, .. } = state;
|
||||
let query = argument.unwrap_or_else(|| "".to_string());
|
||||
@@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
);
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
let search_task = search(
|
||||
mode,
|
||||
query,
|
||||
Arc::<AtomicBool>::default(),
|
||||
recent_entries,
|
||||
prompt_store,
|
||||
thread_store.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
@@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
|
||||
symbol,
|
||||
excerpt_id,
|
||||
@@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
workspace.clone(),
|
||||
cx,
|
||||
),
|
||||
|
||||
Match::Thread(ThreadMatch {
|
||||
thread, is_recent, ..
|
||||
}) => {
|
||||
@@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
Match::Rules(user_rules) => {
|
||||
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
|
||||
Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Rules(user_rules) => Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
)),
|
||||
|
||||
Match::Fetch(url) => Some(Self::completion_for_fetch(
|
||||
source_range.clone(),
|
||||
url,
|
||||
@@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
context_store.clone(),
|
||||
http_client.clone(),
|
||||
)),
|
||||
|
||||
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
|
||||
entry,
|
||||
excerpt_id,
|
||||
@@ -1048,6 +1045,10 @@ mod tests {
|
||||
fn include_in_nav_history() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"Test".into()
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<()> for AtMentionEditor {}
|
||||
|
||||
@@ -193,7 +193,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
|
||||
return;
|
||||
};
|
||||
|
||||
let http_client = workspace.read(cx).client().http_client().clone();
|
||||
let http_client = workspace.read(cx).client().http_client();
|
||||
let url = self.url.clone();
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let text = cx
|
||||
@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let added = self.context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store.read(cx).includes_url(&self.url).is_some()
|
||||
context_store.read(cx).includes_url(&self.url)
|
||||
});
|
||||
|
||||
Some(
|
||||
|
||||
@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
if is_directory {
|
||||
context_store.add_directory(project_path, true, cx)
|
||||
Task::ready(context_store.add_directory(&project_path, true, cx))
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path, true, cx)
|
||||
context_store.add_file_from_path(project_path.clone(), true, cx)
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
@@ -169,7 +169,7 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
.inset(true)
|
||||
.toggle_state(selected)
|
||||
.child(render_file_context_entry(
|
||||
ElementId::NamedInteger("file-ctx-picker".into(), ix),
|
||||
ElementId::named_usize("file-ctx-picker", ix),
|
||||
WorktreeId::from_usize(mat.worktree_id),
|
||||
&mat.path,
|
||||
&mat.path_prefix,
|
||||
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
|
||||
path: path.clone(),
|
||||
};
|
||||
if is_directory {
|
||||
context_store.read(cx).includes_directory(&project_path)
|
||||
} else {
|
||||
context_store
|
||||
.read(cx)
|
||||
.will_include_file_path(&project_path, cx)
|
||||
.path_included_in_directory(&project_path, cx)
|
||||
} else {
|
||||
context_store.read(cx).file_path_included(&project_path, cx)
|
||||
}
|
||||
});
|
||||
|
||||
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
|
||||
})),
|
||||
)
|
||||
.when_some(added, |el, added| match added {
|
||||
FileInclusion::Direct(_) => el.child(
|
||||
FileInclusion::Direct => el.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_end()
|
||||
@@ -369,9 +369,8 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Added").size(LabelSize::Small)),
|
||||
),
|
||||
FileInclusion::InDirectory(directory_project_path) => {
|
||||
// TODO: Consider using worktree full_path to include worktree name.
|
||||
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
|
||||
FileInclusion::InDirectory { full_path } => {
|
||||
let directory_full_path = full_path.to_string_lossy().into_owned();
|
||||
|
||||
el.child(
|
||||
h_flex()
|
||||
@@ -385,7 +384,7 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Included").size(LabelSize::Small)),
|
||||
)
|
||||
.tooltip(Tooltip::text(format!("in {directory_path}")))
|
||||
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use prompt_store::{PromptId, PromptStore, UserPromptId};
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::{self, ContextStore};
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
pub struct RulesContextPicker {
|
||||
picker: Entity<Picker<RulesContextPickerDelegate>>,
|
||||
@@ -18,13 +17,13 @@ pub struct RulesContextPicker {
|
||||
|
||||
impl RulesContextPicker {
|
||||
pub fn new(
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
|
||||
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
RulesContextPicker { picker }
|
||||
@@ -50,7 +49,7 @@ pub struct RulesContextEntry {
|
||||
}
|
||||
|
||||
pub struct RulesContextPickerDelegate {
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
matches: Vec<RulesContextEntry>,
|
||||
@@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate {
|
||||
|
||||
impl RulesContextPickerDelegate {
|
||||
pub fn new(
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
) -> Self {
|
||||
RulesContextPickerDelegate {
|
||||
thread_store,
|
||||
prompt_store,
|
||||
context_picker,
|
||||
context_store,
|
||||
matches: Vec::new(),
|
||||
@@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return Task::ready(());
|
||||
};
|
||||
|
||||
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
|
||||
let search_task = search_rules(
|
||||
query,
|
||||
Arc::new(AtomicBool::default()),
|
||||
&self.prompt_store,
|
||||
cx,
|
||||
);
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let matches = search_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
@@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let prompt_id = entry.prompt_id;
|
||||
|
||||
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
|
||||
thread_store.load_rules(prompt_id, cx)
|
||||
});
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (metadata, text) = load_rules_task.await?;
|
||||
let Some(title) = metadata.title else {
|
||||
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.delegate
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(prompt_id, title, text, true, cx)
|
||||
})
|
||||
.ok();
|
||||
self.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(entry.prompt_id, true, cx)
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
.log_err();
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
@@ -179,11 +159,10 @@ pub fn render_thread_context_entry(
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store
|
||||
let added = context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store
|
||||
.read(cx)
|
||||
.includes_user_rules(&user_rules.prompt_id)
|
||||
.is_some()
|
||||
.includes_user_rules(user_rules.prompt_id)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -218,12 +197,9 @@ pub fn render_thread_context_entry(
|
||||
pub(crate) fn search_rules(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
prompt_store: &Entity<PromptStore>,
|
||||
cx: &mut App,
|
||||
) -> Task<Vec<RulesContextEntry>> {
|
||||
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
|
||||
return Task::ready(vec![]);
|
||||
};
|
||||
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_task
|
||||
|
||||
@@ -10,7 +10,6 @@ use gpui::{
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use project::{DocumentSymbol, Symbol};
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
@@ -172,10 +171,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
|
||||
let mat = &self.matches[ix];
|
||||
|
||||
Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
|
||||
render_symbol_context_entry(
|
||||
ElementId::NamedInteger("symbol-ctx-picker".into(), ix),
|
||||
mat,
|
||||
),
|
||||
render_symbol_context_entry(ElementId::named_usize("symbol-ctx-picker", ix), mat),
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -228,18 +224,16 @@ pub(crate) fn add_symbol(
|
||||
)
|
||||
})?;
|
||||
|
||||
context_store
|
||||
.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
context_store.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -353,38 +347,13 @@ fn compute_symbol_entries(
|
||||
context_store: &ContextStore,
|
||||
cx: &App,
|
||||
) -> Vec<SymbolEntry> {
|
||||
let mut symbol_entries = Vec::with_capacity(symbols.len());
|
||||
for SymbolMatch { symbol, .. } in symbols {
|
||||
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
|
||||
let is_included = if let Some(symbols_for_path) = symbols_for_path {
|
||||
let mut is_included = false;
|
||||
for included_symbol_id in symbols_for_path {
|
||||
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
|
||||
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let included_symbol_range =
|
||||
included_symbol_id.range.to_point_utf16(&snapshot);
|
||||
|
||||
if included_symbol_range.start == symbol.range.start.0
|
||||
&& included_symbol_range.end == symbol.range.end.0
|
||||
{
|
||||
is_included = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
is_included
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
symbol_entries.push(SymbolEntry {
|
||||
symbols
|
||||
.into_iter()
|
||||
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
|
||||
is_included: context_store.includes_symbol(&symbol, cx),
|
||||
symbol,
|
||||
is_included,
|
||||
})
|
||||
}
|
||||
symbol_entries
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
|
||||
|
||||
@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store.read(cx).includes_thread(&thread.id).is_some()
|
||||
ctx_store.read(cx).includes_thread(&thread.id)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -219,7 +219,7 @@ pub(crate) fn search_threads(
|
||||
) -> Task<Vec<ThreadMatch>> {
|
||||
let threads = thread_store
|
||||
.read(cx)
|
||||
.threads()
|
||||
.reverse_chronological_threads()
|
||||
.into_iter()
|
||||
.map(|thread| ThreadContextEntry {
|
||||
id: thread.id,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,9 @@ use itertools::Itertools;
|
||||
use language::Buffer;
|
||||
use project::ProjectItem;
|
||||
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context::{ContextId, ContextKind};
|
||||
use crate::context::{AgentContextHandle, ContextKind};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::Thread;
|
||||
@@ -32,6 +32,7 @@ pub struct ContextStrip {
|
||||
focus_handle: FocusHandle,
|
||||
suggest_context_kind: SuggestContextKind,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
focused_index: Option<usize>,
|
||||
children_bounds: Option<Vec<Bounds<Pixels>>>,
|
||||
@@ -73,12 +74,33 @@ impl ContextStrip {
|
||||
focus_handle,
|
||||
suggest_context_kind,
|
||||
workspace,
|
||||
thread_store,
|
||||
_subscriptions: subscriptions,
|
||||
focused_index: None,
|
||||
children_bounds: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().read(cx);
|
||||
let prompt_store = self
|
||||
.thread_store
|
||||
.as_ref()
|
||||
.and_then(|thread_store| thread_store.upgrade())
|
||||
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
|
||||
self.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.flat_map(|context| {
|
||||
AddedContext::new_pending(context.clone(), prompt_store, project, cx)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
match self.suggest_context_kind {
|
||||
SuggestContextKind::File => self.suggested_file(cx),
|
||||
@@ -93,22 +115,19 @@ impl ContextStrip {
|
||||
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
|
||||
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
|
||||
let active_buffer = active_buffer_entity.read(cx);
|
||||
|
||||
let project_path = active_buffer.project_path(cx)?;
|
||||
|
||||
if self
|
||||
.context_store
|
||||
.read(cx)
|
||||
.will_include_buffer(active_buffer.remote_id(), &project_path)
|
||||
.file_path_included(&project_path, cx)
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let file_name = active_buffer.file()?.file_name(cx);
|
||||
|
||||
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
|
||||
|
||||
Some(SuggestedContext::File {
|
||||
name: file_name.to_string_lossy().into_owned().into(),
|
||||
buffer: active_buffer_entity.downgrade(),
|
||||
@@ -135,7 +154,6 @@ impl ContextStrip {
|
||||
.context_store
|
||||
.read(cx)
|
||||
.includes_thread(active_thread.id())
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -272,12 +290,12 @@ impl ContextStrip {
|
||||
best.map(|(index, _, _)| index)
|
||||
}
|
||||
|
||||
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
|
||||
fn open_context(&mut self, context: &AgentContextHandle, window: &mut Window, cx: &mut App) {
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
|
||||
crate::active_thread::open_context(context, workspace, window, cx);
|
||||
}
|
||||
|
||||
fn remove_focused_context(
|
||||
@@ -287,17 +305,17 @@ impl ContextStrip {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(index) = self.focused_index {
|
||||
let mut is_empty = false;
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let Some(context) = added_contexts.get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.context_store.update(cx, |this, cx| {
|
||||
if let Some(item) = this.context().get(index) {
|
||||
this.remove_context(item.id(), cx);
|
||||
}
|
||||
|
||||
is_empty = this.context().is_empty();
|
||||
this.remove_context(&context.handle, cx);
|
||||
});
|
||||
|
||||
if is_empty {
|
||||
let is_now_empty = added_contexts.len() == 1;
|
||||
if is_now_empty {
|
||||
cx.emit(ContextStripEvent::BlurredEmpty);
|
||||
} else {
|
||||
self.focused_index = Some(index.saturating_sub(1));
|
||||
@@ -306,49 +324,28 @@ impl ContextStrip {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
|
||||
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
|
||||
// We only suggest one item after the actual context
|
||||
self.focused_index == Some(context.len())
|
||||
self.focused_index == Some(added_contexts.len())
|
||||
}
|
||||
|
||||
fn accept_suggested_context(
|
||||
&mut self,
|
||||
_: &AcceptSuggestedContext,
|
||||
window: &mut Window,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(suggested) = self.suggested_context(cx) {
|
||||
let context_store = self.context_store.read(cx);
|
||||
|
||||
if self.is_suggested_focused(context_store.context()) {
|
||||
self.add_suggested_context(&suggested, window, cx);
|
||||
if self.is_suggested_focused(&self.added_contexts(cx)) {
|
||||
self.add_suggested_context(&suggested, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_suggested_context(
|
||||
&mut self,
|
||||
suggested: &SuggestedContext,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let task = self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.accept_suggested_context(&suggested, cx)
|
||||
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
|
||||
self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_suggested_context(&suggested, cx)
|
||||
});
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
match task.await.notify_async_err(cx) {
|
||||
None => {}
|
||||
Some(()) => {
|
||||
if let Some(this) = this.upgrade() {
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
@@ -361,17 +358,10 @@ impl Focusable for ContextStrip {
|
||||
|
||||
impl Render for ContextStrip {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let context_store = self.context_store.read(cx);
|
||||
let context = context_store.context();
|
||||
let context_picker = self.context_picker.clone();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let suggested_context = self.suggested_context(cx);
|
||||
|
||||
let added_contexts = context
|
||||
.iter()
|
||||
.map(|c| AddedContext::new(c, cx))
|
||||
.collect::<Vec<_>>();
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let dupe_names = added_contexts
|
||||
.iter()
|
||||
.map(|c| c.name.clone())
|
||||
@@ -380,6 +370,14 @@ impl Render for ContextStrip {
|
||||
.filter(|(a, b)| a == b)
|
||||
.map(|(a, _)| a)
|
||||
.collect::<HashSet<SharedString>>();
|
||||
let no_added_context = added_contexts.is_empty();
|
||||
|
||||
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
|
||||
(
|
||||
suggested_context,
|
||||
self.is_suggested_focused(&added_contexts),
|
||||
)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
.flex_wrap()
|
||||
@@ -436,7 +434,7 @@ impl Render for ContextStrip {
|
||||
})
|
||||
.with_handle(self.context_picker_menu_handle.clone()),
|
||||
)
|
||||
.when(context.is_empty() && suggested_context.is_none(), {
|
||||
.when(no_added_context && suggested_context.is_none(), {
|
||||
|parent| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
@@ -466,16 +464,17 @@ impl Render for ContextStrip {
|
||||
.enumerate()
|
||||
.map(|(i, added_context)| {
|
||||
let name = added_context.name.clone();
|
||||
let id = added_context.id;
|
||||
let context = added_context.handle.clone();
|
||||
ContextPill::added(
|
||||
added_context,
|
||||
dupe_names.contains(&name),
|
||||
self.focused_index == Some(i),
|
||||
Some({
|
||||
let context = context.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
Rc::new(cx.listener(move |_this, _event, _window, cx| {
|
||||
context_store.update(cx, |this, cx| {
|
||||
this.remove_context(id, cx);
|
||||
this.remove_context(&context, cx);
|
||||
});
|
||||
cx.notify();
|
||||
}))
|
||||
@@ -484,7 +483,7 @@ impl Render for ContextStrip {
|
||||
.on_click({
|
||||
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
|
||||
if event.down.click_count > 1 {
|
||||
this.open_context(id, window, cx);
|
||||
this.open_context(&context, window, cx);
|
||||
} else {
|
||||
this.focused_index = Some(i);
|
||||
}
|
||||
@@ -493,22 +492,22 @@ impl Render for ContextStrip {
|
||||
})
|
||||
}),
|
||||
)
|
||||
.when_some(suggested_context, |el, suggested| {
|
||||
.when_some(suggested_context, |el, (suggested, focused)| {
|
||||
el.child(
|
||||
ContextPill::suggested(
|
||||
suggested.name().clone(),
|
||||
suggested.icon_path(),
|
||||
suggested.kind(),
|
||||
self.is_suggested_focused(&context),
|
||||
focused,
|
||||
)
|
||||
.on_click(Rc::new(cx.listener(
|
||||
move |this, _event, window, cx| {
|
||||
this.add_suggested_context(&suggested, window, cx);
|
||||
move |this, _event, _window, cx| {
|
||||
this.add_suggested_context(&suggested, cx);
|
||||
},
|
||||
))),
|
||||
)
|
||||
})
|
||||
.when(!context.is_empty(), {
|
||||
.when(!no_added_context, {
|
||||
move |parent| {
|
||||
parent.child(
|
||||
IconButton::new("remove-all-context", IconName::Eraser)
|
||||
@@ -534,6 +533,7 @@ impl Render for ContextStrip {
|
||||
)
|
||||
}
|
||||
})
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,10 @@ impl HistoryStore {
|
||||
return history_entries;
|
||||
}
|
||||
|
||||
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
|
||||
for thread in self
|
||||
.thread_store
|
||||
.update(cx, |this, _cx| this.reverse_chronological_threads())
|
||||
{
|
||||
history_entries.push(HistoryEntry::Thread(thread));
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ use project::LspAction;
|
||||
use project::Project;
|
||||
use project::{CodeAction, ProjectTransaction};
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
|
||||
@@ -245,9 +246,13 @@ impl InlineAssistant {
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx))
|
||||
};
|
||||
|
||||
let thread_store = workspace
|
||||
let assistant_panel = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
|
||||
.map(|assistant_panel| assistant_panel.read(cx));
|
||||
let prompt_store = assistant_panel
|
||||
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
|
||||
let thread_store =
|
||||
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
|
||||
|
||||
let handle_assist =
|
||||
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
|
||||
@@ -257,6 +262,7 @@ impl InlineAssistant {
|
||||
&active_editor,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -269,6 +275,7 @@ impl InlineAssistant {
|
||||
&active_terminal,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -323,6 +330,7 @@ impl InlineAssistant {
|
||||
editor: &Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -437,6 +445,8 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
None,
|
||||
context_store.clone(),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -525,6 +535,7 @@ impl InlineAssistant {
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
focus: bool,
|
||||
workspace: Entity<Workspace>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -543,7 +554,7 @@ impl InlineAssistant {
|
||||
}
|
||||
|
||||
let project = workspace.read(cx).project().downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
|
||||
|
||||
let codegen = cx.new(|cx| {
|
||||
BufferCodegen::new(
|
||||
@@ -551,6 +562,8 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
initial_transaction_id,
|
||||
context_store.clone(),
|
||||
project,
|
||||
prompt_store,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
let editor = self.editor.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
window.spawn(cx, async move |cx| {
|
||||
let workspace = workspace.upgrade().context("workspace was released")?;
|
||||
let editor = editor.upgrade().context("editor was released")?;
|
||||
@@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
})?
|
||||
.context("invalid range")?;
|
||||
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
|
||||
let assist_id = assistant.suggest_assist(
|
||||
&editor,
|
||||
@@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
None,
|
||||
true,
|
||||
workspace,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::buffer_codegen::BufferCodegen;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
@@ -13,14 +13,14 @@ use editor::{
|
||||
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
|
||||
actions::{MoveDown, MoveUp},
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedPro};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use language_model_selector::{ModelType, ToggleModelSelector};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use parking_lot::Mutex;
|
||||
use settings::Settings;
|
||||
use std::cmp;
|
||||
@@ -132,7 +132,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -931,7 +931,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedPro>()
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::ModelType;
|
||||
use crate::context::{AssistantContext, format_context_as_string};
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use crate::ui::AnimatedLabel;
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::HashSet;
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
@@ -11,27 +12,31 @@ use editor::{
|
||||
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode,
|
||||
EditorStyle, MultiBuffer,
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt, NewBillingFeatureFlag};
|
||||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, future};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
|
||||
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
};
|
||||
use language::{Buffer, Language};
|
||||
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage};
|
||||
use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Disclosure, KeyBinding, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
use zed_llm_client::CompletionMode;
|
||||
|
||||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
|
||||
use crate::context_store::{ContextStore, refresh_context_store_text};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::profile_selector::ProfileSelector;
|
||||
use crate::thread::{Thread, TokenUsageRatio};
|
||||
@@ -45,17 +50,18 @@ pub struct MessageEditor {
|
||||
thread: Entity<Thread>,
|
||||
incompatible_tools_state: Entity<IncompatibleToolsState>,
|
||||
editor: Entity<Editor>,
|
||||
#[allow(dead_code)]
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AssistantModelSelector>,
|
||||
last_loaded_context: Option<ContextLoadResult>,
|
||||
load_context_task: Option<Shared<Task<()>>>,
|
||||
profile_selector: Entity<ProfileSelector>,
|
||||
edits_expanded: bool,
|
||||
editor_is_expanded: bool,
|
||||
waiting_for_summaries_to_send: bool,
|
||||
last_estimated_token_count: Option<usize>,
|
||||
update_token_count_task: Option<Task<anyhow::Result<()>>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
@@ -68,6 +74,7 @@ impl MessageEditor {
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
window: &mut Window,
|
||||
@@ -135,16 +142,26 @@ impl MessageEditor {
|
||||
let subscriptions = vec![
|
||||
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
|
||||
cx.subscribe(&editor, |this, _, event, cx| match event {
|
||||
EditorEvent::BufferEdited => {
|
||||
this.message_or_context_changed(true, cx);
|
||||
}
|
||||
EditorEvent::BufferEdited => this.handle_message_changed(cx),
|
||||
_ => {}
|
||||
}),
|
||||
cx.observe(&context_store, |this, _, cx| {
|
||||
this.message_or_context_changed(false, cx);
|
||||
// When context changes, reload it for token counting.
|
||||
let _ = this.reload_context(cx);
|
||||
}),
|
||||
];
|
||||
|
||||
let model_selector = cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
model_selector_menu_handle,
|
||||
editor.focus_handle(cx),
|
||||
ModelType::Default(thread.clone()),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
project: thread.read(cx).project().clone(),
|
||||
@@ -152,21 +169,14 @@ impl MessageEditor {
|
||||
incompatible_tools_state: incompatible_tools.clone(),
|
||||
workspace,
|
||||
context_store,
|
||||
prompt_store,
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
model_selector_menu_handle,
|
||||
editor.focus_handle(cx),
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
load_context_task: None,
|
||||
last_loaded_context: None,
|
||||
model_selector,
|
||||
edits_expanded: false,
|
||||
editor_is_expanded: false,
|
||||
waiting_for_summaries_to_send: false,
|
||||
profile_selector: cx
|
||||
.new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
|
||||
last_estimated_token_count: None,
|
||||
@@ -175,6 +185,10 @@ impl MessageEditor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_store(&self) -> &Entity<ContextStore> {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.notify();
|
||||
}
|
||||
@@ -214,6 +228,7 @@ impl MessageEditor {
|
||||
) {
|
||||
self.context_picker_menu_handle.toggle(window, cx);
|
||||
}
|
||||
|
||||
pub fn remove_all_context(
|
||||
&mut self,
|
||||
_: &RemoveAllContext,
|
||||
@@ -229,6 +244,10 @@ impl MessageEditor {
|
||||
return;
|
||||
}
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_editing(cx);
|
||||
});
|
||||
|
||||
if self.thread.read(cx).is_generating() {
|
||||
self.stop_current_and_send_new_message(window, cx);
|
||||
return;
|
||||
@@ -244,15 +263,11 @@ impl MessageEditor {
|
||||
self.editor.read(cx).text(cx).trim().is_empty()
|
||||
}
|
||||
|
||||
fn is_model_selected(&self, cx: &App) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
|
||||
let Some(ConfiguredModel { model, provider }) = self
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -270,68 +285,22 @@ impl MessageEditor {
|
||||
self.last_estimated_token_count.take();
|
||||
cx.emit(MessageEditorEvent::EstimatedTokenCount);
|
||||
|
||||
let refresh_task =
|
||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
|
||||
|
||||
let thread = self.thread.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
|
||||
let context_task = self.reload_context(cx);
|
||||
let window_handle = window.window_handle();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
refresh_task.await;
|
||||
wait_for_images.await;
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
|
||||
let loaded_context = loaded_context.unwrap_or_default();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let context = context_store.read(cx).context().clone();
|
||||
thread.insert_user_message(user_message, context, checkpoint, cx);
|
||||
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
let excerpt_ids = context_store
|
||||
.context()
|
||||
.iter()
|
||||
.filter(|ctx| {
|
||||
matches!(
|
||||
ctx,
|
||||
AssistantContext::Selection(_) | AssistantContext::Image(_)
|
||||
)
|
||||
})
|
||||
.map(|ctx| ctx.id())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for id in excerpt_ids {
|
||||
context_store.remove_context(id, cx);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(wait_for_summaries) = context_store
|
||||
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
|
||||
.log_err()
|
||||
{
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = true;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
|
||||
wait_for_summaries.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
// Send to model after summaries are done
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
@@ -343,6 +312,10 @@ impl MessageEditor {
|
||||
}
|
||||
|
||||
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_editing(cx);
|
||||
});
|
||||
|
||||
let cancelled = self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_last_completion(Some(window.window_handle()), cx)
|
||||
});
|
||||
@@ -402,7 +375,7 @@ impl MessageEditor {
|
||||
|
||||
self.context_store.update(cx, |store, cx| {
|
||||
for image in images {
|
||||
store.add_image(Arc::new(image), cx);
|
||||
store.add_image_instance(Arc::new(image), cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -426,6 +399,36 @@ impl MessageEditor {
|
||||
}
|
||||
}
|
||||
|
||||
fn render_max_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
if !cx.has_flag::<NewBillingFeatureFlag>() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let thread = self.thread.read(cx);
|
||||
let model = thread.configured_model();
|
||||
if !model?.model.supports_max_mode() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let active_completion_mode = thread.completion_mode();
|
||||
|
||||
Some(
|
||||
IconButton::new("max-mode", IconName::SquarePlus)
|
||||
.icon_size(IconSize::Small)
|
||||
.toggle_state(active_completion_mode == Some(CompletionMode::Max))
|
||||
.on_click(cx.listener(move |this, _event, _window, cx| {
|
||||
this.thread.update(cx, |thread, _cx| {
|
||||
thread.set_completion_mode(match active_completion_mode {
|
||||
Some(CompletionMode::Max) => Some(CompletionMode::Normal),
|
||||
Some(CompletionMode::Normal) | None => Some(CompletionMode::Max),
|
||||
});
|
||||
});
|
||||
}))
|
||||
.tooltip(Tooltip::text("Max Mode"))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_editor(
|
||||
&self,
|
||||
font_size: Rems,
|
||||
@@ -434,24 +437,21 @@ impl MessageEditor {
|
||||
cx: &mut Context<Self>,
|
||||
) -> Div {
|
||||
let thread = self.thread.read(cx);
|
||||
let model = thread.configured_model();
|
||||
|
||||
let editor_bg_color = cx.theme().colors().editor_background;
|
||||
let is_generating = thread.is_generating();
|
||||
let focus_handle = self.editor.focus_handle(cx);
|
||||
|
||||
let is_model_selected = self.is_model_selected(cx);
|
||||
let is_model_selected = model.is_some();
|
||||
let is_editor_empty = self.is_editor_empty(cx);
|
||||
|
||||
let model = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.model.clone());
|
||||
|
||||
let incompatible_tools = model
|
||||
.as_ref()
|
||||
.map(|model| {
|
||||
self.incompatible_tools_state.update(cx, |state, cx| {
|
||||
state
|
||||
.incompatible_tools(model, cx)
|
||||
.incompatible_tools(&model.model, cx)
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
@@ -495,32 +495,34 @@ impl MessageEditor {
|
||||
.items_start()
|
||||
.justify_between()
|
||||
.child(self.context_strip.clone())
|
||||
.child(
|
||||
IconButton::new("toggle-height", expand_icon)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
let expand_label = if is_editor_expanded {
|
||||
"Minimize Message Editor".to_string()
|
||||
} else {
|
||||
"Expand Message Editor".to_string()
|
||||
};
|
||||
.when(focus_handle.is_focused(window), |this| {
|
||||
this.child(
|
||||
IconButton::new("toggle-height", expand_icon)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
let expand_label = if is_editor_expanded {
|
||||
"Minimize Message Editor".to_string()
|
||||
} else {
|
||||
"Expand Message Editor".to_string()
|
||||
};
|
||||
|
||||
Tooltip::for_action_in(
|
||||
expand_label,
|
||||
&ExpandMessageEditor,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(|_, _, window, cx| {
|
||||
window.dispatch_action(Box::new(ExpandMessageEditor), cx);
|
||||
})),
|
||||
),
|
||||
Tooltip::for_action_in(
|
||||
expand_label,
|
||||
&ExpandMessageEditor,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click(cx.listener(|_, _, window, cx| {
|
||||
window.dispatch_action(Box::new(ExpandMessageEditor), cx);
|
||||
})),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
@@ -586,6 +588,7 @@ impl MessageEditor {
|
||||
}),
|
||||
)
|
||||
})
|
||||
.children(self.render_max_mode_toggle(cx))
|
||||
.child(self.model_selector.clone())
|
||||
.map({
|
||||
let focus_handle = focus_handle.clone();
|
||||
@@ -638,31 +641,31 @@ impl MessageEditor {
|
||||
})
|
||||
.when(!is_editor_empty, |parent| {
|
||||
parent.child(
|
||||
IconButton::new("send-message", IconName::Send)
|
||||
.icon_color(Color::Accent)
|
||||
.style(ButtonStyle::Filled)
|
||||
.disabled(
|
||||
!is_model_selected
|
||||
|| self
|
||||
.waiting_for_summaries_to_send,
|
||||
)
|
||||
.on_click({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |_event, window, cx| {
|
||||
focus_handle.dispatch_action(
|
||||
&Chat, window, cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::for_action(
|
||||
"Stop and Send New Message",
|
||||
&Chat,
|
||||
window,
|
||||
cx,
|
||||
IconButton::new(
|
||||
"send-message",
|
||||
IconName::Send,
|
||||
)
|
||||
}),
|
||||
)
|
||||
.icon_color(Color::Accent)
|
||||
.style(ButtonStyle::Filled)
|
||||
.disabled(!is_model_selected)
|
||||
.on_click({
|
||||
let focus_handle =
|
||||
focus_handle.clone();
|
||||
move |_event, window, cx| {
|
||||
focus_handle.dispatch_action(
|
||||
&Chat, window, cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::for_action(
|
||||
"Stop and Send New Message",
|
||||
&Chat,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
)
|
||||
})
|
||||
} else {
|
||||
parent.child(
|
||||
@@ -670,10 +673,7 @@ impl MessageEditor {
|
||||
.icon_color(Color::Accent)
|
||||
.style(ButtonStyle::Filled)
|
||||
.disabled(
|
||||
is_editor_empty
|
||||
|| !is_model_selected
|
||||
|| self
|
||||
.waiting_for_summaries_to_send,
|
||||
is_editor_empty || !is_model_selected,
|
||||
)
|
||||
.on_click({
|
||||
let focus_handle = focus_handle.clone();
|
||||
@@ -724,9 +724,12 @@ impl MessageEditor {
|
||||
let border_color = cx.theme().colors().border;
|
||||
let active_color = cx.theme().colors().element_selected;
|
||||
let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3));
|
||||
|
||||
let is_edit_changes_expanded = self.edits_expanded;
|
||||
let is_generating = self.thread.read(cx).is_generating();
|
||||
|
||||
v_flex()
|
||||
.mt_1()
|
||||
.mx_2()
|
||||
.bg(bg_edit_files_disclosure)
|
||||
.border_1()
|
||||
@@ -761,25 +764,44 @@ impl MessageEditor {
|
||||
cx.notify();
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Label::new("Edits")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new("•").size(LabelSize::XSmall).color(Color::Muted))
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"{} {}",
|
||||
changed_buffers.len(),
|
||||
if changed_buffers.len() == 1 {
|
||||
"file"
|
||||
} else {
|
||||
"files"
|
||||
}
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
.map(|this| {
|
||||
if is_generating {
|
||||
this.child(
|
||||
AnimatedLabel::new(format!(
|
||||
"Editing {} {}",
|
||||
changed_buffers.len(),
|
||||
if changed_buffers.len() == 1 {
|
||||
"file"
|
||||
} else {
|
||||
"files"
|
||||
}
|
||||
))
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
Label::new("Edits")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new("•").size(LabelSize::XSmall).color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"{} {}",
|
||||
changed_buffers.len(),
|
||||
if changed_buffers.len() == 1 {
|
||||
"file"
|
||||
} else {
|
||||
"files"
|
||||
}
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
}
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Button::new("review", "Review Changes")
|
||||
@@ -869,7 +891,7 @@ impl MessageEditor {
|
||||
.justify_between()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.hover(|style| style.bg(hover_color))
|
||||
.when(index + 1 < changed_buffers.len(), |parent| {
|
||||
.when(index < changed_buffers.len() - 1, |parent| {
|
||||
parent.border_color(border_color).border_b_1()
|
||||
})
|
||||
.child(
|
||||
@@ -885,9 +907,9 @@ impl MessageEditor {
|
||||
.gap_0p5()
|
||||
.children(name_label)
|
||||
.children(parent_label),
|
||||
) // TODO: show lines changed
|
||||
.child(Label::new("+").color(Color::Created))
|
||||
.child(Label::new("-").color(Color::Deleted)),
|
||||
), // TODO: Implement line diff
|
||||
// .child(Label::new("+").color(Color::Created))
|
||||
// .child(Label::new("-").color(Color::Deleted)),
|
||||
)
|
||||
.child(
|
||||
div().visible_on_hover("edited-code").child(
|
||||
@@ -1015,18 +1037,49 @@ impl MessageEditor {
|
||||
self.update_token_count_task.is_some()
|
||||
}
|
||||
|
||||
fn reload_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
|
||||
let load_task = cx.spawn(async move |this, cx| {
|
||||
let Ok(load_task) = this.update(cx, |this, cx| {
|
||||
let new_context = this.context_store.read_with(cx, |context_store, cx| {
|
||||
context_store.new_context_for_thread(this.thread.read(cx))
|
||||
});
|
||||
load_context(new_context, &this.project, &this.prompt_store, cx)
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let result = load_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.last_loaded_context = Some(result);
|
||||
this.load_context_task = None;
|
||||
this.message_or_context_changed(false, cx);
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
// Replace existing load task, if any, causing it to be cancelled.
|
||||
let load_task = load_task.shared();
|
||||
self.load_context_task = Some(load_task.clone());
|
||||
cx.spawn(async move |this, cx| {
|
||||
load_task.await;
|
||||
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
|
||||
self.message_or_context_changed(true, cx);
|
||||
}
|
||||
|
||||
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Changed);
|
||||
self.update_token_count_task.take();
|
||||
|
||||
let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
let Some(model) = self.thread.read(cx).configured_model() else {
|
||||
self.last_estimated_token_count.take();
|
||||
return;
|
||||
};
|
||||
|
||||
let context_store = self.context_store.clone();
|
||||
let editor = self.editor.clone();
|
||||
let thread = self.thread.clone();
|
||||
|
||||
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
|
||||
if debounce {
|
||||
@@ -1035,33 +1088,46 @@ impl MessageEditor {
|
||||
.await;
|
||||
}
|
||||
|
||||
let token_count = if let Some(task) = cx.update(|cx| {
|
||||
let context = context_store.read(cx).context().iter();
|
||||
let new_context = thread.read(cx).filter_new_context(context);
|
||||
let context_text =
|
||||
format_context_as_string(new_context, cx).unwrap_or(String::new());
|
||||
let token_count = if let Some(task) = this.update(cx, |this, cx| {
|
||||
let loaded_context = this
|
||||
.last_loaded_context
|
||||
.as_ref()
|
||||
.map(|context_load_result| &context_load_result.loaded_context);
|
||||
let message_text = editor.read(cx).text(cx);
|
||||
|
||||
let content = context_text + &message_text;
|
||||
|
||||
if content.is_empty() {
|
||||
if message_text.is_empty()
|
||||
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
if let Some(loaded_context) = loaded_context {
|
||||
loaded_context.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
if !message_text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message_text));
|
||||
}
|
||||
|
||||
let request = language_model::LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![content.into()],
|
||||
cache: false,
|
||||
}],
|
||||
mode: None,
|
||||
messages: vec![request_message],
|
||||
tools: vec![],
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
};
|
||||
|
||||
Some(default_model.model.count_tokens(request, cx))
|
||||
Some(model.model.count_tokens(request, cx))
|
||||
})? {
|
||||
task.await?
|
||||
} else {
|
||||
@@ -1093,8 +1159,11 @@ impl Focusable for MessageEditor {
|
||||
impl Render for MessageEditor {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let thread = self.thread.read(cx);
|
||||
let total_token_usage = thread.total_token_usage(cx);
|
||||
let token_usage_ratio = total_token_usage.ratio();
|
||||
let token_usage_ratio = thread
|
||||
.total_token_usage()
|
||||
.map_or(TokenUsageRatio::Normal, |total_token_usage| {
|
||||
total_token_usage.ratio()
|
||||
});
|
||||
|
||||
let action_log = self.thread.read(cx).action_log();
|
||||
let changed_buffers = action_log.read(cx).changed_buffers(cx);
|
||||
@@ -1104,41 +1173,6 @@ impl Render for MessageEditor {
|
||||
|
||||
v_flex()
|
||||
.size_full()
|
||||
.when(self.waiting_for_summaries_to_send, |parent| {
|
||||
parent.child(
|
||||
h_flex().py_3().w_full().justify_center().child(
|
||||
h_flex()
|
||||
.flex_none()
|
||||
.px_2()
|
||||
.py_2()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.rounded_lg()
|
||||
.shadow_md()
|
||||
.gap_1()
|
||||
.child(
|
||||
Icon::new(IconName::ArrowCircle)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted)
|
||||
.with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|icon, delta| {
|
||||
icon.transform(gpui::Transformation::rotate(
|
||||
gpui::percentage(delta),
|
||||
))
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Label::new("Summarizing context…")
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
),
|
||||
)
|
||||
})
|
||||
.when(changed_buffers.len() > 0, |parent| {
|
||||
parent.child(self.render_changed_buffers(&changed_buffers, window, cx))
|
||||
})
|
||||
|
||||
@@ -32,7 +32,7 @@ impl TerminalCodegen {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
|
||||
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
@@ -45,6 +45,7 @@ impl TerminalCodegen {
|
||||
self.status = CodegenStatus::Pending;
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
self.generation = cx.spawn(async move |this, cx| {
|
||||
let prompt = prompt_task.await;
|
||||
let model_telemetry_id = model.telemetry_id();
|
||||
let model_provider_id = model.provider_id();
|
||||
let response = model.stream_completion_text(prompt, &cx).await;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context::load_context;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::inline_prompt_editor::{
|
||||
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
|
||||
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use editor::{MultiBuffer, actions::SelectAll};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
|
||||
use language::Buffer;
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::{PromptBuilder, PromptStore};
|
||||
use std::sync::Arc;
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::TerminalView;
|
||||
@@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
|
||||
terminal_view: &Entity<TerminalView>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
|
||||
prompt_editor,
|
||||
workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
|
||||
.log_err();
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
return;
|
||||
};
|
||||
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
|
||||
}
|
||||
|
||||
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
|
||||
@@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
|
||||
&self,
|
||||
assist_id: TerminalInlineAssistId,
|
||||
cx: &mut App,
|
||||
) -> Result<LanguageModelRequest> {
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
||||
|
||||
let shell = std::env::var("SHELL").ok();
|
||||
@@ -246,28 +248,41 @@ impl TerminalInlineAssistant {
|
||||
&latest_output,
|
||||
)?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
let contexts = assist
|
||||
.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
|
||||
let project = workspace.project();
|
||||
load_context(contexts, project, &assist.prompt_store, cx)
|
||||
})?;
|
||||
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
assist.context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
Ok(cx.background_spawn(async move {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
context_load_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn finish_assist(
|
||||
@@ -380,6 +395,7 @@ struct TerminalInlineAssist {
|
||||
codegen: Entity<TerminalCodegen>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -390,6 +406,7 @@ impl TerminalInlineAssist {
|
||||
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
@@ -400,6 +417,7 @@ impl TerminalInlineAssist {
|
||||
codegen: codegen.clone(),
|
||||
workspace: workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
_subscriptions: vec![
|
||||
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
|
||||
TerminalInlineAssistant::update_global(cx, |this, cx| {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@ use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
|
||||
use fs::Fs;
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::future::{self, BoxFuture, Shared};
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
@@ -22,10 +21,10 @@ use gpui::{
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::{Project, Worktree};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
|
||||
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
|
||||
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
|
||||
UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
@@ -83,11 +82,10 @@ impl ThreadStore {
|
||||
project: Entity<Project>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
let (thread_store, ready_rx) = cx.update(|cx| {
|
||||
let mut option_ready_rx = None;
|
||||
let thread_store = cx.new(|cx| {
|
||||
@@ -208,15 +206,15 @@ impl ThreadStore {
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<()> {
|
||||
let project = self.project.read(cx);
|
||||
let worktree_tasks = project
|
||||
let worktrees = self
|
||||
.project
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.collect::<Vec<_>>();
|
||||
let worktree_tasks = worktrees
|
||||
.into_iter()
|
||||
.map(|worktree| {
|
||||
Self::load_worktree_info_for_system_prompt(
|
||||
project.fs().clone(),
|
||||
worktree.read(cx),
|
||||
cx,
|
||||
)
|
||||
Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let default_user_rules_task = match prompt_store {
|
||||
@@ -277,13 +275,13 @@ impl ThreadStore {
|
||||
}
|
||||
|
||||
fn load_worktree_info_for_system_prompt(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
|
||||
let root_name = worktree.root_name().into();
|
||||
let root_name = worktree.read(cx).root_name().into();
|
||||
|
||||
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
|
||||
let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
|
||||
let Some(rules_task) = rules_task else {
|
||||
return Task::ready((
|
||||
WorktreeContext {
|
||||
@@ -313,33 +311,44 @@ impl ThreadStore {
|
||||
}
|
||||
|
||||
fn load_worktree_rules_file(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree: &Worktree,
|
||||
cx: &App,
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Option<Task<Result<RulesFileContext>>> {
|
||||
let worktree_ref = worktree.read(cx);
|
||||
let worktree_id = worktree_ref.id();
|
||||
let selected_rules_file = RULES_FILE_NAMES
|
||||
.into_iter()
|
||||
.filter_map(|name| {
|
||||
worktree
|
||||
worktree_ref
|
||||
.entry_for_path(name)
|
||||
.filter(|entry| entry.is_file())
|
||||
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
|
||||
.map(|entry| entry.path.clone())
|
||||
})
|
||||
.next();
|
||||
|
||||
// Note that Cline supports `.clinerules` being a directory, but that is not currently
|
||||
// supported. This doesn't seem to occur often in GitHub repositories.
|
||||
selected_rules_file.map(|(path_in_worktree, abs_path)| {
|
||||
let fs = fs.clone();
|
||||
selected_rules_file.map(|path_in_worktree| {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: path_in_worktree.clone(),
|
||||
};
|
||||
let buffer_task =
|
||||
project.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
||||
let rope_task = cx.spawn(async move |cx| {
|
||||
buffer_task.await?.read_with(cx, |buffer, cx| {
|
||||
let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
|
||||
anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
|
||||
})?
|
||||
});
|
||||
// Build a string from the rope on a background thread.
|
||||
cx.background_spawn(async move {
|
||||
let abs_path = abs_path?;
|
||||
let text = fs.load(&abs_path).await.with_context(|| {
|
||||
format!("Failed to load assistant rules file {:?}", abs_path)
|
||||
})?;
|
||||
let (project_entry_id, rope) = rope_task.await?;
|
||||
anyhow::Ok(RulesFileContext {
|
||||
path_in_worktree,
|
||||
abs_path: abs_path.into(),
|
||||
text: text.trim().to_string(),
|
||||
text: rope.to_string().trim().to_string(),
|
||||
project_entry_id: project_entry_id.to_usize(),
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -349,25 +358,8 @@ impl ThreadStore {
|
||||
self.context_server_manager.clone()
|
||||
}
|
||||
|
||||
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
|
||||
self.prompt_store.clone()
|
||||
}
|
||||
|
||||
pub fn load_rules(
|
||||
&self,
|
||||
prompt_id: UserPromptId,
|
||||
cx: &App,
|
||||
) -> Task<Result<(PromptMetadata, String)>> {
|
||||
let prompt_id = PromptId::User { uuid: prompt_id };
|
||||
let Some(prompt_store) = self.prompt_store.as_ref() else {
|
||||
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return Task::ready(Err(anyhow!("User rules not found in library.")));
|
||||
};
|
||||
let text_task = prompt_store.load(prompt_id, cx);
|
||||
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
|
||||
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Entity<ToolWorkingSet> {
|
||||
@@ -379,16 +371,12 @@ impl ThreadStore {
|
||||
self.threads.len()
|
||||
}
|
||||
|
||||
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
|
||||
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
|
||||
threads
|
||||
}
|
||||
|
||||
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
|
||||
self.threads().into_iter().take(limit).collect()
|
||||
}
|
||||
|
||||
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
|
||||
cx.new(|cx| {
|
||||
Thread::new(
|
||||
@@ -516,6 +504,22 @@ impl ThreadStore {
|
||||
);
|
||||
});
|
||||
}
|
||||
// Enable all the tools from all context servers, but disable the ones that are explicitly disabled
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.disable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
@@ -636,15 +640,28 @@ pub struct SerializedThread {
|
||||
pub detailed_summary_state: DetailedSummaryState,
|
||||
#[serde(default)]
|
||||
pub exceeded_window_error: Option<ExceededWindowError>,
|
||||
#[serde(default)]
|
||||
pub model: Option<SerializedLanguageModel>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedLanguageModel {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
SerializedThreadV0_1_0::VERSION => {
|
||||
let saved_thread =
|
||||
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
|
||||
Ok(saved_thread.upgrade())
|
||||
}
|
||||
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
|
||||
saved_thread_json,
|
||||
)?),
|
||||
@@ -666,6 +683,38 @@ impl SerializedThread {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedThreadV0_1_0(
|
||||
// The structure did not change, so we are reusing the latest SerializedThread.
|
||||
// When making the next version, make sure this points to SerializedThreadV0_2_0
|
||||
SerializedThread,
|
||||
);
|
||||
|
||||
impl SerializedThreadV0_1_0 {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn upgrade(self) -> SerializedThread {
|
||||
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
|
||||
|
||||
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
|
||||
|
||||
for message in self.0.messages {
|
||||
if message.role == Role::User && !message.tool_results.is_empty() {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
|
||||
last_message.tool_results = message.tool_results;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
SerializedThread { messages, ..self.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
@@ -733,6 +782,7 @@ impl LegacySerializedThread {
|
||||
request_token_usage: Vec::new(),
|
||||
detailed_summary_state: DetailedSummaryState::default(),
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
};
|
||||
use ui::IconName;
|
||||
@@ -30,7 +30,6 @@ pub struct ToolUse {
|
||||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
@@ -42,7 +41,6 @@ impl ToolUseState {
|
||||
Self {
|
||||
tools,
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
@@ -56,7 +54,6 @@ impl ToolUseState {
|
||||
pub fn from_serialized_messages(
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||
) -> Self {
|
||||
let mut this = Self::new(tools);
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
@@ -68,7 +65,6 @@ impl ToolUseState {
|
||||
let tool_uses = message
|
||||
.tool_uses
|
||||
.iter()
|
||||
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
|
||||
.map(|tool_use| LanguageModelToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone().into(),
|
||||
@@ -86,14 +82,6 @@ impl ToolUseState {
|
||||
|
||||
this.tool_uses_by_assistant_message
|
||||
.insert(message.id, tool_uses);
|
||||
}
|
||||
}
|
||||
Role::User => {
|
||||
if !message.tool_results.is_empty() {
|
||||
let tool_uses_by_user_message = this
|
||||
.tool_uses_by_user_message
|
||||
.entry(message.id)
|
||||
.or_default();
|
||||
|
||||
for tool_result in &message.tool_results {
|
||||
let tool_use_id = tool_result.tool_use_id.clone();
|
||||
@@ -102,11 +90,6 @@ impl ToolUseState {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !(filter_by_tool_name)(tool_use.as_ref()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||
this.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
@@ -119,7 +102,7 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::System => {}
|
||||
Role::System | Role::User => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,20 +212,26 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
let empty = Vec::new();
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
let Some(tool_uses) = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.unwrap_or(&empty)
|
||||
tool_uses
|
||||
.iter()
|
||||
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
|
||||
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
||||
@@ -294,14 +283,6 @@ impl ToolUseState {
|
||||
self.tool_use_metadata_by_id
|
||||
.insert(tool_use.id.clone(), metadata);
|
||||
|
||||
// The tool use is being requested by the Assistant, so we want to
|
||||
// attach the tool results to the next user message.
|
||||
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
|
||||
self.tool_uses_by_user_message
|
||||
.entry(next_user_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.id.clone());
|
||||
|
||||
PendingToolUseStatus::Idle
|
||||
} else {
|
||||
PendingToolUseStatus::InputStillStreaming
|
||||
@@ -372,7 +353,7 @@ impl ToolUseState {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
output: Result<String>,
|
||||
cx: &App,
|
||||
configured_model: Option<&ConfiguredModel>,
|
||||
) -> Option<PendingToolUse> {
|
||||
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
|
||||
|
||||
@@ -392,13 +373,10 @@ impl ToolUseState {
|
||||
|
||||
match output {
|
||||
Ok(tool_result) => {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
|
||||
|
||||
// Protect from clearly large output
|
||||
let tool_output_limit = model_registry
|
||||
.default_model()
|
||||
let tool_output_limit = configured_model
|
||||
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
@@ -450,7 +428,6 @@ impl ToolUseState {
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
dbg!(&self.tool_uses_by_assistant_message, &message_id);
|
||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
|
||||
for tool_use in tool_uses {
|
||||
if self.tool_results.contains_key(&tool_use.id) {
|
||||
@@ -468,32 +445,49 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(
|
||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.contains_key(&assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_results_message(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
dbg!(&self.tool_uses_by_user_message, &message_id);
|
||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
|
||||
for tool_use_id in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||
request_message.content.push(MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
assistant_message_id: MessageId,
|
||||
) -> Option<LanguageModelRequestMessage> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)?;
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
));
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Some(request_message)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
use std::sync::Arc;
|
||||
use std::{rc::Rc, time::Duration};
|
||||
use std::{ops::Range, path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
|
||||
use file_icons::FileIcons;
|
||||
use futures::FutureExt;
|
||||
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
|
||||
use gpui::{ClickEvent, Task};
|
||||
use futures::FutureExt as _;
|
||||
use gpui::{
|
||||
Animation, AnimationExt as _, AnyView, ClickEvent, Entity, Image, MouseButton, Task,
|
||||
pulsating_between,
|
||||
};
|
||||
use language_model::LanguageModelImage;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Point;
|
||||
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
|
||||
use crate::context::{
|
||||
AgentContext, AgentContextHandle, ContextId, ContextKind, DirectoryContext,
|
||||
DirectoryContextHandle, FetchedUrlContext, FileContext, FileContextHandle, ImageContext,
|
||||
ImageStatus, RulesContext, RulesContextHandle, SelectionContext, SelectionContextHandle,
|
||||
SymbolContext, SymbolContextHandle, ThreadContext, ThreadContextHandle,
|
||||
};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub enum ContextPill {
|
||||
@@ -73,9 +82,7 @@ impl ContextPill {
|
||||
|
||||
pub fn id(&self) -> ElementId {
|
||||
match self {
|
||||
Self::Added { context, .. } => {
|
||||
ElementId::NamedInteger("context-pill".into(), context.id.0)
|
||||
}
|
||||
Self::Added { context, .. } => context.handle.element_id("context-pill".into()),
|
||||
Self::Suggested { .. } => "suggested-context-pill".into(),
|
||||
}
|
||||
}
|
||||
@@ -168,16 +175,11 @@ impl RenderOnce for ContextPill {
|
||||
.map(|element| match &context.status {
|
||||
ContextStatus::Ready => element
|
||||
.when_some(
|
||||
context.render_preview.as_ref(),
|
||||
|element, render_preview| {
|
||||
element.hoverable_tooltip({
|
||||
let render_preview = render_preview.clone();
|
||||
move |_, cx| {
|
||||
cx.new(|_| ContextPillPreview {
|
||||
render_preview: render_preview.clone(),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
context.render_hover.as_ref(),
|
||||
|element, render_hover| {
|
||||
let render_hover = render_hover.clone();
|
||||
element.hoverable_tooltip(move |window, cx| {
|
||||
render_hover(window, cx)
|
||||
})
|
||||
},
|
||||
)
|
||||
@@ -199,14 +201,17 @@ impl RenderOnce for ContextPill {
|
||||
)
|
||||
.when_some(on_remove.as_ref(), |element, on_remove| {
|
||||
element.child(
|
||||
IconButton::new(("remove", context.id.0), IconName::Close)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
.on_click({
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
IconButton::new(
|
||||
context.handle.element_id("remove".into()),
|
||||
IconName::Close,
|
||||
)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
.on_click({
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
@@ -264,216 +269,441 @@ pub enum ContextStatus {
|
||||
|
||||
#[derive(RegisterComponent)]
|
||||
pub struct AddedContext {
|
||||
pub id: ContextId,
|
||||
pub handle: AgentContextHandle,
|
||||
pub kind: ContextKind,
|
||||
pub name: SharedString,
|
||||
pub parent: Option<SharedString>,
|
||||
pub tooltip: Option<SharedString>,
|
||||
pub icon_path: Option<SharedString>,
|
||||
pub status: ContextStatus,
|
||||
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement + 'static>>,
|
||||
pub render_hover: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView + 'static>>,
|
||||
}
|
||||
|
||||
impl AddedContext {
|
||||
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
|
||||
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
|
||||
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
|
||||
///
|
||||
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
|
||||
pub fn new_pending(
|
||||
handle: AgentContextHandle,
|
||||
prompt_store: Option<&Entity<PromptStore>>,
|
||||
project: &Project,
|
||||
cx: &App,
|
||||
) -> Option<AddedContext> {
|
||||
match handle {
|
||||
AgentContextHandle::File(handle) => Self::pending_file(handle, cx),
|
||||
AgentContextHandle::Directory(handle) => Self::pending_directory(handle, project, cx),
|
||||
AgentContextHandle::Symbol(handle) => Self::pending_symbol(handle, cx),
|
||||
AgentContextHandle::Selection(handle) => Self::pending_selection(handle, cx),
|
||||
AgentContextHandle::FetchedUrl(handle) => Some(Self::fetched_url(handle)),
|
||||
AgentContextHandle::Thread(handle) => Some(Self::pending_thread(handle, cx)),
|
||||
AgentContextHandle::Rules(handle) => Self::pending_rules(handle, prompt_store, cx),
|
||||
AgentContextHandle::Image(handle) => Some(Self::image(handle)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_attached(context: &AgentContext, cx: &App) -> AddedContext {
|
||||
match context {
|
||||
AssistantContext::File(file_context) => {
|
||||
let full_path = file_context.context_buffer.full_path(cx);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned().into())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
let parent = full_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
id: file_context.id,
|
||||
kind: ContextKind::File,
|
||||
name,
|
||||
parent,
|
||||
tooltip: Some(full_path_string),
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
}
|
||||
}
|
||||
AgentContext::File(context) => Self::attached_file(context, cx),
|
||||
AgentContext::Directory(context) => Self::attached_directory(context),
|
||||
AgentContext::Symbol(context) => Self::attached_symbol(context, cx),
|
||||
AgentContext::Selection(context) => Self::attached_selection(context, cx),
|
||||
AgentContext::FetchedUrl(context) => Self::fetched_url(context.clone()),
|
||||
AgentContext::Thread(context) => Self::attached_thread(context),
|
||||
AgentContext::Rules(context) => Self::attached_rules(context),
|
||||
AgentContext::Image(context) => Self::image(context.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
AssistantContext::Directory(directory_context) => {
|
||||
let worktree = directory_context.worktree.read(cx);
|
||||
// If the directory no longer exists, use its last known path.
|
||||
let full_path = worktree
|
||||
.entry_for_id(directory_context.entry_id)
|
||||
.map_or_else(
|
||||
|| directory_context.last_path.clone(),
|
||||
|entry| worktree.full_path(&entry.path).into(),
|
||||
);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned().into())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
let parent = full_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
id: directory_context.id,
|
||||
kind: ContextKind::Directory,
|
||||
name,
|
||||
parent,
|
||||
tooltip: Some(full_path_string),
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
}
|
||||
}
|
||||
fn pending_file(handle: FileContextHandle, cx: &App) -> Option<AddedContext> {
|
||||
let full_path = handle.buffer.read(cx).file()?.full_path(cx);
|
||||
Some(Self::file(handle, &full_path, cx))
|
||||
}
|
||||
|
||||
AssistantContext::Symbol(symbol_context) => AddedContext {
|
||||
id: symbol_context.id,
|
||||
kind: ContextKind::Symbol,
|
||||
name: symbol_context.context_symbol.id.name.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
fn attached_file(context: &FileContext, cx: &App) -> AddedContext {
|
||||
Self::file(context.handle.clone(), &context.full_path, cx)
|
||||
}
|
||||
|
||||
fn file(handle: FileContextHandle, full_path: &Path, cx: &App) -> AddedContext {
|
||||
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned().into())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
let parent = full_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
kind: ContextKind::File,
|
||||
name,
|
||||
parent,
|
||||
tooltip: Some(full_path_string),
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: None,
|
||||
handle: AgentContextHandle::File(handle),
|
||||
}
|
||||
}
|
||||
|
||||
fn pending_directory(
|
||||
handle: DirectoryContextHandle,
|
||||
project: &Project,
|
||||
cx: &App,
|
||||
) -> Option<AddedContext> {
|
||||
let worktree = project.worktree_for_entry(handle.entry_id, cx)?.read(cx);
|
||||
let entry = worktree.entry_for_id(handle.entry_id)?;
|
||||
let full_path = worktree.full_path(&entry.path);
|
||||
Some(Self::directory(handle, &full_path))
|
||||
}
|
||||
|
||||
fn attached_directory(context: &DirectoryContext) -> AddedContext {
|
||||
Self::directory(context.handle.clone(), &context.full_path)
|
||||
}
|
||||
|
||||
fn directory(handle: DirectoryContextHandle, full_path: &Path) -> AddedContext {
|
||||
let full_path_string: SharedString = full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned().into())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
let parent = full_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
kind: ContextKind::Directory,
|
||||
name,
|
||||
parent,
|
||||
tooltip: Some(full_path_string),
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: None,
|
||||
handle: AgentContextHandle::Directory(handle),
|
||||
}
|
||||
}
|
||||
|
||||
fn pending_symbol(handle: SymbolContextHandle, cx: &App) -> Option<AddedContext> {
|
||||
let excerpt =
|
||||
ContextFileExcerpt::new(&handle.full_path(cx)?, handle.enclosing_line_range(cx), cx);
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Symbol,
|
||||
name: handle.symbol.clone(),
|
||||
parent: Some(excerpt.file_name_and_range.clone()),
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: {
|
||||
let handle = handle.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
excerpt.hover_view(handle.text(cx), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Symbol(handle),
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::Selection(selection_context) => {
|
||||
let full_path = selection_context.context_buffer.full_path(cx);
|
||||
let mut full_path_string = full_path.to_string_lossy().into_owned();
|
||||
let mut name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
|
||||
let line_range_text = format!(
|
||||
" ({}-{})",
|
||||
selection_context.line_range.start.row + 1,
|
||||
selection_context.line_range.end.row + 1
|
||||
);
|
||||
|
||||
full_path_string.push_str(&line_range_text);
|
||||
name.push_str(&line_range_text);
|
||||
|
||||
let parent = full_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
|
||||
AddedContext {
|
||||
id: selection_context.id,
|
||||
kind: ContextKind::Selection,
|
||||
name: name.into(),
|
||||
parent,
|
||||
tooltip: None,
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: Some(Rc::new({
|
||||
let content = selection_context.context_buffer.text.clone();
|
||||
move |_, cx| {
|
||||
div()
|
||||
.id("context-pill-selection-preview")
|
||||
.overflow_scroll()
|
||||
.max_w_128()
|
||||
.max_h_96()
|
||||
.child(Label::new(content.clone()).buffer_font(cx))
|
||||
.into_any_element()
|
||||
}
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
|
||||
id: fetched_url_context.id,
|
||||
kind: ContextKind::FetchedUrl,
|
||||
name: fetched_url_context.url.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
fn attached_symbol(context: &SymbolContext, cx: &App) -> AddedContext {
|
||||
let excerpt = ContextFileExcerpt::new(&context.full_path, context.line_range.clone(), cx);
|
||||
AddedContext {
|
||||
kind: ContextKind::Symbol,
|
||||
name: context.handle.symbol.clone(),
|
||||
parent: Some(excerpt.file_name_and_range.clone()),
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: {
|
||||
let text = context.text.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
excerpt.hover_view(text.clone(), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Symbol(context.handle.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
AssistantContext::Thread(thread_context) => AddedContext {
|
||||
id: thread_context.id,
|
||||
kind: ContextKind::Thread,
|
||||
name: thread_context.summary(cx),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: if thread_context
|
||||
.thread
|
||||
.read(cx)
|
||||
.is_generating_detailed_summary()
|
||||
{
|
||||
ContextStatus::Loading {
|
||||
message: "Summarizing…".into(),
|
||||
}
|
||||
} else {
|
||||
ContextStatus::Ready
|
||||
fn pending_selection(handle: SelectionContextHandle, cx: &App) -> Option<AddedContext> {
|
||||
let excerpt = ContextFileExcerpt::new(&handle.full_path(cx)?, handle.line_range(cx), cx);
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Selection,
|
||||
name: excerpt.file_name_and_range.clone(),
|
||||
parent: excerpt.parent_name.clone(),
|
||||
tooltip: None,
|
||||
icon_path: excerpt.icon_path.clone(),
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: {
|
||||
let handle = handle.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
excerpt.hover_view(handle.text(cx), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Selection(handle),
|
||||
})
|
||||
}
|
||||
|
||||
fn attached_selection(context: &SelectionContext, cx: &App) -> AddedContext {
|
||||
let excerpt = ContextFileExcerpt::new(&context.full_path, context.line_range.clone(), cx);
|
||||
AddedContext {
|
||||
kind: ContextKind::Selection,
|
||||
name: excerpt.file_name_and_range.clone(),
|
||||
parent: excerpt.parent_name.clone(),
|
||||
tooltip: None,
|
||||
icon_path: excerpt.icon_path.clone(),
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: {
|
||||
let text = context.text.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
excerpt.hover_view(text.clone(), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Selection(context.handle.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn fetched_url(context: FetchedUrlContext) -> AddedContext {
|
||||
AddedContext {
|
||||
kind: ContextKind::FetchedUrl,
|
||||
name: context.url.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: None,
|
||||
handle: AgentContextHandle::FetchedUrl(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn pending_thread(handle: ThreadContextHandle, cx: &App) -> AddedContext {
|
||||
AddedContext {
|
||||
kind: ContextKind::Thread,
|
||||
name: handle.title(cx),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: if handle.thread.read(cx).is_generating_detailed_summary() {
|
||||
ContextStatus::Loading {
|
||||
message: "Summarizing…".into(),
|
||||
}
|
||||
} else {
|
||||
ContextStatus::Ready
|
||||
},
|
||||
render_hover: {
|
||||
let thread = handle.thread.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
let text = thread.read(cx).latest_detailed_summary_or_text();
|
||||
ContextPillHover::new_text(text.clone(), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Thread(handle),
|
||||
}
|
||||
}
|
||||
|
||||
fn attached_thread(context: &ThreadContext) -> AddedContext {
|
||||
AddedContext {
|
||||
kind: ContextKind::Thread,
|
||||
name: context.title.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: {
|
||||
let text = context.text.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
ContextPillHover::new_text(text.clone(), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Thread(context.handle.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn pending_rules(
|
||||
handle: RulesContextHandle,
|
||||
prompt_store: Option<&Entity<PromptStore>>,
|
||||
cx: &App,
|
||||
) -> Option<AddedContext> {
|
||||
let title = prompt_store
|
||||
.as_ref()?
|
||||
.read(cx)
|
||||
.metadata(handle.prompt_id.into())?
|
||||
.title
|
||||
.unwrap_or_else(|| "Unnamed Rule".into());
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Rules,
|
||||
name: title.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: None,
|
||||
handle: AgentContextHandle::Rules(handle),
|
||||
})
|
||||
}
|
||||
|
||||
fn attached_rules(context: &RulesContext) -> AddedContext {
|
||||
let title = context
|
||||
.title
|
||||
.clone()
|
||||
.unwrap_or_else(|| "Unnamed Rule".into());
|
||||
AddedContext {
|
||||
kind: ContextKind::Rules,
|
||||
name: title,
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_hover: {
|
||||
let text = context.text.clone();
|
||||
Some(Rc::new(move |_, cx| {
|
||||
ContextPillHover::new_text(text.clone(), cx).into()
|
||||
}))
|
||||
},
|
||||
handle: AgentContextHandle::Rules(context.handle.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn image(context: ImageContext) -> AddedContext {
|
||||
AddedContext {
|
||||
kind: ContextKind::Image,
|
||||
name: "Image".into(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: match context.status() {
|
||||
ImageStatus::Loading => ContextStatus::Loading {
|
||||
message: "Loading…".into(),
|
||||
},
|
||||
render_preview: None,
|
||||
},
|
||||
|
||||
AssistantContext::Rules(user_rules_context) => AddedContext {
|
||||
id: user_rules_context.id,
|
||||
kind: ContextKind::Rules,
|
||||
name: user_rules_context.title.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
},
|
||||
|
||||
AssistantContext::Image(image_context) => AddedContext {
|
||||
id: image_context.id,
|
||||
kind: ContextKind::Image,
|
||||
name: "Image".into(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: if image_context.is_loading() {
|
||||
ContextStatus::Loading {
|
||||
message: "Loading…".into(),
|
||||
}
|
||||
} else if image_context.is_error() {
|
||||
ContextStatus::Error {
|
||||
message: "Failed to load image".into(),
|
||||
}
|
||||
} else {
|
||||
ContextStatus::Ready
|
||||
ImageStatus::Error => ContextStatus::Error {
|
||||
message: "Failed to load image".into(),
|
||||
},
|
||||
render_preview: Some(Rc::new({
|
||||
let image = image_context.original_image.clone();
|
||||
move |_, _| {
|
||||
ImageStatus::Ready => ContextStatus::Ready,
|
||||
},
|
||||
render_hover: Some(Rc::new({
|
||||
let image = context.original_image.clone();
|
||||
move |_, cx| {
|
||||
let image = image.clone();
|
||||
ContextPillHover::new(cx, move |_, _| {
|
||||
gpui::img(image.clone())
|
||||
.max_w_96()
|
||||
.max_h_96()
|
||||
.into_any_element()
|
||||
}
|
||||
})),
|
||||
},
|
||||
})
|
||||
.into()
|
||||
}
|
||||
})),
|
||||
handle: AgentContextHandle::Image(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextPillPreview {
|
||||
render_preview: Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>,
|
||||
#[derive(Debug, Clone)]
|
||||
struct ContextFileExcerpt {
|
||||
pub file_name_and_range: SharedString,
|
||||
pub full_path_and_range: SharedString,
|
||||
pub parent_name: Option<SharedString>,
|
||||
pub icon_path: Option<SharedString>,
|
||||
}
|
||||
|
||||
impl Render for ContextPillPreview {
|
||||
impl ContextFileExcerpt {
|
||||
pub fn new(full_path: &Path, line_range: Range<Point>, cx: &App) -> Self {
|
||||
let full_path_string = full_path.to_string_lossy().into_owned();
|
||||
let file_name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
|
||||
let line_range_text = format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
|
||||
let mut full_path_and_range = full_path_string;
|
||||
full_path_and_range.push_str(&line_range_text);
|
||||
let mut file_name_and_range = file_name;
|
||||
file_name_and_range.push_str(&line_range_text);
|
||||
|
||||
let parent_name = full_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
|
||||
let icon_path = FileIcons::get_icon(&full_path, cx);
|
||||
|
||||
ContextFileExcerpt {
|
||||
file_name_and_range: file_name_and_range.into(),
|
||||
full_path_and_range: full_path_and_range.into(),
|
||||
parent_name,
|
||||
icon_path,
|
||||
}
|
||||
}
|
||||
|
||||
fn hover_view(&self, text: SharedString, cx: &mut App) -> Entity<ContextPillHover> {
|
||||
let icon_path = self.icon_path.clone();
|
||||
let full_path_and_range = self.full_path_and_range.clone();
|
||||
ContextPillHover::new(cx, move |_, cx| {
|
||||
v_flex()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.w_full()
|
||||
.max_w_full()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.children(
|
||||
icon_path
|
||||
.clone()
|
||||
.map(Icon::from_path)
|
||||
.map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
|
||||
)
|
||||
.child(
|
||||
// TODO: make this truncate on the left.
|
||||
Label::new(full_path_and_range.clone())
|
||||
.size(LabelSize::Small)
|
||||
.ml_1(),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("context-pill-hover-contents")
|
||||
.overflow_scroll()
|
||||
.max_w_128()
|
||||
.max_h_96()
|
||||
.child(Label::new(text.clone()).buffer_font(cx)),
|
||||
)
|
||||
.into_any_element()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextPillHover {
|
||||
render_hover: Box<dyn Fn(&mut Window, &mut App) -> AnyElement>,
|
||||
}
|
||||
|
||||
impl ContextPillHover {
|
||||
fn new(
|
||||
cx: &mut App,
|
||||
render_hover: impl Fn(&mut Window, &mut App) -> AnyElement + 'static,
|
||||
) -> Entity<Self> {
|
||||
cx.new(|_| Self {
|
||||
render_hover: Box::new(render_hover),
|
||||
})
|
||||
}
|
||||
|
||||
fn new_text(content: SharedString, cx: &mut App) -> Entity<Self> {
|
||||
Self::new(cx, move |_, _| {
|
||||
div()
|
||||
.id("context-pill-hover-contents")
|
||||
.overflow_scroll()
|
||||
.max_w_128()
|
||||
.max_h_96()
|
||||
.child(content.clone())
|
||||
.into_any_element()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ContextPillHover {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
tooltip_container(window, cx, move |this, window, cx| {
|
||||
this.occlude()
|
||||
.on_mouse_move(|_, _, cx| cx.stop_propagation())
|
||||
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
|
||||
.child((self.render_preview)(window, cx))
|
||||
.child((self.render_hover)(window, cx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -488,45 +718,40 @@ impl Component for AddedContext {
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
let mut next_context_id = ContextId::zero();
|
||||
let image_ready = (
|
||||
"Ready",
|
||||
AddedContext::new(
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(0),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
AddedContext::image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
project_path: None,
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
||||
}),
|
||||
);
|
||||
|
||||
let image_loading = (
|
||||
"Loading",
|
||||
AddedContext::new(
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(1),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: cx
|
||||
.background_spawn(async move {
|
||||
smol::Timer::after(Duration::from_secs(60 * 5)).await;
|
||||
Some(LanguageModelImage::empty())
|
||||
})
|
||||
.shared(),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
AddedContext::image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
project_path: None,
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: cx
|
||||
.background_spawn(async move {
|
||||
smol::Timer::after(Duration::from_secs(60 * 5)).await;
|
||||
Some(LanguageModelImage::empty())
|
||||
})
|
||||
.shared(),
|
||||
}),
|
||||
);
|
||||
|
||||
let image_error = (
|
||||
"Error",
|
||||
AddedContext::new(
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(2),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(None).shared(),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
AddedContext::image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
project_path: None,
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(None).shared(),
|
||||
}),
|
||||
);
|
||||
|
||||
Some(
|
||||
|
||||
@@ -98,6 +98,10 @@ impl RenderOnce for UsageBanner {
|
||||
}
|
||||
|
||||
impl Component for UsageBanner {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"AgentUsageBanner"
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
mod supported_countries;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -11,8 +9,6 @@ use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
|
||||
@@ -1,225 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Returns whether the given country code is supported by Anthropic.
|
||||
///
|
||||
/// <https://www.anthropic.com/supported-countries>
|
||||
pub fn is_supported_country(country_code: &str) -> bool {
|
||||
SUPPORTED_COUNTRIES.contains(&country_code)
|
||||
}
|
||||
|
||||
/// The list of country codes supported by Anthropic.
|
||||
///
|
||||
/// https://www.anthropic.com/supported-countries
|
||||
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
vec![
|
||||
"AL", // Albania
|
||||
"DZ", // Algeria
|
||||
"AS", // American Samoa (US)
|
||||
"AD", // Andorra
|
||||
"AO", // Angola
|
||||
"AI", // Anguilla (UK)
|
||||
"AG", // Antigua and Barbuda
|
||||
"AR", // Argentina
|
||||
"AM", // Armenia
|
||||
"AU", // Australia
|
||||
"AT", // Austria
|
||||
"AZ", // Azerbaijan
|
||||
"BS", // Bahamas
|
||||
"BH", // Bahrain
|
||||
"BD", // Bangladesh
|
||||
"BB", // Barbados
|
||||
"BE", // Belgium
|
||||
"BZ", // Belize
|
||||
"BJ", // Benin
|
||||
"BM", // Bermuda (UK)
|
||||
"BT", // Bhutan
|
||||
"BO", // Bolivia
|
||||
"BA", // Bosnia and Herzegovina
|
||||
"BW", // Botswana
|
||||
"BR", // Brazil
|
||||
"IO", // British Indian Ocean Territory (UK)
|
||||
"BN", // Brunei
|
||||
"BG", // Bulgaria
|
||||
"BF", // Burkina Faso
|
||||
"BI", // Burundi
|
||||
"CV", // Cabo Verde
|
||||
"KH", // Cambodia
|
||||
"CM", // Cameroon
|
||||
"CA", // Canada
|
||||
"KY", // Cayman Islands (UK)
|
||||
"TD", // Chad
|
||||
"CL", // Chile
|
||||
"CX", // Christmas Island (AU)
|
||||
"CC", // Cocos (Keeling) Islands (AU)
|
||||
"CO", // Colombia
|
||||
"KM", // Comoros
|
||||
"CG", // Congo (Brazzaville)
|
||||
"CK", // Cook Islands (NZ)
|
||||
"CR", // Costa Rica
|
||||
"CI", // Côte d'Ivoire
|
||||
"HR", // Croatia
|
||||
"CY", // Cyprus
|
||||
"CZ", // Czechia (Czech Republic)
|
||||
"DK", // Denmark
|
||||
"DJ", // Djibouti
|
||||
"DM", // Dominica
|
||||
"DO", // Dominican Republic
|
||||
"EC", // Ecuador
|
||||
"EG", // Egypt
|
||||
"SV", // El Salvador
|
||||
"GQ", // Equatorial Guinea
|
||||
"EE", // Estonia
|
||||
"SZ", // Eswatini
|
||||
"FK", // Falkland Islands (UK)
|
||||
"FJ", // Fiji
|
||||
"FI", // Finland
|
||||
"FR", // France
|
||||
"GF", // French Guiana (FR)
|
||||
"PF", // French Polynesia (FR)
|
||||
"TF", // French Southern Territories
|
||||
"GA", // Gabon
|
||||
"GM", // Gambia
|
||||
"GE", // Georgia
|
||||
"DE", // Germany
|
||||
"GH", // Ghana
|
||||
"GI", // Gibraltar (UK)
|
||||
"GR", // Greece
|
||||
"GD", // Grenada
|
||||
"GT", // Guatemala
|
||||
"GU", // Guam (US)
|
||||
"GN", // Guinea
|
||||
"GW", // Guinea-Bissau
|
||||
"GY", // Guyana
|
||||
"HT", // Haiti
|
||||
"HM", // Heard Island and McDonald Islands (AU)
|
||||
"HN", // Honduras
|
||||
"HU", // Hungary
|
||||
"IS", // Iceland
|
||||
"IN", // India
|
||||
"ID", // Indonesia
|
||||
"IQ", // Iraq
|
||||
"IE", // Ireland
|
||||
"IL", // Israel
|
||||
"IT", // Italy
|
||||
"JM", // Jamaica
|
||||
"JP", // Japan
|
||||
"JO", // Jordan
|
||||
"KZ", // Kazakhstan
|
||||
"KE", // Kenya
|
||||
"KI", // Kiribati
|
||||
"KW", // Kuwait
|
||||
"KG", // Kyrgyzstan
|
||||
"LA", // Laos
|
||||
"LV", // Latvia
|
||||
"LB", // Lebanon
|
||||
"LS", // Lesotho
|
||||
"LR", // Liberia
|
||||
"LI", // Liechtenstein
|
||||
"LT", // Lithuania
|
||||
"LU", // Luxembourg
|
||||
"MG", // Madagascar
|
||||
"MW", // Malawi
|
||||
"MY", // Malaysia
|
||||
"MV", // Maldives
|
||||
"MT", // Malta
|
||||
"MH", // Marshall Islands
|
||||
"MR", // Mauritania
|
||||
"MU", // Mauritius
|
||||
"MX", // Mexico
|
||||
"FM", // Micronesia
|
||||
"MD", // Moldova
|
||||
"MC", // Monaco
|
||||
"MN", // Mongolia
|
||||
"MS", // Montserrat (UK)
|
||||
"ME", // Montenegro
|
||||
"MA", // Morocco
|
||||
"MZ", // Mozambique
|
||||
"NA", // Namibia
|
||||
"NR", // Nauru
|
||||
"NP", // Nepal
|
||||
"NL", // Netherlands
|
||||
"NZ", // New Zealand
|
||||
"NE", // Niger
|
||||
"NG", // Nigeria
|
||||
"NF", // Norfolk Island (AU)
|
||||
"MK", // North Macedonia
|
||||
"MI", // Northern Mariana Islands (UK)
|
||||
"NO", // Norway
|
||||
"NU", // Niue (NZ)
|
||||
"OM", // Oman
|
||||
"PK", // Pakistan
|
||||
"PW", // Palau
|
||||
"PS", // Palestine
|
||||
"PA", // Panama
|
||||
"PG", // Papua New Guinea
|
||||
"PY", // Paraguay
|
||||
"PE", // Peru
|
||||
"PH", // Philippines
|
||||
"PN", // Pitcairn (UK)
|
||||
"PL", // Poland
|
||||
"PT", // Portugal
|
||||
"PR", // Puerto Rico (US)
|
||||
"QA", // Qatar
|
||||
"RO", // Romania
|
||||
"RW", // Rwanda
|
||||
"BL", // Saint Barthélemy (FR)
|
||||
"KN", // Saint Kitts and Nevis
|
||||
"LC", // Saint Lucia
|
||||
"MF", // Saint Martin (FR)
|
||||
"PM", // Saint Pierre and Miquelon (FR)
|
||||
"VC", // Saint Vincent and the Grenadines
|
||||
"WS", // Samoa
|
||||
"SM", // San Marino
|
||||
"ST", // São Tomé and Príncipe
|
||||
"SA", // Saudi Arabia
|
||||
"SN", // Senegal
|
||||
"RS", // Serbia
|
||||
"SC", // Seychelles
|
||||
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
|
||||
"SL", // Sierra Leone
|
||||
"SG", // Singapore
|
||||
"SK", // Slovakia
|
||||
"SI", // Slovenia
|
||||
"SB", // Solomon Islands
|
||||
"ZA", // South Africa
|
||||
"KR", // South Korea
|
||||
"ES", // Spain
|
||||
"LK", // Sri Lanka
|
||||
"SR", // Suriname
|
||||
"SE", // Sweden
|
||||
"CH", // Switzerland
|
||||
"TW", // Taiwan
|
||||
"TJ", // Tajikistan
|
||||
"TZ", // Tanzania
|
||||
"TH", // Thailand
|
||||
"TL", // Timor-Leste
|
||||
"TG", // Togo
|
||||
"TK", // Tokelau (NZ)
|
||||
"TO", // Tonga
|
||||
"TT", // Trinidad and Tobago
|
||||
"TN", // Tunisia
|
||||
"TR", // Türkiye (Turkey)
|
||||
"TM", // Turkmenistan
|
||||
"TC", // Turks and Caicos Islands (UK)
|
||||
"TV", // Tuvalu
|
||||
"UG", // Uganda
|
||||
"UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
|
||||
"AE", // United Arab Emirates
|
||||
"GB", // United Kingdom
|
||||
"UM", // United States Minor Outlying Islands (US)
|
||||
"US", // United States of America
|
||||
"UY", // Uruguay
|
||||
"UZ", // Uzbekistan
|
||||
"VU", // Vanuatu
|
||||
"VA", // Vatican City
|
||||
"VN", // Vietnam
|
||||
"VI", // Virgin Islands (US)
|
||||
"VG", // Virgin Islands (UK)
|
||||
"WF", // Wallis and Futuna (FR)
|
||||
"ZM", // Zambia
|
||||
"ZW", // Zimbabwe
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
@@ -15,6 +15,7 @@ path = "src/askpass.rs"
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
shlex.workspace = true
|
||||
smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -72,8 +72,7 @@ impl AskPassSession {
|
||||
let (askpass_opened_tx, askpass_opened_rx) = oneshot::channel::<()>();
|
||||
let listener =
|
||||
UnixListener::bind(&askpass_socket).context("failed to create askpass socket")?;
|
||||
let zed_path = std::env::current_exe()
|
||||
.context("Failed to figure out current executable path for use in askpass")?;
|
||||
let zed_path = get_shell_safe_zed_path()?;
|
||||
|
||||
let (askpass_kill_master_tx, askpass_kill_master_rx) = oneshot::channel::<()>();
|
||||
let mut kill_tx = Some(askpass_kill_master_tx);
|
||||
@@ -115,7 +114,7 @@ impl AskPassSession {
|
||||
// Create an askpass script that communicates back to this process.
|
||||
let askpass_script = format!(
|
||||
"{shebang}\n{print_args} | {zed_exe} --askpass={askpass_socket} 2> /dev/null \n",
|
||||
zed_exe = zed_path.display(),
|
||||
zed_exe = zed_path,
|
||||
askpass_socket = askpass_socket.display(),
|
||||
print_args = "printf '%s\\0' \"$@\"",
|
||||
shebang = "#!/bin/sh",
|
||||
@@ -161,6 +160,32 @@ impl AskPassSession {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn get_shell_safe_zed_path() -> anyhow::Result<String> {
|
||||
let zed_path = std::env::current_exe()
|
||||
.context("Failed to figure out current executable path for use in askpass")?
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
// sanity check on unix systems that the path exists and is executable
|
||||
// todo(windows): implement this check for windows (or just use `is-executable` crate)
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
let metadata = std::fs::metadata(&zed_path)
|
||||
.context("Failed to check metadata of Zed executable path for use in askpass")?;
|
||||
let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
|
||||
anyhow::ensure!(
|
||||
is_executable,
|
||||
"Failed to verify Zed executable path for use in askpass"
|
||||
);
|
||||
// As of writing, this can only be fail if the path contains a null byte, which shouldn't be possible
|
||||
// but shlex has annotated the error as #[non_exhaustive] so we can't make it a compile error if other
|
||||
// errors are introduced in the future :(
|
||||
let zed_path_escaped = shlex::try_quote(&zed_path)
|
||||
.context("Failed to shell-escape Zed executable path for use in askpass")?;
|
||||
|
||||
return Ok(zed_path_escaped.to_string());
|
||||
}
|
||||
|
||||
/// The main function for when Zed is running in netcat mode for use in askpass.
|
||||
/// Called from both the remote server binary and the zed binary in their respective main functions.
|
||||
#[cfg(unix)]
|
||||
|
||||
@@ -49,7 +49,7 @@ menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
prompt_library.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
rope.workspace = true
|
||||
|
||||
@@ -101,7 +101,7 @@ pub fn init(
|
||||
SlashCommandSettings::register(cx);
|
||||
|
||||
assistant_context_editor::init(client.clone(), cx);
|
||||
prompt_library::init(cx);
|
||||
rules_library::init(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -193,7 +193,7 @@ impl Focusable for ConfigurationView {
|
||||
impl Item for ConfigurationView {
|
||||
type Event = ConfigurationViewEvent;
|
||||
|
||||
fn tab_content_text(&self, _window: &Window, _cx: &App) -> Option<SharedString> {
|
||||
Some("Configuration".into())
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"Configuration".into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
|
||||
use prompt_store::{PromptBuilder, UserPromptId};
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
|
||||
use search::{BufferSearchBar, buffer_search::DivRegistrar};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -43,7 +43,7 @@ use workspace::{
|
||||
dock::{DockPosition, Panel, PanelEvent},
|
||||
pane,
|
||||
};
|
||||
use zed_actions::assistant::{InlineAssist, OpenPromptLibrary, ShowConfiguration, ToggleFocus};
|
||||
use zed_actions::assistant::{InlineAssist, OpenRulesLibrary, ShowConfiguration, ToggleFocus};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
|
||||
@@ -57,11 +57,11 @@ pub fn init(cx: &mut App) {
|
||||
.register_action(AssistantPanel::show_configuration)
|
||||
.register_action(AssistantPanel::create_new_context)
|
||||
.register_action(AssistantPanel::restart_context_servers)
|
||||
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_prompt_library(action, window, cx)
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -272,8 +272,8 @@ impl AssistantPanel {
|
||||
.action("New Chat", Box::new(NewChat))
|
||||
.action("History", Box::new(DeployHistory))
|
||||
.action(
|
||||
"Prompt Library",
|
||||
Box::new(OpenPromptLibrary::default()),
|
||||
"Rules Library",
|
||||
Box::new(OpenRulesLibrary::default()),
|
||||
)
|
||||
.action("Configure", Box::new(ShowConfiguration))
|
||||
.action(zoom_label, Box::new(ToggleZoom))
|
||||
@@ -476,7 +476,7 @@ impl AssistantPanel {
|
||||
{
|
||||
return;
|
||||
}
|
||||
context.custom_summary(new_summary, cx)
|
||||
context.set_custom_summary(new_summary, cx)
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1043,13 +1043,13 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn deploy_prompt_library(
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenPromptLibrary,
|
||||
action: &OpenRulesLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_prompt_library(
|
||||
open_rules_library(
|
||||
self.languages.clone(),
|
||||
Box::new(PromptLibraryInlineAssist),
|
||||
Arc::new(|| {
|
||||
@@ -1059,9 +1059,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action.prompt_to_select.map(|uuid| PromptId::User {
|
||||
uuid: UserPromptId(uuid),
|
||||
}),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -1235,7 +1235,7 @@ impl Render for AssistantPanel {
|
||||
this.show_configuration_tab(window, cx)
|
||||
}))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_history))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_prompt_library))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_rules_library))
|
||||
.child(registrar.size_full().child(self.pane.clone()))
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -1350,13 +1350,13 @@ impl Focusable for AssistantPanel {
|
||||
|
||||
struct PromptLibraryInlineAssist;
|
||||
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assistant.assist(&prompt_editor, None, None, initial_prompt, window, cx)
|
||||
|
||||
@@ -18,11 +18,11 @@ use editor::{
|
||||
},
|
||||
};
|
||||
use feature_flags::{
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedPro,
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedProFeatureFlag,
|
||||
};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt,
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
channel::mpsc,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
join,
|
||||
@@ -37,7 +37,7 @@ use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
|
||||
};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::{CodeAction, LspAction, ProjectTransaction};
|
||||
@@ -1652,7 +1652,7 @@ impl Render for PromptEditor {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -1759,6 +1759,7 @@ impl PromptEditor {
|
||||
language_model_selector: cx.new(|cx| {
|
||||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
@@ -1766,7 +1767,6 @@ impl PromptEditor {
|
||||
move |settings, _| settings.set_model(model.clone()),
|
||||
);
|
||||
},
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -1966,7 +1966,7 @@ impl PromptEditor {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedPro>()
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
@@ -2981,6 +2981,7 @@ impl CodegenAlternative {
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
@@ -3023,7 +3024,7 @@ impl CodegenAlternative {
|
||||
}
|
||||
}
|
||||
|
||||
let http_client = cx.http_client().clone();
|
||||
let http_client = cx.http_client();
|
||||
let telemetry = self.telemetry.clone();
|
||||
let language_name = {
|
||||
let multibuffer = self.buffer.read(cx);
|
||||
@@ -3056,7 +3057,8 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks =
|
||||
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -19,7 +19,7 @@ use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
|
||||
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
|
||||
use prompt_store::PromptBuilder;
|
||||
use settings::{Settings, update_settings_file};
|
||||
use std::{
|
||||
@@ -294,6 +294,7 @@ impl TerminalInlineAssistant {
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
@@ -748,6 +749,7 @@ impl PromptEditor {
|
||||
language_model_selector: cx.new(|cx| {
|
||||
let fs = fs.clone();
|
||||
LanguageModelSelector::new(
|
||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
@@ -755,7 +757,6 @@ impl PromptEditor {
|
||||
move |settings, _| settings.set_model(model.clone()),
|
||||
);
|
||||
},
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -459,6 +459,7 @@ pub enum ContextEvent {
|
||||
ShowMaxMonthlySpendReachedError,
|
||||
MessagesEdited,
|
||||
SummaryChanged,
|
||||
SummaryGenerated,
|
||||
StreamedCompletion,
|
||||
StartedThoughtProcess(Range<language::Anchor>),
|
||||
EndedThoughtProcess(language::Anchor),
|
||||
@@ -482,7 +483,7 @@ pub enum ContextEvent {
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct ContextSummary {
|
||||
pub text: String,
|
||||
done: bool,
|
||||
pub done: bool,
|
||||
timestamp: clock::Lamport,
|
||||
}
|
||||
|
||||
@@ -640,7 +641,7 @@ pub struct AssistantContext {
|
||||
contents: Vec<Content>,
|
||||
messages_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: Option<ContextSummary>,
|
||||
pending_summary: Task<Option<()>>,
|
||||
summary_task: Task<Option<()>>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
token_count: Option<usize>,
|
||||
@@ -741,7 +742,7 @@ impl AssistantContext {
|
||||
thought_process_output_sections: Vec::new(),
|
||||
edits_since_last_parse: edits_since_last_slash_command_parse,
|
||||
summary: None,
|
||||
pending_summary: Task::ready(None),
|
||||
summary_task: Task::ready(None),
|
||||
completion_count: Default::default(),
|
||||
pending_completions: Default::default(),
|
||||
token_count: None,
|
||||
@@ -951,7 +952,7 @@ impl AssistantContext {
|
||||
|
||||
fn flush_ops(&mut self, cx: &mut Context<AssistantContext>) {
|
||||
let mut changed_messages = HashSet::default();
|
||||
let mut summary_changed = false;
|
||||
let mut summary_generated = false;
|
||||
|
||||
self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
|
||||
for op in mem::take(&mut self.pending_ops) {
|
||||
@@ -993,7 +994,7 @@ impl AssistantContext {
|
||||
.map_or(true, |summary| new_summary.timestamp > summary.timestamp)
|
||||
{
|
||||
self.summary = Some(new_summary);
|
||||
summary_changed = true;
|
||||
summary_generated = true;
|
||||
}
|
||||
}
|
||||
ContextOperation::SlashCommandStarted {
|
||||
@@ -1072,8 +1073,9 @@ impl AssistantContext {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
if summary_changed {
|
||||
if summary_generated {
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
cx.emit(ContextEvent::SummaryGenerated);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
@@ -2557,6 +2559,7 @@ impl AssistantContext {
|
||||
let mut completion_request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
messages: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
@@ -2611,7 +2614,9 @@ impl AssistantContext {
|
||||
.map(MessageContent::Text),
|
||||
);
|
||||
|
||||
completion_request.messages.push(request_message);
|
||||
if !request_message.contents_empty() {
|
||||
completion_request.messages.push(request_message);
|
||||
}
|
||||
}
|
||||
|
||||
if let RequestType::SuggestEdits = request_type {
|
||||
@@ -2945,7 +2950,7 @@ impl AssistantContext {
|
||||
self.message_anchors.insert(insertion_ix, new_anchor);
|
||||
}
|
||||
|
||||
pub fn summarize(&mut self, replace_old: bool, cx: &mut Context<Self>) {
|
||||
pub fn summarize(&mut self, mut replace_old: bool, cx: &mut Context<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
|
||||
return;
|
||||
};
|
||||
@@ -2965,7 +2970,18 @@ impl AssistantContext {
|
||||
cache: false,
|
||||
});
|
||||
|
||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||
// If there is no summary, it is set with `done: false` so that "Loading Summary…" can
|
||||
// be displayed.
|
||||
if self.summary.is_none() {
|
||||
self.summary = Some(ContextSummary {
|
||||
text: "".to_string(),
|
||||
done: false,
|
||||
timestamp: clock::Lamport::default(),
|
||||
});
|
||||
replace_old = true;
|
||||
}
|
||||
|
||||
self.summary_task = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let stream = model.model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
@@ -2990,6 +3006,7 @@ impl AssistantContext {
|
||||
};
|
||||
this.push_op(operation, cx);
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
cx.emit(ContextEvent::SummaryGenerated);
|
||||
})?;
|
||||
|
||||
// Stop if the LLM generated multiple lines.
|
||||
@@ -3010,6 +3027,7 @@ impl AssistantContext {
|
||||
};
|
||||
this.push_op(operation, cx);
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
cx.emit(ContextEvent::SummaryGenerated);
|
||||
}
|
||||
})?;
|
||||
|
||||
@@ -3182,7 +3200,7 @@ impl AssistantContext {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn custom_summary(&mut self, custom_summary: String, cx: &mut Context<Self>) {
|
||||
pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context<Self>) {
|
||||
let timestamp = self.next_timestamp();
|
||||
let summary = self.summary.get_or_insert(ContextSummary::default());
|
||||
summary.timestamp = timestamp;
|
||||
@@ -3190,6 +3208,15 @@ impl AssistantContext {
|
||||
summary.text = custom_summary;
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
self.summary
|
||||
.as_ref()
|
||||
.map(|summary| summary.text.clone().into())
|
||||
.unwrap_or(Self::DEFAULT_SUMMARY)
|
||||
}
|
||||
}
|
||||
|
||||
fn trimmed_text_in_range(buffer: &BufferSnapshot, range: Range<text::Anchor>) -> String {
|
||||
|
||||
@@ -39,7 +39,7 @@ use language_model::{
|
||||
Role,
|
||||
};
|
||||
use language_model_selector::{
|
||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
|
||||
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use picker::Picker;
|
||||
@@ -48,7 +48,7 @@ use project::{Project, Worktree};
|
||||
use rope::Point;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use std::{any::TypeId, borrow::Cow, cmp, ops::Range, path::PathBuf, sync::Arc, time::Duration};
|
||||
use std::{any::TypeId, cmp, ops::Range, path::PathBuf, sync::Arc, time::Duration};
|
||||
use text::SelectionGoal;
|
||||
use ui::{
|
||||
ButtonLike, Disclosure, ElevationIndex, KeyBinding, PopoverMenuHandle, TintColor, Tooltip,
|
||||
@@ -291,6 +291,7 @@ impl ContextEditor {
|
||||
dragged_file_worktrees: Vec::new(),
|
||||
language_model_selector: cx.new(|cx| {
|
||||
LanguageModelSelector::new(
|
||||
|cx| LanguageModelRegistry::read_global(cx).default_model(),
|
||||
move |model, cx| {
|
||||
update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
@@ -298,7 +299,6 @@ impl ContextEditor {
|
||||
move |settings, _| settings.set_model(model.clone()),
|
||||
);
|
||||
},
|
||||
ModelType::Default,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -618,6 +618,7 @@ impl ContextEditor {
|
||||
context.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
|
||||
});
|
||||
}
|
||||
ContextEvent::SummaryGenerated => {}
|
||||
ContextEvent::StartedThoughtProcess(range) => {
|
||||
let creases = self.insert_thought_process_output_sections(
|
||||
[(
|
||||
@@ -2179,13 +2180,8 @@ impl ContextEditor {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn title(&self, cx: &App) -> Cow<str> {
|
||||
self.context
|
||||
.read(cx)
|
||||
.summary()
|
||||
.map(|summary| summary.text.clone())
|
||||
.map(Cow::Owned)
|
||||
.unwrap_or_else(|| Cow::Borrowed(DEFAULT_TAB_TITLE))
|
||||
pub fn title(&self, cx: &App) -> SharedString {
|
||||
self.context.read(cx).summary_or_default()
|
||||
}
|
||||
|
||||
fn render_patch_block(
|
||||
@@ -3160,8 +3156,8 @@ impl Focusable for ContextEditor {
|
||||
impl Item for ContextEditor {
|
||||
type Event = editor::EditorEvent;
|
||||
|
||||
fn tab_content_text(&self, _window: &Window, cx: &App) -> Option<SharedString> {
|
||||
Some(util::truncate_and_trailoff(&self.title(cx), MAX_TAB_TITLE_LEN).into())
|
||||
fn tab_content_text(&self, _detail: usize, cx: &App) -> SharedString {
|
||||
util::truncate_and_trailoff(&self.title(cx), MAX_TAB_TITLE_LEN).into()
|
||||
}
|
||||
|
||||
fn to_item_events(event: &Self::Event, mut f: impl FnMut(item::ItemEvent)) {
|
||||
@@ -3768,7 +3764,7 @@ pub fn make_lsp_adapter_delegate(
|
||||
let Some(worktree) = project.worktrees(cx).next() else {
|
||||
return Ok(None::<Arc<dyn LspAdapterDelegate>>);
|
||||
};
|
||||
let http_client = project.client().http_client().clone();
|
||||
let http_client = project.client().http_client();
|
||||
project.lsp_store().update(cx, |_, cx| {
|
||||
Ok(Some(LocalLspAdapterDelegate::new(
|
||||
project.languages().clone(),
|
||||
|
||||
@@ -108,8 +108,8 @@ impl EventEmitter<()> for ContextHistory {}
|
||||
impl Item for ContextHistory {
|
||||
type Event = ();
|
||||
|
||||
fn tab_content_text(&self, _window: &Window, _cx: &App) -> Option<SharedString> {
|
||||
Some("History".into())
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"History".into()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ impl DocsSlashCommand {
|
||||
.upgrade()
|
||||
.ok_or_else(|| anyhow!("workspace was dropped"))?;
|
||||
let project = workspace.read(cx).project().clone();
|
||||
anyhow::Ok(project.read(cx).client().http_client().clone())
|
||||
anyhow::Ok(project.read(cx).client().http_client())
|
||||
});
|
||||
|
||||
if let Some(http_client) = http_client.log_err() {
|
||||
|
||||
@@ -10,6 +10,11 @@ pub fn adapt_schema_to_format(
|
||||
json: &mut Value,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<()> {
|
||||
if let Value::Object(obj) = json {
|
||||
obj.remove("$schema");
|
||||
obj.remove("title");
|
||||
}
|
||||
|
||||
match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
|
||||
@@ -30,7 +35,12 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
|
||||
const KEYS_TO_REMOVE: [&str; 4] = [
|
||||
"format",
|
||||
"additionalProperties",
|
||||
"exclusiveMinimum",
|
||||
"exclusiveMaximum",
|
||||
];
|
||||
for key in KEYS_TO_REMOVE {
|
||||
obj.remove(key);
|
||||
}
|
||||
@@ -45,7 +55,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
}
|
||||
|
||||
// If a type is not specified for an input parameter, add a default type
|
||||
if obj.contains_key("description")
|
||||
if matches!(obj.get("description"), Some(Value::String(_)))
|
||||
&& !obj.contains_key("type")
|
||||
&& !(obj.contains_key("anyOf")
|
||||
|| obj.contains_key("oneOf")
|
||||
@@ -117,14 +127,37 @@ mod tests {
|
||||
"type": "string"
|
||||
})
|
||||
);
|
||||
|
||||
// Ensure that we do not add a type if it is an object
|
||||
let mut json = json!({
|
||||
"description": {
|
||||
"value": "abc",
|
||||
"type": "string"
|
||||
}
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"description": {
|
||||
"value": "abc",
|
||||
"type": "string"
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_removes_format() {
|
||||
fn test_transform_removes_unsupported_keys() {
|
||||
let mut json = json!({
|
||||
"description": "A test field",
|
||||
"type": "integer",
|
||||
"format": "uint32"
|
||||
"format": "uint32",
|
||||
"exclusiveMinimum": 0,
|
||||
"exclusiveMaximum": 100,
|
||||
"additionalProperties": false
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
@@ -37,9 +37,8 @@ serde_json.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
web_search.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
worktree.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -59,7 +59,6 @@ use crate::thinking_tool::ThinkingTool;
|
||||
pub use create_file_tool::CreateFileToolInput;
|
||||
pub use edit_file_tool::EditFileToolInput;
|
||||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use list_directory_tool::ListDirectoryToolInput;
|
||||
pub use read_file_tool::ReadFileToolInput;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
@@ -111,11 +110,38 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::Client;
|
||||
use clock::FakeSystemClock;
|
||||
use http_client::FakeHttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_json_schema() {
|
||||
#[derive(Serialize, JsonSchema)]
|
||||
struct GetWeatherTool {
|
||||
location: String,
|
||||
}
|
||||
|
||||
let schema = schema::json_schema_for::<GetWeatherTool>(
|
||||
language_model::LanguageModelToolSchemaFormat::JsonSchema,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
schema,
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
|
||||
@@ -14,7 +14,7 @@ use regex::{Regex, RegexBuilder};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CodeSymbolsInput {
|
||||
@@ -102,7 +102,7 @@ impl Tool for CodeSymbolsTool {
|
||||
|
||||
match &input.path {
|
||||
Some(path) => {
|
||||
let path = MarkdownString::inline_code(path);
|
||||
let path = MarkdownInlineCode(path);
|
||||
if page > 1 {
|
||||
format!("List page {page} of code symbols for {path}")
|
||||
} else {
|
||||
|
||||
@@ -11,7 +11,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, path::Path};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
/// If the model requests to read a file whose size exceeds this, then
|
||||
/// the tool will return the file's symbol outline instead of its contents,
|
||||
@@ -82,7 +82,7 @@ impl Tool for ContentsTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<ContentsToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownString::inline_code(&input.path);
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
|
||||
match (input.start, input.end) {
|
||||
(Some(start), None) => format!("Read {path} (from line {start})"),
|
||||
|
||||
@@ -10,7 +10,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CopyPathToolInput {
|
||||
@@ -63,8 +63,8 @@ impl Tool for CopyPathTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<CopyPathToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let src = MarkdownString::inline_code(&input.source_path);
|
||||
let dest = MarkdownString::inline_code(&input.destination_path);
|
||||
let src = MarkdownInlineCode(&input.source_path);
|
||||
let dest = MarkdownInlineCode(&input.destination_path);
|
||||
format!("Copy {src} to {dest}")
|
||||
}
|
||||
Err(_) => "Copy path".to_string(),
|
||||
|
||||
@@ -10,7 +10,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CreateDirectoryToolInput {
|
||||
@@ -53,10 +53,7 @@ impl Tool for CreateDirectoryTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<CreateDirectoryToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
format!(
|
||||
"Create directory {}",
|
||||
MarkdownString::inline_code(&input.path)
|
||||
)
|
||||
format!("Create directory {}", MarkdownInlineCode(&input.path))
|
||||
}
|
||||
Err(_) => "Create directory".to_string(),
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CreateFileToolInput {
|
||||
@@ -73,7 +73,7 @@ impl Tool for CreateFileTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownString::inline_code(&input.path);
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
format!("Create file {path}")
|
||||
}
|
||||
Err(_) => DEFAULT_UI_TEXT.to_string(),
|
||||
|
||||
@@ -9,7 +9,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, path::Path, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DiagnosticsToolInput {
|
||||
@@ -66,11 +66,11 @@ impl Tool for DiagnosticsTool {
|
||||
if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
|
||||
.ok()
|
||||
.and_then(|input| match input.path {
|
||||
Some(path) if !path.is_empty() => Some(MarkdownString::inline_code(&path)),
|
||||
Some(path) if !path.is_empty() => Some(path),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
format!("Check diagnostics for {path}")
|
||||
format!("Check diagnostics for {}", MarkdownInlineCode(&path))
|
||||
} else {
|
||||
"Check project diagnostics".to_string()
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownEscaped;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
|
||||
enum ContentType {
|
||||
@@ -134,7 +134,7 @@ impl Tool for FetchTool {
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<FetchToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Fetch {}", MarkdownString::escape(&input.url)),
|
||||
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)),
|
||||
Err(_) => "Fetch URL".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use editor::Editor;
|
||||
use futures::channel::oneshot::{self, Receiver};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
};
|
||||
use language;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot;
|
||||
use std::fmt::Write;
|
||||
use std::{cmp, path::PathBuf, sync::Arc};
|
||||
use ui::{Disclosure, Tooltip, prelude::*};
|
||||
use util::{ResultExt, paths::PathMatcher};
|
||||
use workspace::Workspace;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FindPathToolInput {
|
||||
@@ -29,7 +35,7 @@ pub struct FindPathToolInput {
|
||||
/// Optional starting position for paginated results (0-based).
|
||||
/// When not provided, starts from the beginning.
|
||||
#[serde(default)]
|
||||
pub offset: u32,
|
||||
pub offset: usize,
|
||||
}
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
@@ -77,13 +83,20 @@ impl Tool for FindPathTool {
|
||||
Ok(input) => (input.offset, input.glob),
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let offset = offset as usize;
|
||||
let task = search_paths(&glob, project, cx);
|
||||
cx.background_spawn(async move {
|
||||
let matches = task.await?;
|
||||
let paginated_matches = &matches[cmp::min(offset, matches.len())
|
||||
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
|
||||
let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx));
|
||||
|
||||
let search_paths_task = search_paths(&glob, project, cx);
|
||||
|
||||
let task = cx.background_spawn(async move {
|
||||
let matches = search_paths_task.await?;
|
||||
let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len())
|
||||
..cmp::min(offset + RESULTS_PER_PAGE, matches.len())];
|
||||
|
||||
sender.send(paginated_matches.to_vec()).log_err();
|
||||
|
||||
if matches.is_empty() {
|
||||
Ok("No matches found".to_string())
|
||||
} else {
|
||||
@@ -102,8 +115,12 @@ impl Tool for FindPathTool {
|
||||
}
|
||||
Ok(message)
|
||||
}
|
||||
})
|
||||
.into()
|
||||
});
|
||||
|
||||
ToolResult {
|
||||
output: task,
|
||||
card: Some(card.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +132,7 @@ fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Resu
|
||||
Ok(matcher) => matcher,
|
||||
Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))),
|
||||
};
|
||||
let snapshots: Vec<Snapshot> = project
|
||||
let snapshots: Vec<_> = project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
@@ -135,6 +152,209 @@ fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Resu
|
||||
})
|
||||
}
|
||||
|
||||
struct FindPathToolCard {
|
||||
paths: Vec<PathBuf>,
|
||||
expanded: bool,
|
||||
glob: String,
|
||||
_receiver_task: Option<Task<Result<()>>>,
|
||||
}
|
||||
|
||||
impl FindPathToolCard {
|
||||
fn new(glob: String, receiver: Receiver<Vec<PathBuf>>, cx: &mut Context<Self>) -> Self {
|
||||
let _receiver_task = cx.spawn(async move |this, cx| {
|
||||
let paths = receiver.await?;
|
||||
|
||||
this.update(cx, |this, _cx| {
|
||||
this.paths = paths;
|
||||
})
|
||||
.log_err();
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
Self {
|
||||
paths: Vec::new(),
|
||||
expanded: false,
|
||||
glob,
|
||||
_receiver_task: Some(_receiver_task),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCard for FindPathToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
_status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let matches_label: SharedString = if self.paths.len() == 0 {
|
||||
"No matches".into()
|
||||
} else if self.paths.len() == 1 {
|
||||
"1 match".into()
|
||||
} else {
|
||||
format!("{} matches", self.paths.len()).into()
|
||||
};
|
||||
|
||||
let glob_label = self.glob.to_string();
|
||||
|
||||
let content = if !self.paths.is_empty() && self.expanded {
|
||||
Some(
|
||||
v_flex()
|
||||
.relative()
|
||||
.ml_1p5()
|
||||
.px_1p5()
|
||||
.gap_0p5()
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.children(self.paths.iter().enumerate().map(|(index, path)| {
|
||||
let path_clone = path.clone();
|
||||
let workspace_clone = workspace.clone();
|
||||
let button_label = path.to_string_lossy().to_string();
|
||||
|
||||
Button::new(("path", index), button_label)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_position(IconPosition::End)
|
||||
.label_size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Jump to File"))
|
||||
.on_click(move |_, window, cx| {
|
||||
workspace_clone
|
||||
.update(cx, |workspace, cx| {
|
||||
let path = PathBuf::from(&path_clone);
|
||||
let Some(project_path) = workspace
|
||||
.project()
|
||||
.read(cx)
|
||||
.find_project_path(&path, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let open_task = workspace.open_path(
|
||||
project_path,
|
||||
None,
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let item = open_task.await?;
|
||||
if let Some(active_editor) =
|
||||
item.downcast::<Editor>()
|
||||
{
|
||||
active_editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(
|
||||
language::Point::new(0, 0),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}))
|
||||
.into_any(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.mb_2()
|
||||
.gap_1()
|
||||
.child(
|
||||
ToolCallCardHeader::new(IconName::SearchCode, matches_label)
|
||||
.with_code_path(glob_label)
|
||||
.disclosure_slot(
|
||||
Disclosure::new("path-search-disclosure", self.expanded)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.disabled(self.paths.is_empty())
|
||||
.on_click(cx.listener(move |this, _, _, _cx| {
|
||||
this.expanded = !this.expanded;
|
||||
})),
|
||||
),
|
||||
)
|
||||
.children(content)
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for FindPathTool {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
}
|
||||
|
||||
fn sort_name() -> &'static str {
|
||||
"FindPathTool"
|
||||
}
|
||||
|
||||
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
let successful_card = cx.new(|_| FindPathToolCard {
|
||||
paths: vec![
|
||||
PathBuf::from("src/main.rs"),
|
||||
PathBuf::from("src/lib.rs"),
|
||||
PathBuf::from("tests/test.rs"),
|
||||
],
|
||||
expanded: true,
|
||||
glob: "*.rs".to_string(),
|
||||
_receiver_task: None,
|
||||
});
|
||||
|
||||
let empty_card = cx.new(|_| FindPathToolCard {
|
||||
paths: Vec::new(),
|
||||
expanded: false,
|
||||
glob: "*.nonexistent".to_string(),
|
||||
_receiver_task: None,
|
||||
});
|
||||
|
||||
Some(
|
||||
v_flex()
|
||||
.gap_6()
|
||||
.children(vec![example_group(vec![
|
||||
single_example(
|
||||
"With Paths",
|
||||
div()
|
||||
.size_full()
|
||||
.child(successful_card.update(cx, |tool, cx| {
|
||||
tool.render(
|
||||
&ToolUseStatus::Finished("".into()),
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"No Paths",
|
||||
div()
|
||||
.size_full()
|
||||
.child(empty_card.update(cx, |tool, cx| {
|
||||
tool.render(
|
||||
&ToolUseStatus::Finished("".into()),
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
])])
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
@@ -13,7 +13,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{cmp, fmt::Write, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
use util::paths::PathMatcher;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -75,7 +75,7 @@ impl Tool for GrepTool {
|
||||
match serde_json::from_value::<GrepToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let page = input.page();
|
||||
let regex_str = MarkdownString::inline_code(&input.regex);
|
||||
let regex_str = MarkdownInlineCode(&input.regex);
|
||||
let case_info = if input.case_sensitive {
|
||||
" (case-sensitive)"
|
||||
} else {
|
||||
|
||||
@@ -8,7 +8,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, path::Path, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ListDirectoryToolInput {
|
||||
@@ -63,7 +63,7 @@ impl Tool for ListDirectoryTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownString::inline_code(&input.path);
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
format!("List the {path} directory's contents")
|
||||
}
|
||||
Err(_) => "List directory".to_string(),
|
||||
|
||||
@@ -8,7 +8,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::Path, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct MovePathToolInput {
|
||||
@@ -61,8 +61,8 @@ impl Tool for MovePathTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<MovePathToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let src = MarkdownString::inline_code(&input.source_path);
|
||||
let dest = MarkdownString::inline_code(&input.destination_path);
|
||||
let src = MarkdownInlineCode(&input.source_path);
|
||||
let dest = MarkdownInlineCode(&input.destination_path);
|
||||
let src_path = Path::new(&input.source_path);
|
||||
let dest_path = Path::new(&input.destination_path);
|
||||
|
||||
@@ -71,7 +71,7 @@ impl Tool for MovePathTool {
|
||||
.and_then(|os_str| os_str.to_os_string().into_string().ok())
|
||||
{
|
||||
Some(filename) if src_path.parent() == dest_path.parent() => {
|
||||
let filename = MarkdownString::inline_code(&filename);
|
||||
let filename = MarkdownInlineCode(&filename);
|
||||
format!("Rename {src} to {filename}")
|
||||
}
|
||||
_ => {
|
||||
|
||||
@@ -8,7 +8,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownEscaped;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OpenToolInput {
|
||||
@@ -41,7 +41,7 @@ impl Tool for OpenTool {
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<OpenToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Open `{}`", MarkdownString::escape(&input.path_or_url)),
|
||||
Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)),
|
||||
Err(_) => "Open file or URL".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
/// If the model requests to read a file whose size exceeds this, then
|
||||
/// the tool will return an error along with the model's symbol outline,
|
||||
@@ -40,7 +40,7 @@ pub struct ReadFileToolInput {
|
||||
#[serde(default)]
|
||||
pub start_line: Option<usize>,
|
||||
|
||||
/// Optional line number to end reading on (1-based index)
|
||||
/// Optional line number to end reading on (1-based index, inclusive)
|
||||
#[serde(default)]
|
||||
pub end_line: Option<usize>,
|
||||
}
|
||||
@@ -71,7 +71,7 @@ impl Tool for ReadFileTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<ReadFileToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let path = MarkdownString::inline_code(&input.path);
|
||||
let path = MarkdownInlineCode(&input.path);
|
||||
match (input.start_line, input.end_line) {
|
||||
(Some(start), None) => format!("Read file {path} (from line {start})"),
|
||||
(Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"),
|
||||
@@ -128,7 +128,7 @@ impl Tool for ReadFileTool {
|
||||
let start = input.start_line.unwrap_or(1);
|
||||
let lines = text.split('\n').skip(start - 1);
|
||||
if let Some(end) = input.end_line {
|
||||
let count = end.saturating_sub(start).max(1); // Ensure at least 1 line
|
||||
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
|
||||
Itertools::intersperse(lines.take(count), "\n").collect()
|
||||
} else {
|
||||
Itertools::intersperse(lines, "\n").collect()
|
||||
@@ -329,7 +329,7 @@ mod test {
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3");
|
||||
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
||||
@@ -8,7 +8,7 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, ops::Range, sync::Arc};
|
||||
use ui::IconName;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
|
||||
@@ -91,7 +91,7 @@ impl Tool for SymbolInfoTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
|
||||
Ok(input) => {
|
||||
let symbol = MarkdownString::inline_code(&input.symbol);
|
||||
let symbol = MarkdownInlineCode(&input.symbol);
|
||||
|
||||
match input.command {
|
||||
Info::Definition => {
|
||||
|
||||
@@ -15,7 +15,7 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use ui::IconName;
|
||||
use util::command::new_smol_command;
|
||||
use util::markdown::MarkdownString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TerminalToolInput {
|
||||
@@ -55,17 +55,14 @@ impl Tool for TerminalTool {
|
||||
let first_line = lines.next().unwrap_or_default();
|
||||
let remaining_line_count = lines.count();
|
||||
match remaining_line_count {
|
||||
0 => MarkdownString::inline_code(&first_line).0,
|
||||
1 => {
|
||||
MarkdownString::inline_code(&format!(
|
||||
"{} - {} more line",
|
||||
first_line, remaining_line_count
|
||||
))
|
||||
.0
|
||||
}
|
||||
n => {
|
||||
MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
|
||||
}
|
||||
0 => MarkdownInlineCode(&first_line).to_string(),
|
||||
1 => MarkdownInlineCode(&format!(
|
||||
"{} - {} more line",
|
||||
first_line, remaining_line_count
|
||||
))
|
||||
.to_string(),
|
||||
n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n))
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
Err(_) => "Run terminal command".to_string(),
|
||||
@@ -205,39 +202,52 @@ async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<
|
||||
consume_reader(out_reader, truncated).await?;
|
||||
consume_reader(err_reader, truncated).await?;
|
||||
|
||||
let status = cmd.status().await.context("Failed to get command status")?;
|
||||
// Handle potential errors during status retrieval, including interruption.
|
||||
match cmd.status().await {
|
||||
Ok(status) => {
|
||||
let output_string = if truncated {
|
||||
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
|
||||
// multi-byte characters.
|
||||
let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
|
||||
let buffer_content =
|
||||
&combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
|
||||
|
||||
let output_string = if truncated {
|
||||
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
|
||||
// multi-byte characters.
|
||||
let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
|
||||
let combined_buffer = &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
|
||||
format!(
|
||||
"Command output too long. The first {} bytes:\n\n{}",
|
||||
buffer_content.len(),
|
||||
output_block(buffer_content),
|
||||
)
|
||||
} else {
|
||||
output_block(&combined_buffer)
|
||||
};
|
||||
|
||||
format!(
|
||||
"Command output too long. The first {} bytes:\n\n{}",
|
||||
combined_buffer.len(),
|
||||
output_block(&combined_buffer),
|
||||
)
|
||||
} else {
|
||||
output_block(&combined_buffer)
|
||||
};
|
||||
let output_with_status = if status.success() {
|
||||
if output_string.is_empty() {
|
||||
"Command executed successfully.".to_string()
|
||||
} else {
|
||||
output_string
|
||||
}
|
||||
} else {
|
||||
format!(
|
||||
"Command failed with exit code {} (shell: {}).\n\n{}",
|
||||
status.code().unwrap_or(-1),
|
||||
shell,
|
||||
output_string,
|
||||
)
|
||||
};
|
||||
|
||||
let output_with_status = if status.success() {
|
||||
if output_string.is_empty() {
|
||||
"Command executed successfully.".to_string()
|
||||
} else {
|
||||
output_string.to_string()
|
||||
Ok(output_with_status)
|
||||
}
|
||||
} else {
|
||||
format!(
|
||||
"Command failed with exit code {} (shell: {}).\n\n{}",
|
||||
status.code().unwrap_or(-1),
|
||||
shell,
|
||||
output_string,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(output_with_status)
|
||||
Err(err) => {
|
||||
// Error occurred getting status (potential interruption). Include partial output.
|
||||
let partial_output = output_block(&combined_buffer);
|
||||
let error_message = format!(
|
||||
"Command failed or was interrupted.\nPartial output captured:\n\n{}",
|
||||
partial_output
|
||||
);
|
||||
Err(anyhow!(err).context(error_message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn consume_reader<T: AsyncReadExt + Unpin>(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use gpui::{Animation, AnimationExt, App, IntoElement, pulsating_between};
|
||||
use gpui::{Animation, AnimationExt, AnyElement, App, IntoElement, pulsating_between};
|
||||
use std::time::Duration;
|
||||
use ui::{Tooltip, prelude::*};
|
||||
|
||||
@@ -8,6 +8,8 @@ pub struct ToolCallCardHeader {
|
||||
icon: IconName,
|
||||
primary_text: SharedString,
|
||||
secondary_text: Option<SharedString>,
|
||||
code_path: Option<SharedString>,
|
||||
disclosure_slot: Option<AnyElement>,
|
||||
is_loading: bool,
|
||||
error: Option<String>,
|
||||
}
|
||||
@@ -18,6 +20,8 @@ impl ToolCallCardHeader {
|
||||
icon,
|
||||
primary_text: primary_text.into(),
|
||||
secondary_text: None,
|
||||
code_path: None,
|
||||
disclosure_slot: None,
|
||||
is_loading: false,
|
||||
error: None,
|
||||
}
|
||||
@@ -28,6 +32,16 @@ impl ToolCallCardHeader {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_code_path(mut self, text: impl Into<SharedString>) -> Self {
|
||||
self.code_path = Some(text.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn disclosure_slot(mut self, element: impl IntoElement) -> Self {
|
||||
self.disclosure_slot = Some(element.into_any_element());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn loading(mut self) -> Self {
|
||||
self.is_loading = true;
|
||||
self
|
||||
@@ -42,26 +56,36 @@ impl ToolCallCardHeader {
|
||||
impl RenderOnce for ToolCallCardHeader {
|
||||
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let font_size = rems(0.8125);
|
||||
let line_height = window.line_height();
|
||||
|
||||
let secondary_text = self.secondary_text;
|
||||
let code_path = self.code_path;
|
||||
|
||||
let bullet_divider = || {
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text)
|
||||
};
|
||||
|
||||
h_flex()
|
||||
.id("tool-label-container")
|
||||
.gap_1p5()
|
||||
.gap_2()
|
||||
.max_w_full()
|
||||
.overflow_x_scroll()
|
||||
.opacity(0.8)
|
||||
.child(
|
||||
h_flex().h(window.line_height()).justify_center().child(
|
||||
Icon::new(self.icon)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.h(window.line_height())
|
||||
.h(line_height)
|
||||
.gap_1p5()
|
||||
.text_size(font_size)
|
||||
.child(
|
||||
h_flex().h(line_height).justify_center().child(
|
||||
Icon::new(self.icon)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.map(|this| {
|
||||
if let Some(error) = &self.error {
|
||||
this.child(format!("{} failed", self.primary_text)).child(
|
||||
@@ -76,13 +100,15 @@ impl RenderOnce for ToolCallCardHeader {
|
||||
}
|
||||
})
|
||||
.when_some(secondary_text, |this, secondary_text| {
|
||||
this.child(
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text),
|
||||
this.child(bullet_divider())
|
||||
.child(div().text_size(font_size).child(secondary_text.clone()))
|
||||
})
|
||||
.when_some(code_path, |this, code_path| {
|
||||
this.child(bullet_divider()).child(
|
||||
Label::new(code_path.clone())
|
||||
.size(LabelSize::Small)
|
||||
.inline_code(cx),
|
||||
)
|
||||
.child(div().text_size(font_size).child(secondary_text.clone()))
|
||||
})
|
||||
.with_animation(
|
||||
"loading-label",
|
||||
@@ -98,5 +124,11 @@ impl RenderOnce for ToolCallCardHeader {
|
||||
},
|
||||
),
|
||||
)
|
||||
.when_some(self.disclosure_slot, |container, disclosure_slot| {
|
||||
container
|
||||
.group("disclosure")
|
||||
.justify_between()
|
||||
.child(div().visible_on_hover("disclosure").child(disclosure_slot))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ fn view_release_notes_locally(
|
||||
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
|
||||
let tab_description = SharedString::from(body.title.to_string());
|
||||
let tab_content = SharedString::from(body.title.to_string());
|
||||
let editor = cx.new(|cx| {
|
||||
Editor::for_multibuffer(buffer, Some(project), window, cx)
|
||||
});
|
||||
@@ -102,7 +102,7 @@ fn view_release_notes_locally(
|
||||
editor,
|
||||
workspace_handle,
|
||||
language_registry,
|
||||
Some(tab_description),
|
||||
tab_content,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
@@ -564,7 +564,7 @@ impl Client {
|
||||
|
||||
pub fn production(cx: &mut App) -> Arc<Self> {
|
||||
let clock = Arc::new(clock::RealSystemClock);
|
||||
let http = Arc::new(HttpClientWithUrl::new_uri(
|
||||
let http = Arc::new(HttpClientWithUrl::new_url(
|
||||
cx.http_client(),
|
||||
&ClientSettings::get_global(cx).server_url,
|
||||
cx.http_client().proxy().cloned(),
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
//! socks proxy
|
||||
use anyhow::{Result, anyhow};
|
||||
use http_client::Uri;
|
||||
use http_client::Url;
|
||||
use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
|
||||
|
||||
pub(crate) async fn connect_socks_proxy_stream(
|
||||
proxy: Option<&Uri>,
|
||||
proxy: Option<&Url>,
|
||||
rpc_host: (&str, u16),
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
let stream = match parse_socks_proxy(proxy) {
|
||||
@@ -32,9 +32,9 @@ pub(crate) async fn connect_socks_proxy_stream(
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn parse_socks_proxy(proxy: Option<&Uri>) -> Option<((String, u16), SocksVersion)> {
|
||||
let proxy_uri = proxy?;
|
||||
let scheme = proxy_uri.scheme_str()?;
|
||||
fn parse_socks_proxy(proxy: Option<&Url>) -> Option<((String, u16), SocksVersion)> {
|
||||
let proxy_url = proxy?;
|
||||
let scheme = proxy_url.scheme();
|
||||
let socks_version = if scheme.starts_with("socks4") {
|
||||
// socks4
|
||||
SocksVersion::V4
|
||||
@@ -44,7 +44,7 @@ fn parse_socks_proxy(proxy: Option<&Uri>) -> Option<((String, u16), SocksVersion
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
if let (Some(host), Some(port)) = (proxy_uri.host(), proxy_uri.port_u16()) {
|
||||
if let Some((host, port)) = proxy_url.host().zip(proxy_url.port_or_known_default()) {
|
||||
Some(((host.to_string(), port), socks_version))
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
alter table billing_preferences
|
||||
add column model_request_overages_enabled bool not null default false,
|
||||
add column model_request_overages_spend_limit_in_cents integer not null default 0;
|
||||
@@ -0,0 +1,8 @@
|
||||
create table subscription_usage_meters (
|
||||
id serial primary key,
|
||||
subscription_usage_id integer not null references subscription_usages (id) on delete cascade,
|
||||
model_id integer not null references models (id) on delete cascade,
|
||||
requests integer not null default 0
|
||||
);
|
||||
|
||||
create unique index uix_subscription_usage_meters_on_subscription_usage_model on subscription_usage_meters (subscription_usage_id, model_id);
|
||||
@@ -0,0 +1,6 @@
|
||||
alter table subscription_usage_meters
|
||||
add column mode text not null default 'normal';
|
||||
|
||||
drop index uix_subscription_usage_meters_on_subscription_usage_model;
|
||||
|
||||
create unique index uix_subscription_usage_meters_on_subscription_usage_model_mode on subscription_usage_meters (subscription_usage_id, model_id, mode);
|
||||
@@ -152,6 +152,7 @@ struct AuthenticatedUserParams {
|
||||
struct AuthenticatedUserResponse {
|
||||
user: User,
|
||||
metrics_id: String,
|
||||
feature_flags: Vec<String>,
|
||||
}
|
||||
|
||||
async fn get_authenticated_user(
|
||||
@@ -172,7 +173,12 @@ async fn get_authenticated_user(
|
||||
)
|
||||
.await?;
|
||||
let metrics_id = app.db.get_user_metrics_id(user.id).await?;
|
||||
Ok(Json(AuthenticatedUserResponse { user, metrics_id }))
|
||||
let feature_flags = app.db.get_user_flags(user.id).await?;
|
||||
Ok(Json(AuthenticatedUserResponse {
|
||||
user,
|
||||
metrics_id,
|
||||
feature_flags,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
|
||||
@@ -65,6 +65,8 @@ struct GetBillingPreferencesParams {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
model_request_overages_enabled: bool,
|
||||
model_request_overages_spend_limit_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn get_billing_preferences(
|
||||
@@ -81,16 +83,30 @@ async fn get_billing_preferences(
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: preferences
|
||||
.as_ref()
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents
|
||||
}),
|
||||
model_request_overages_enabled: preferences.as_ref().map_or(false, |preferences| {
|
||||
preferences.model_request_overages_enabled
|
||||
}),
|
||||
model_request_overages_spend_limit_in_cents: preferences
|
||||
.as_ref()
|
||||
.map_or(0, |preferences| {
|
||||
preferences.model_request_overages_spend_limit_in_cents
|
||||
}),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpdateBillingPreferencesBody {
|
||||
github_user_id: i32,
|
||||
#[serde(default)]
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
#[serde(default)]
|
||||
model_request_overages_enabled: bool,
|
||||
#[serde(default)]
|
||||
model_request_overages_spend_limit_in_cents: i32,
|
||||
}
|
||||
|
||||
async fn update_billing_preferences(
|
||||
@@ -106,6 +122,8 @@ async fn update_billing_preferences(
|
||||
|
||||
let max_monthly_llm_usage_spending_in_cents =
|
||||
body.max_monthly_llm_usage_spending_in_cents.max(0);
|
||||
let model_request_overages_spend_limit_in_cents =
|
||||
body.model_request_overages_spend_limit_in_cents.max(0);
|
||||
|
||||
let billing_preferences =
|
||||
if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
|
||||
@@ -116,6 +134,12 @@ async fn update_billing_preferences(
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
model_request_overages_enabled: ActiveValue::set(
|
||||
body.model_request_overages_enabled,
|
||||
),
|
||||
model_request_overages_spend_limit_in_cents: ActiveValue::set(
|
||||
model_request_overages_spend_limit_in_cents,
|
||||
),
|
||||
},
|
||||
)
|
||||
.await?
|
||||
@@ -125,18 +149,22 @@ async fn update_billing_preferences(
|
||||
user.id,
|
||||
&crate::db::CreateBillingPreferencesParams {
|
||||
max_monthly_llm_usage_spending_in_cents,
|
||||
model_request_overages_enabled: body.model_request_overages_enabled,
|
||||
model_request_overages_spend_limit_in_cents,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
};
|
||||
|
||||
SnowflakeRow::new(
|
||||
"Spend Limit Updated",
|
||||
"Billing Preferences Updated",
|
||||
Some(user.metrics_id),
|
||||
user.admin,
|
||||
None,
|
||||
json!({
|
||||
"user_id": user.id,
|
||||
"model_request_overages_enabled": billing_preferences.model_request_overages_enabled,
|
||||
"model_request_overages_spend_limit_in_cents": billing_preferences.model_request_overages_spend_limit_in_cents,
|
||||
"max_monthly_llm_usage_spending_in_cents": billing_preferences.max_monthly_llm_usage_spending_in_cents,
|
||||
}),
|
||||
)
|
||||
@@ -149,6 +177,9 @@ async fn update_billing_preferences(
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
|
||||
model_request_overages_spend_limit_in_cents: billing_preferences
|
||||
.model_request_overages_spend_limit_in_cents,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -291,16 +322,35 @@ async fn create_billing_subscription(
|
||||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&stripe_client,
|
||||
CreateCustomer {
|
||||
email: user.email_address.as_deref(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let existing_customer = if let Some(email) = user.email_address.as_deref() {
|
||||
let customers = Customer::list(
|
||||
&stripe_client,
|
||||
&stripe::ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
customers.data.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&stripe_client,
|
||||
CreateCustomer {
|
||||
email: user.email_address.as_deref(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
}
|
||||
};
|
||||
|
||||
let success_url = format!(
|
||||
@@ -343,7 +393,9 @@ async fn create_billing_subscription(
|
||||
zed_llm_client::LanguageModelProvider::Anthropic,
|
||||
"claude-3-7-sonnet",
|
||||
)?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let stripe_model = stripe_billing
|
||||
.register_model_for_token_based_usage(default_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?
|
||||
@@ -1193,9 +1245,9 @@ async fn find_or_create_billing_customer(
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
@@ -1210,17 +1262,19 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.context("failed to sync LLM usage to Stripe")
|
||||
.trace_err();
|
||||
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
|
||||
executor
|
||||
.sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn sync_with_stripe(
|
||||
async fn sync_token_usage_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
@@ -1251,15 +1305,120 @@ async fn sync_with_stripe(
|
||||
.parse()
|
||||
.context("failed to parse stripe customer id from db")?;
|
||||
|
||||
let stripe_model = stripe_billing.register_model(&model).await?;
|
||||
let stripe_model = stripe_billing
|
||||
.register_model_for_token_based_usage(&model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.await?;
|
||||
llm_db.consume_billing_event(event.id).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_model_request_usage_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.context("failed to sync LLM request usage to Stripe")
|
||||
.trace_err();
|
||||
executor
|
||||
.sleep(SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn sync_model_request_usage_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let usage_meters = llm_db
|
||||
.get_current_subscription_usage_meters(Utc::now())
|
||||
.await?;
|
||||
let user_ids = usage_meters
|
||||
.iter()
|
||||
.map(|(_, usage)| usage.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let billing_subscriptions = app
|
||||
.db
|
||||
.get_active_zed_pro_billing_subscriptions(user_ids)
|
||||
.await?;
|
||||
|
||||
let claude_3_5_sonnet = stripe_billing
|
||||
.find_price_by_lookup_key("claude-3-5-sonnet-requests")
|
||||
.await?;
|
||||
let claude_3_7_sonnet = stripe_billing
|
||||
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
|
||||
.await?;
|
||||
|
||||
for (usage_meter, usage) in usage_meters {
|
||||
maybe!(async {
|
||||
let Some((billing_customer, billing_subscription)) =
|
||||
billing_subscriptions.get(&usage.user_id)
|
||||
else {
|
||||
bail!(
|
||||
"Attempted to sync usage meter for user who is not a Stripe customer: {}",
|
||||
usage.user_id
|
||||
);
|
||||
};
|
||||
|
||||
let stripe_customer_id = billing_customer
|
||||
.stripe_customer_id
|
||||
.parse::<stripe::CustomerId>()
|
||||
.context("failed to parse Stripe customer ID from database")?;
|
||||
let stripe_subscription_id = billing_subscription
|
||||
.stripe_subscription_id
|
||||
.parse::<stripe::SubscriptionId>()
|
||||
.context("failed to parse Stripe subscription ID from database")?;
|
||||
|
||||
let model = llm_db.model_by_id(usage_meter.model_id)?;
|
||||
|
||||
let (price_id, meter_event_name) = match model.name.as_str() {
|
||||
"claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
|
||||
"claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
|
||||
model_name => {
|
||||
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
|
||||
}
|
||||
};
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_price(&stripe_subscription_id, price_id)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_request_usage(
|
||||
&stripe_customer_id,
|
||||
meter_event_name,
|
||||
usage_meter.requests,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -800,6 +800,7 @@ impl LocalSettingsKind {
|
||||
proto::LocalSettingsKind::Settings => Self::Settings,
|
||||
proto::LocalSettingsKind::Tasks => Self::Tasks,
|
||||
proto::LocalSettingsKind::Editorconfig => Self::Editorconfig,
|
||||
proto::LocalSettingsKind::Debug => Self::Debug,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -808,6 +809,7 @@ impl LocalSettingsKind {
|
||||
Self::Settings => proto::LocalSettingsKind::Settings,
|
||||
Self::Tasks => proto::LocalSettingsKind::Tasks,
|
||||
Self::Editorconfig => proto::LocalSettingsKind::Editorconfig,
|
||||
Self::Debug => proto::LocalSettingsKind::Debug,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod processed_stripe_events;
|
||||
pub mod projects;
|
||||
pub mod rate_buckets;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
pub mod users;
|
||||
|
||||
@@ -3,11 +3,15 @@ use super::*;
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
pub model_request_overages_enabled: bool,
|
||||
pub model_request_overages_spend_limit_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingPreferencesParams {
|
||||
pub max_monthly_llm_usage_spending_in_cents: ActiveValue<i32>,
|
||||
pub model_request_overages_enabled: ActiveValue<bool>,
|
||||
pub model_request_overages_spend_limit_in_cents: ActiveValue<i32>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
@@ -37,6 +41,12 @@ impl Database {
|
||||
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
|
||||
params.max_monthly_llm_usage_spending_in_cents,
|
||||
),
|
||||
model_request_overages_enabled: ActiveValue::set(
|
||||
params.model_request_overages_enabled,
|
||||
),
|
||||
model_request_overages_spend_limit_in_cents: ActiveValue::set(
|
||||
params.model_request_overages_spend_limit_in_cents,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
@@ -59,6 +69,10 @@ impl Database {
|
||||
max_monthly_llm_usage_spending_in_cents: params
|
||||
.max_monthly_llm_usage_spending_in_cents
|
||||
.clone(),
|
||||
model_request_overages_enabled: params.model_request_overages_enabled.clone(),
|
||||
model_request_overages_spend_limit_in_cents: params
|
||||
.model_request_overages_spend_limit_in_cents
|
||||
.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
|
||||
@@ -191,6 +191,38 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_zed_pro_billing_subscriptions(
|
||||
&self,
|
||||
user_ids: HashSet<UserId>,
|
||||
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
|
||||
self.transaction(|tx| {
|
||||
let user_ids = user_ids.clone();
|
||||
async move {
|
||||
let mut rows = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.select_also(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.is_in(user_ids))
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.filter(billing_subscription::Column::Kind.eq(SubscriptionKind::ZedPro))
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
let mut subscriptions = HashMap::default();
|
||||
while let Some(row) = rows.next().await {
|
||||
if let (subscription, Some(customer)) = row? {
|
||||
subscriptions.insert(customer.user_id, (customer, subscription));
|
||||
}
|
||||
}
|
||||
Ok(subscriptions)
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the user has an active billing subscription.
|
||||
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
|
||||
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
use super::*;
|
||||
use crate::db::tables::rate_buckets;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
impl Database {
|
||||
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
|
||||
/// than the currently saved timestamp.
|
||||
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
|
||||
if buckets.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
|
||||
rate_buckets::ActiveModel {
|
||||
user_id: ActiveValue::Set(bucket.user_id),
|
||||
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
|
||||
token_count: ActiveValue::Set(bucket.token_count),
|
||||
last_refill: ActiveValue::Set(bucket.last_refill),
|
||||
}
|
||||
}))
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
rate_buckets::Column::UserId,
|
||||
rate_buckets::Column::RateLimitName,
|
||||
])
|
||||
.update_columns([
|
||||
rate_buckets::Column::TokenCount,
|
||||
rate_buckets::Column::LastRefill,
|
||||
])
|
||||
.to_owned(),
|
||||
)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieves the rate limit for the given user and rate limit name.
|
||||
pub async fn get_rate_bucket(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
rate_limit_name: &str,
|
||||
) -> Result<Option<rate_buckets::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
let rate_limit = rate_buckets::Entity::find()
|
||||
.filter(rate_buckets::Column::UserId.eq(user_id))
|
||||
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(rate_limit)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,6 @@ pub mod project;
|
||||
pub mod project_collaborator;
|
||||
pub mod project_repository;
|
||||
pub mod project_repository_statuses;
|
||||
pub mod rate_buckets;
|
||||
pub mod room;
|
||||
pub mod room_participant;
|
||||
pub mod server;
|
||||
|
||||
@@ -9,6 +9,8 @@ pub struct Model {
|
||||
pub created_at: DateTime,
|
||||
pub user_id: UserId,
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
pub model_request_overages_enabled: bool,
|
||||
pub model_request_overages_spend_limit_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use crate::db::UserId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "rate_buckets")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub user_id: UserId,
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub rate_limit_name: String,
|
||||
pub token_count: i32,
|
||||
pub last_refill: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -32,4 +32,6 @@ pub enum LocalSettingsKind {
|
||||
Tasks,
|
||||
#[sea_orm(string_value = "editorconfig")]
|
||||
Editorconfig,
|
||||
#[sea_orm(string_value = "debug")]
|
||||
Debug,
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ pub mod env;
|
||||
pub mod executor;
|
||||
pub mod llm;
|
||||
pub mod migrations;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
@@ -25,7 +24,6 @@ pub use cents::*;
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
pub use rate_limiter::*;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
@@ -295,7 +293,6 @@ pub struct AppState {
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
||||
pub config: Config,
|
||||
@@ -348,7 +345,6 @@ impl AppState {
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
|
||||
@@ -2,5 +2,6 @@ use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod subscription_usage_meters;
|
||||
pub mod subscription_usages;
|
||||
pub mod usages;
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
/// Returns all current subscription usage meters as of the given timestamp.
|
||||
pub async fn get_current_subscription_usage_meters(
|
||||
&self,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
|
||||
let now = convert_chrono_to_time(now)?;
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
let result = subscription_usage_meter::Entity::find()
|
||||
.inner_join(subscription_usage::Entity)
|
||||
.filter(
|
||||
subscription_usage::Column::PeriodStartAt
|
||||
.lte(now)
|
||||
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
|
||||
)
|
||||
.select_also(subscription_usage::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
|
||||
let result = result
|
||||
.into_iter()
|
||||
.filter_map(|(meter, usage)| {
|
||||
let usage = usage?;
|
||||
Some((meter, usage))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ use crate::db::{UserId, billing_subscription};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
|
||||
pub fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
|
||||
use chrono::{Datelike as _, Timelike as _};
|
||||
|
||||
let date = time::Date::from_calendar_date(
|
||||
|
||||
@@ -3,5 +3,6 @@ pub mod model;
|
||||
pub mod monthly_usage;
|
||||
pub mod provider;
|
||||
pub mod subscription_usage;
|
||||
pub mod subscription_usage_meter;
|
||||
pub mod usage;
|
||||
pub mod usage_measure;
|
||||
|
||||
43
crates/collab/src/llm/db/tables/subscription_usage_meter.rs
Normal file
43
crates/collab/src/llm/db/tables/subscription_usage_meter.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
use crate::llm::db::ModelId;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "subscription_usage_meters")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub subscription_usage_id: i32,
|
||||
pub model_id: ModelId,
|
||||
pub requests: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::subscription_usage::Entity",
|
||||
from = "Column::SubscriptionUsageId",
|
||||
to = "super::subscription_usage::Column::Id"
|
||||
)]
|
||||
SubscriptionUsage,
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Related<super::subscription_usage::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::SubscriptionUsage.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -34,6 +34,10 @@ pub struct LlmTokenClaims {
|
||||
#[serde(default)]
|
||||
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
|
||||
#[serde(default)]
|
||||
pub enable_model_request_overages: bool,
|
||||
#[serde(default)]
|
||||
pub model_request_overages_spend_limit_in_cents: u32,
|
||||
#[serde(default)]
|
||||
pub can_use_web_search_tool: bool,
|
||||
}
|
||||
|
||||
@@ -75,6 +79,7 @@ impl LlmTokenClaims {
|
||||
can_use_web_search_tool: feature_flags.iter().any(|flag| flag == "assistant2"),
|
||||
has_llm_subscription: has_legacy_llm_subscription,
|
||||
max_monthly_spend_in_cents: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents as u32
|
||||
}),
|
||||
@@ -96,6 +101,16 @@ impl LlmTokenClaims {
|
||||
|
||||
Some((period_start_at.naive_utc(), period_end_at.naive_utc()))
|
||||
}),
|
||||
enable_model_request_overages: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(false, |preferences| {
|
||||
preferences.model_request_overages_enabled
|
||||
}),
|
||||
model_request_overages_spend_limit_in_cents: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(0, |preferences| {
|
||||
preferences.model_request_overages_spend_limit_in_cents as u32
|
||||
}),
|
||||
};
|
||||
|
||||
Ok(jsonwebtoken::encode(
|
||||
|
||||
@@ -8,13 +8,15 @@ use axum::{
|
||||
};
|
||||
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::api::billing::sync_llm_usage_with_stripe_periodically;
|
||||
use collab::api::billing::{
|
||||
sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
|
||||
};
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
use collab::{
|
||||
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
|
||||
env, executor::Executor, rpc::ResultExt,
|
||||
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
|
||||
executor::Executor, rpc::ResultExt,
|
||||
};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
|
||||
use db::Database;
|
||||
@@ -111,10 +113,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
RateLimiter::save_periodically(
|
||||
state.rate_limiter.clone(),
|
||||
state.executor.clone(),
|
||||
);
|
||||
|
||||
let epoch = state
|
||||
.db
|
||||
@@ -156,7 +154,8 @@ async fn main() -> Result<()> {
|
||||
|
||||
if let Some(mut llm_db) = llm_db {
|
||||
llm_db.initialize().await?;
|
||||
sync_llm_usage_with_stripe_periodically(state.clone());
|
||||
sync_llm_request_usage_with_stripe_periodically(state.clone());
|
||||
sync_llm_token_usage_with_stripe_periodically(state.clone());
|
||||
}
|
||||
|
||||
app = app
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
use crate::{Database, Error, Result, db::UserId, executor::Executor};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use rpc::ErrorCodeExt;
|
||||
use sea_orm::prelude::DateTimeUtc;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub trait RateLimit: Send + Sync {
|
||||
fn capacity(&self) -> usize;
|
||||
fn refill_duration(&self) -> Duration;
|
||||
fn db_name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
/// Used to enforce per-user rate limits
|
||||
pub struct RateLimiter {
|
||||
buckets: DashMap<(UserId, String), RateBucket>,
|
||||
dirty_buckets: DashSet<(UserId, String)>,
|
||||
db: Arc<Database>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(db: Arc<Database>) -> Self {
|
||||
RateLimiter {
|
||||
buckets: DashMap::new(),
|
||||
dirty_buckets: DashSet::new(),
|
||||
db,
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a new task that periodically saves rate limit data to the database.
|
||||
pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
|
||||
const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
executor.clone().spawn_detached(async move {
|
||||
loop {
|
||||
executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
|
||||
rate_limiter.save().await.log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns an error if the user has exceeded the specified `RateLimit`.
|
||||
/// Attempts to read the from the database if no cached RateBucket currently exists.
|
||||
pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
|
||||
self.check_internal(limit, user_id, Utc::now()).await
|
||||
}
|
||||
|
||||
async fn check_internal(
|
||||
&self,
|
||||
limit: &dyn RateLimit,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<()> {
|
||||
let bucket_key = (user_id, limit.db_name().to_string());
|
||||
|
||||
// Attempt to fetch the bucket from the database if it hasn't been cached.
|
||||
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
|
||||
// but this enforces limits across restarts so long as the database is reachable.
|
||||
if !self.buckets.contains_key(&bucket_key) {
|
||||
if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
|
||||
self.buckets.insert(bucket_key.clone(), bucket);
|
||||
self.dirty_buckets.insert(bucket_key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut bucket = self
|
||||
.buckets
|
||||
.entry(bucket_key.clone())
|
||||
.or_insert_with(|| RateBucket::new(limit, now));
|
||||
|
||||
if bucket.value_mut().allow(now) {
|
||||
self.dirty_buckets.insert(bucket_key);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(rpc::proto::ErrorCode::RateLimitExceeded
|
||||
.message("rate limit exceeded".into())
|
||||
.anyhow())?
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_bucket(
|
||||
&self,
|
||||
limit: &dyn RateLimit,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<RateBucket>, Error> {
|
||||
Ok(self
|
||||
.db
|
||||
.get_rate_bucket(user_id, limit.db_name())
|
||||
.await?
|
||||
.map(|saved_bucket| {
|
||||
RateBucket::from_db(
|
||||
limit,
|
||||
saved_bucket.token_count as usize,
|
||||
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn save(&self) -> Result<()> {
|
||||
let mut buckets = Vec::new();
|
||||
self.dirty_buckets.retain(|key| {
|
||||
if let Some(bucket) = self.buckets.get(key) {
|
||||
buckets.push(crate::db::rate_buckets::Model {
|
||||
user_id: key.0,
|
||||
rate_limit_name: key.1.clone(),
|
||||
token_count: bucket.token_count as i32,
|
||||
last_refill: bucket.last_refill.naive_utc(),
|
||||
});
|
||||
}
|
||||
false
|
||||
});
|
||||
|
||||
match self.db.save_rate_buckets(&buckets).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) => {
|
||||
for bucket in buckets {
|
||||
self.dirty_buckets
|
||||
.insert((bucket.user_id, bucket.rate_limit_name));
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct RateBucket {
|
||||
capacity: usize,
|
||||
token_count: usize,
|
||||
refill_time_per_token: Duration,
|
||||
last_refill: DateTimeUtc,
|
||||
}
|
||||
|
||||
impl RateBucket {
|
||||
fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
|
||||
Self {
|
||||
capacity: limit.capacity(),
|
||||
token_count: limit.capacity(),
|
||||
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
|
||||
last_refill: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
|
||||
Self {
|
||||
capacity: limit.capacity(),
|
||||
token_count,
|
||||
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
|
||||
last_refill,
|
||||
}
|
||||
}
|
||||
|
||||
fn allow(&mut self, now: DateTimeUtc) -> bool {
|
||||
self.refill(now);
|
||||
if self.token_count > 0 {
|
||||
self.token_count -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self, now: DateTimeUtc) {
|
||||
let elapsed = now - self.last_refill;
|
||||
if elapsed >= self.refill_time_per_token {
|
||||
let new_tokens =
|
||||
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
|
||||
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
|
||||
|
||||
let unused_refill_time = Duration::milliseconds(
|
||||
elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
|
||||
);
|
||||
self.last_refill = now - unused_refill_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db::{NewUserParams, TestDb};
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_rate_limiter(cx: &mut TestAppContext) {
|
||||
let test_db = TestDb::sqlite(cx.executor().clone());
|
||||
let db = test_db.db().clone();
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user-1@zed.dev",
|
||||
None,
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user-1".into(),
|
||||
github_user_id: 1,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user-2@zed.dev",
|
||||
None,
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user-2".into(),
|
||||
github_user_id: 2,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let mut now = Utc::now();
|
||||
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
let rate_limit_a = Box::new(RateLimitA);
|
||||
let rate_limit_b = Box::new(RateLimitB);
|
||||
|
||||
// User 1 can access resource A two times before being rate-limited.
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// User 2 can access resource A and user 1 can access resource B.
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_b, user_2, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_b, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// After 1.5s, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::milliseconds(1500);
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// After 500ms, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::milliseconds(500);
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
rate_limiter.save().await.unwrap();
|
||||
|
||||
// Rate limits are reloaded from the database, so user A is still rate-limited
|
||||
// for resource A.
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// After 1s, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::seconds(1);
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
struct RateLimitA;
|
||||
|
||||
impl RateLimit for RateLimitA {
|
||||
fn capacity(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> Duration {
|
||||
Duration::seconds(2)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"rate-limit-a"
|
||||
}
|
||||
}
|
||||
|
||||
struct RateLimitB;
|
||||
|
||||
impl RateLimit for RateLimitB {
|
||||
fn capacity(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> Duration {
|
||||
Duration::seconds(3)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"rate-limit-b"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
mod connection_pool;
|
||||
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
@@ -178,15 +179,23 @@ impl Session {
|
||||
Ok(db.has_active_billing_subscription(user_id).await?)
|
||||
}
|
||||
|
||||
pub async fn current_plan(
|
||||
&self,
|
||||
_db: &MutexGuard<'_, DbHandle>,
|
||||
) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
Ok(proto::Plan::ZedPro)
|
||||
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
|
||||
let user_id = self.user_id();
|
||||
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
||||
|
||||
let plan = if let Some(subscription_kind) = subscription_kind {
|
||||
match subscription_kind {
|
||||
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => proto::Plan::Free,
|
||||
}
|
||||
} else {
|
||||
Ok(proto::Plan::Free)
|
||||
}
|
||||
proto::Plan::Free
|
||||
};
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
fn user_id(&self) -> UserId {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user