Compare commits
88 Commits
message-ed
...
configure-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de9d470e4f | ||
|
|
f17f63ec84 | ||
|
|
15a1eb2a2e | ||
|
|
332626e582 | ||
|
|
7b3fe0a474 | ||
|
|
36184a71df | ||
|
|
ea7bc96c05 | ||
|
|
d1958aa439 | ||
|
|
5620e359af | ||
|
|
6f2e7c355e | ||
|
|
864d4bc1d1 | ||
|
|
7784fac288 | ||
|
|
f5f14111ef | ||
|
|
e664a9bc48 | ||
|
|
bf34e185d5 | ||
|
|
b9c110e63e | ||
|
|
f642f7615f | ||
|
|
3d77ad7e1a | ||
|
|
f365403618 | ||
|
|
9eb1ff2726 | ||
|
|
239e479aed | ||
|
|
3e0a755486 | ||
|
|
7199c733b2 | ||
|
|
65f64aa513 | ||
|
|
2a9d4599cd | ||
|
|
75f85b3aaa | ||
|
|
b3cad8b527 | ||
|
|
1931889759 | ||
|
|
3c5d5a1d57 | ||
|
|
bd1fda6782 | ||
|
|
e452aba9da | ||
|
|
75b832029a | ||
|
|
257e0991d8 | ||
|
|
c39f294bcb | ||
|
|
7671f34f88 | ||
|
|
7993ee9c07 | ||
|
|
485802b9e5 | ||
|
|
1e41d86b31 | ||
|
|
10a2426a58 | ||
|
|
91e6b38285 | ||
|
|
f63036548c | ||
|
|
846ed6adf9 | ||
|
|
708c434bd4 | ||
|
|
6f3cd42411 | ||
|
|
f8b0105258 | ||
|
|
2a57b160b0 | ||
|
|
d891348442 | ||
|
|
4f0b00b0d9 | ||
|
|
a3dcc76687 | ||
|
|
8d6982e78f | ||
|
|
23d0433158 | ||
|
|
4d27b228f7 | ||
|
|
8366b6ce54 | ||
|
|
b1e806442a | ||
|
|
e2ce787c05 | ||
|
|
b7c562f359 | ||
|
|
3a711d0814 | ||
|
|
b65e9af3e9 | ||
|
|
eb9bbaacb1 | ||
|
|
43ee604179 | ||
|
|
2acfa5e948 | ||
|
|
1a169e0b16 | ||
|
|
5a9546ff4b | ||
|
|
9a2b7ef372 | ||
|
|
20be133713 | ||
|
|
528d56e807 | ||
|
|
f514c7cc18 | ||
|
|
ba2c45bc53 | ||
|
|
e5402d5464 | ||
|
|
ffac8c5128 | ||
|
|
b3d048d6dc | ||
|
|
8e4f30abcb | ||
|
|
0291db0d78 | ||
|
|
5bbdd1a262 | ||
|
|
ab9fa03d55 | ||
|
|
5a6df38ccf | ||
|
|
32f9de6124 | ||
|
|
e67b2da20c | ||
|
|
293992f5b1 | ||
|
|
665006c414 | ||
|
|
09e90fb023 | ||
|
|
8452532c8f | ||
|
|
1d2eaf210a | ||
|
|
a6e2e0d24a | ||
|
|
9be44517cb | ||
|
|
389d24d7e5 | ||
|
|
389d382f42 | ||
|
|
bd61eb0889 |
@@ -25,6 +25,8 @@ third-party = [
|
||||
{ name = "reqwest", version = "0.11.27" },
|
||||
# build of remote_server should not include scap / its x11 dependency
|
||||
{ name = "scap", git = "https://github.com/zed-industries/scap", rev = "808aa5c45b41e8f44729d02e38fd00a2fe2722e7" },
|
||||
# build of remote_server should not need to include on libalsa through rodio
|
||||
{ name = "rodio" },
|
||||
]
|
||||
|
||||
[final-excludes]
|
||||
@@ -32,7 +34,6 @@ workspace-members = [
|
||||
"zed_extension_api",
|
||||
|
||||
# exclude all extensions
|
||||
"zed_emmet",
|
||||
"zed_glsl",
|
||||
"zed_html",
|
||||
"zed_proto",
|
||||
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -718,7 +718,7 @@ jobs:
|
||||
timeout-minutes: 60
|
||||
runs-on: github-8vcpu-ubuntu-2404
|
||||
if: |
|
||||
( startsWith(github.ref, 'refs/tags/v')
|
||||
false && ( startsWith(github.ref, 'refs/tags/v')
|
||||
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
|
||||
needs: [linux_tests]
|
||||
name: Build Zed on FreeBSD
|
||||
|
||||
230
Cargo.lock
generated
230
Cargo.lock
generated
@@ -7,12 +7,14 @@ name = "acp_thread"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"action_log",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"anyhow",
|
||||
"buffer_diff",
|
||||
"collections",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"file_icons",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indoc",
|
||||
@@ -21,6 +23,7 @@ dependencies = [
|
||||
"markdown",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -169,9 +172,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.23"
|
||||
version = "0.0.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8"
|
||||
checksum = "2ab66add8be8d6a963f5bf4070045c1bbf36472837654c73e2298dd16bda5bf7"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
@@ -344,7 +347,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indexed_docs",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"itertools 0.14.0",
|
||||
@@ -392,6 +394,7 @@ dependencies = [
|
||||
"ui",
|
||||
"ui_input",
|
||||
"unindent",
|
||||
"url",
|
||||
"urlencoding",
|
||||
"util",
|
||||
"uuid",
|
||||
@@ -868,7 +871,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indexed_docs",
|
||||
"language",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
@@ -1258,26 +1260,6 @@ dependencies = [
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stripe"
|
||||
version = "0.40.0"
|
||||
source = "git+https://github.com/zed-industries/async-stripe?rev=3672dd4efb7181aa597bf580bf5a2f5d23db6735#3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"futures-util",
|
||||
"http-types",
|
||||
"hyper 0.14.32",
|
||||
"hyper-rustls 0.24.2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_qs 0.10.1",
|
||||
"smart-default 0.6.0",
|
||||
"smol_str 0.1.24",
|
||||
"thiserror 1.0.69",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-tar"
|
||||
version = "0.5.0"
|
||||
@@ -1300,9 +1282,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.88"
|
||||
version = "0.1.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
|
||||
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2079,12 +2061,6 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "349a06037c7bf932dd7e7d1f653678b2038b9ad46a74102f1fc7bd7872678cce"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
@@ -3277,7 +3253,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assistant_context",
|
||||
"assistant_slash_command",
|
||||
"async-stripe",
|
||||
"async-trait",
|
||||
"async-tungstenite",
|
||||
"audio",
|
||||
@@ -3293,7 +3268,6 @@ dependencies = [
|
||||
"chrono",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
"collab_ui",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
@@ -3304,7 +3278,6 @@ dependencies = [
|
||||
"dap_adapters",
|
||||
"dashmap 6.1.0",
|
||||
"debugger_ui",
|
||||
"derive_more 0.99.19",
|
||||
"editor",
|
||||
"envy",
|
||||
"extension",
|
||||
@@ -3320,7 +3293,6 @@ dependencies = [
|
||||
"http_client",
|
||||
"hyper 0.14.32",
|
||||
"indoc",
|
||||
"jsonwebtoken",
|
||||
"language",
|
||||
"language_model",
|
||||
"livekit_api",
|
||||
@@ -3366,7 +3338,6 @@ dependencies = [
|
||||
"telemetry_events",
|
||||
"text",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"tokio",
|
||||
"toml 0.8.20",
|
||||
@@ -3868,7 +3839,7 @@ dependencies = [
|
||||
"rustc-hash 1.1.0",
|
||||
"rustybuzz 0.14.1",
|
||||
"self_cell",
|
||||
"smol_str 0.2.2",
|
||||
"smol_str",
|
||||
"swash",
|
||||
"sys-locale",
|
||||
"ttf-parser 0.21.1",
|
||||
@@ -4065,6 +4036,8 @@ dependencies = [
|
||||
"minidumper",
|
||||
"paths",
|
||||
"release_channel",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -6372,17 +6345,6 @@ dependencies = [
|
||||
"windows-targets 0.48.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi 0.9.0+wasi-snapshot-preview1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.15"
|
||||
@@ -7879,6 +7841,12 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hound"
|
||||
version = "3.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f"
|
||||
|
||||
[[package]]
|
||||
name = "html5ever"
|
||||
version = "0.27.0"
|
||||
@@ -7980,27 +7948,6 @@ version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
|
||||
|
||||
[[package]]
|
||||
name = "http-types"
|
||||
version = "2.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-channel 1.9.0",
|
||||
"base64 0.13.1",
|
||||
"futures-lite 1.13.0",
|
||||
"http 0.2.12",
|
||||
"infer",
|
||||
"pin-project-lite",
|
||||
"rand 0.7.3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_qs 0.8.5",
|
||||
"serde_urlencoded",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http_client"
|
||||
version = "0.1.0"
|
||||
@@ -8434,34 +8381,6 @@ version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408"
|
||||
|
||||
[[package]]
|
||||
name = "indexed_docs"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"cargo_metadata",
|
||||
"collections",
|
||||
"derive_more 0.99.19",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"heed",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indexmap",
|
||||
"indoc",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"strum 0.27.1",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.9.0"
|
||||
@@ -8479,12 +8398,6 @@ version = "2.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd"
|
||||
|
||||
[[package]]
|
||||
name = "infer"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
|
||||
|
||||
[[package]]
|
||||
name = "inherent"
|
||||
version = "1.0.12"
|
||||
@@ -9707,6 +9620,7 @@ dependencies = [
|
||||
"objc",
|
||||
"parking_lot",
|
||||
"postage",
|
||||
"rodio",
|
||||
"scap",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -10260,7 +10174,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"range-map",
|
||||
"scroll",
|
||||
"smart-default 0.7.1",
|
||||
"smart-default",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -13134,19 +13048,6 @@ version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
|
||||
dependencies = [
|
||||
"getrandom 0.1.16",
|
||||
"libc",
|
||||
"rand_chacha 0.2.2",
|
||||
"rand_core 0.5.1",
|
||||
"rand_hc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
@@ -13168,16 +13069,6 @@ dependencies = [
|
||||
"rand_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
@@ -13198,15 +13089,6 @@ dependencies = [
|
||||
"rand_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
dependencies = [
|
||||
"getrandom 0.1.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
@@ -13225,15 +13107,6 @@ dependencies = [
|
||||
"getrandom 0.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
dependencies = [
|
||||
"rand_core 0.5.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "range-map"
|
||||
version = "0.2.0"
|
||||
@@ -13968,6 +13841,7 @@ checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183"
|
||||
dependencies = [
|
||||
"cpal",
|
||||
"dasp_sample",
|
||||
"hound",
|
||||
"num-rational",
|
||||
"symphonia",
|
||||
"tracing",
|
||||
@@ -14887,28 +14761,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_qs"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_qs"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cac3f1e2ca2fe333923a1ae72caca910b98ed0630bb35ef6f8c8517d6e81afa"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_repr"
|
||||
version = "0.1.20"
|
||||
@@ -15050,8 +14902,10 @@ dependencies = [
|
||||
"ui",
|
||||
"ui_input",
|
||||
"util",
|
||||
"vim",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -15283,17 +15137,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smart-default"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "133659a15339456eeeb07572eb02a91c91e9815e9cbc89566944d2c8d3efdbf6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smart-default"
|
||||
version = "0.7.1"
|
||||
@@ -15322,15 +15165,6 @@ dependencies = [
|
||||
"futures-lite 2.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fad6c857cbab2627dcf01ec85a623ca4e7dcb5691cbaa3d7fb7653671f0d09c9"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
version = "0.2.2"
|
||||
@@ -18179,12 +18013,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
@@ -20271,7 +20099,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "xim"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
|
||||
source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d"
|
||||
dependencies = [
|
||||
"ahash 0.8.11",
|
||||
"hashbrown 0.14.5",
|
||||
@@ -20284,7 +20112,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "xim-ctext"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
|
||||
source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d"
|
||||
dependencies = [
|
||||
"encoding_rs",
|
||||
]
|
||||
@@ -20292,7 +20120,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "xim-parser"
|
||||
version = "0.2.1"
|
||||
source = "git+https://github.com/XDeme1/xim-rs?rev=d50d461764c2213655cd9cf65a0ea94c70d3c4fd#d50d461764c2213655cd9cf65a0ea94c70d3c4fd"
|
||||
source = "git+https://github.com/zed-industries/xim-rs?rev=c0a70c1bd2ce197364216e5e818a2cb3adb99a8d#c0a70c1bd2ce197364216e5e818a2cb3adb99a8d"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
]
|
||||
@@ -20570,6 +20398,7 @@ dependencies = [
|
||||
"language_tools",
|
||||
"languages",
|
||||
"libc",
|
||||
"livekit_client",
|
||||
"log",
|
||||
"markdown",
|
||||
"markdown_preview",
|
||||
@@ -20661,13 +20490,6 @@ dependencies = [
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zed_emmet"
|
||||
version = "0.0.6"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.1.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zed_extension_api"
|
||||
version = "0.1.0"
|
||||
|
||||
21
Cargo.toml
21
Cargo.toml
@@ -81,7 +81,6 @@ members = [
|
||||
"crates/http_client_tls",
|
||||
"crates/icons",
|
||||
"crates/image_viewer",
|
||||
"crates/indexed_docs",
|
||||
"crates/edit_prediction",
|
||||
"crates/edit_prediction_button",
|
||||
"crates/inspector_ui",
|
||||
@@ -199,7 +198,6 @@ members = [
|
||||
# Extensions
|
||||
#
|
||||
|
||||
"extensions/emmet",
|
||||
"extensions/glsl",
|
||||
"extensions/html",
|
||||
"extensions/proto",
|
||||
@@ -306,7 +304,6 @@ http_client = { path = "crates/http_client" }
|
||||
http_client_tls = { path = "crates/http_client_tls" }
|
||||
icons = { path = "crates/icons" }
|
||||
image_viewer = { path = "crates/image_viewer" }
|
||||
indexed_docs = { path = "crates/indexed_docs" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
edit_prediction_button = { path = "crates/edit_prediction_button" }
|
||||
inspector_ui = { path = "crates/inspector_ui" }
|
||||
@@ -363,6 +360,7 @@ remote_server = { path = "crates/remote_server" }
|
||||
repl = { path = "crates/repl" }
|
||||
reqwest_client = { path = "crates/reqwest_client" }
|
||||
rich_text = { path = "crates/rich_text" }
|
||||
rodio = { version = "0.21.1", default-features = false }
|
||||
rope = { path = "crates/rope" }
|
||||
rpc = { path = "crates/rpc" }
|
||||
rules_library = { path = "crates/rules_library" }
|
||||
@@ -425,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.23"
|
||||
agent-client-protocol = "0.0.25"
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
@@ -564,7 +562,6 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
|
||||
"socks",
|
||||
"stream",
|
||||
] }
|
||||
rodio = { version = "0.21.1", default-features = false }
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
@@ -667,20 +664,6 @@ workspace-hack = "0.1.0"
|
||||
yawc = { git = "https://github.com/deviant-forks/yawc", rev = "1899688f3e69ace4545aceb97b2a13881cf26142" }
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
git = "https://github.com/zed-industries/async-stripe"
|
||||
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
|
||||
default-features = false
|
||||
features = [
|
||||
"runtime-tokio-hyper-rustls",
|
||||
"billing",
|
||||
"checkout",
|
||||
"events",
|
||||
# The features below are only enabled to get the `events` feature to build.
|
||||
"chrono",
|
||||
"connect",
|
||||
]
|
||||
|
||||
[workspace.dependencies.windows]
|
||||
version = "0.61"
|
||||
features = [
|
||||
|
||||
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Bold.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-BoldItalic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Italic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf
Normal file
BIN
assets/fonts/ibm-plex-sans/IBMPlexSans-Regular.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-Bold.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-Bold.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-BoldItalic.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-BoldItalic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-Italic.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-Italic.ttf
Normal file
Binary file not shown.
BIN
assets/fonts/lilex/Lilex-Regular.ttf
Normal file
BIN
assets/fonts/lilex/Lilex-Regular.ttf
Normal file
Binary file not shown.
@@ -1,8 +1,9 @@
|
||||
Copyright © 2017 IBM Corp. with Reserved Font Name "Plex"
|
||||
Copyright 2019 The Lilex Project Authors (https://github.com/mishamyrt/Lilex)
|
||||
|
||||
This Font Software is licensed under the SIL Open Font License, Version 1.1.
|
||||
This license is copied below, and is also available with a FAQ at:
|
||||
http://scripts.sil.org/OFL
|
||||
https://scripts.sil.org/OFL
|
||||
|
||||
|
||||
-----------------------------------------------------------
|
||||
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
|
||||
@@ -89,4 +90,4 @@ COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
|
||||
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4
assets/icons/json.svg
Normal file
4
assets/icons/json.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M5.78125 3C3.90625 3 3.90625 4.5 3.90625 5.5C3.90625 6.5 3.40625 7.50106 2.40625 8C3.40625 8.50106 3.90625 9.5 3.90625 10.5C3.90625 11.5 3.90625 13 5.78125 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10.2422 3C12.1172 3 12.1172 4.5 12.1172 5.5C12.1172 6.5 12.6172 7.50106 13.6172 8C12.6172 8.50106 12.1172 9.5 12.1172 10.5C12.1172 11.5 12.1172 13 10.2422 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 607 B |
@@ -58,6 +58,8 @@
|
||||
"[ space": "vim::InsertEmptyLineAbove",
|
||||
"[ e": "editor::MoveLineUp",
|
||||
"] e": "editor::MoveLineDown",
|
||||
"[ f": "workspace::FollowNextCollaborator",
|
||||
"] f": "workspace::FollowNextCollaborator",
|
||||
|
||||
// Word motions
|
||||
"w": "vim::NextWordStart",
|
||||
@@ -390,7 +392,7 @@
|
||||
"right": "vim::WrappingRight",
|
||||
"h": "vim::WrappingLeft",
|
||||
"l": "vim::WrappingRight",
|
||||
"y": "editor::Copy",
|
||||
"y": "vim::HelixYank",
|
||||
"alt-;": "vim::OtherEnd",
|
||||
"ctrl-r": "vim::Redo",
|
||||
"f": ["vim::PushFindForward", { "before": false, "multiline": true }],
|
||||
@@ -407,6 +409,7 @@
|
||||
"g w": "vim::PushRewrap",
|
||||
"insert": "vim::InsertBefore",
|
||||
"alt-.": "vim::RepeatFind",
|
||||
"alt-s": ["editor::SplitSelectionIntoLines", { "keep_selections": true }],
|
||||
// tree-sitter related commands
|
||||
"[ x": "editor::SelectLargerSyntaxNode",
|
||||
"] x": "editor::SelectSmallerSyntaxNode",
|
||||
|
||||
@@ -28,7 +28,9 @@
|
||||
"edit_prediction_provider": "zed"
|
||||
},
|
||||
// The name of a font to use for rendering text in the editor
|
||||
"buffer_font_family": "Zed Plex Mono",
|
||||
// ".ZedMono" currently aliases to Lilex
|
||||
// but this may change in the future.
|
||||
"buffer_font_family": ".ZedMono",
|
||||
// Set the buffer text's font fallbacks, this will be merged with
|
||||
// the platform's default fallbacks.
|
||||
"buffer_font_fallbacks": null,
|
||||
@@ -54,7 +56,9 @@
|
||||
"buffer_line_height": "comfortable",
|
||||
// The name of a font to use for rendering text in the UI
|
||||
// You can set this to ".SystemUIFont" to use the system font
|
||||
"ui_font_family": "Zed Plex Sans",
|
||||
// ".ZedSans" currently aliases to "IBM Plex Sans", but this may
|
||||
// change in the future
|
||||
"ui_font_family": ".ZedSans",
|
||||
// Set the UI's font fallbacks, this will be merged with the platform's
|
||||
// default font fallbacks.
|
||||
"ui_font_fallbacks": null,
|
||||
@@ -67,8 +71,8 @@
|
||||
"ui_font_weight": 400,
|
||||
// The default font size for text in the UI
|
||||
"ui_font_size": 16,
|
||||
// The default font size for text in the agent panel
|
||||
"agent_font_size": 16,
|
||||
// The default font size for text in the agent panel. Falls back to the UI font size if unset.
|
||||
"agent_font_size": null,
|
||||
// How much to fade out unused code.
|
||||
"unnecessary_code_fade": 0.3,
|
||||
// Active pane styling settings.
|
||||
@@ -883,11 +887,6 @@
|
||||
},
|
||||
// The settings for slash commands.
|
||||
"slash_commands": {
|
||||
// Settings for the `/docs` slash command.
|
||||
"docs": {
|
||||
// Whether `/docs` is enabled.
|
||||
"enabled": false
|
||||
},
|
||||
// Settings for the `/project` slash command.
|
||||
"project": {
|
||||
// Whether `/project` is enabled.
|
||||
@@ -1252,7 +1251,9 @@
|
||||
// Status bar-related settings.
|
||||
"status_bar": {
|
||||
// Whether to show the active language button in the status bar.
|
||||
"active_language_button": true
|
||||
"active_language_button": true,
|
||||
// Whether to show the cursor position button in the status bar.
|
||||
"cursor_position_button": true
|
||||
},
|
||||
// Settings specific to the terminal
|
||||
"terminal": {
|
||||
@@ -1402,7 +1403,7 @@
|
||||
// "font_size": 15,
|
||||
// Set the terminal's font family. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font family.
|
||||
// "font_family": "Zed Plex Mono",
|
||||
// "font_family": ".ZedMono",
|
||||
// Set the terminal's font fallbacks. If this option is not included,
|
||||
// the terminal will default to matching the buffer's font fallbacks.
|
||||
// This will be merged with the platform's default font fallbacks
|
||||
|
||||
@@ -13,21 +13,25 @@ path = "src/acp_thread.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "project/test-support"]
|
||||
test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"]
|
||||
|
||||
[dependencies]
|
||||
action_log.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
file_icons.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
markdown.workspace = true
|
||||
parking_lot = { workspace = true, optional = true }
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
|
||||
@@ -32,13 +32,24 @@ use util::ResultExt;
|
||||
pub struct UserMessage {
|
||||
pub id: Option<UserMessageId>,
|
||||
pub content: ContentBlock,
|
||||
pub checkpoint: Option<GitStoreCheckpoint>,
|
||||
pub chunks: Vec<acp::ContentBlock>,
|
||||
pub checkpoint: Option<Checkpoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Checkpoint {
|
||||
git_checkpoint: GitStoreCheckpoint,
|
||||
pub show: bool,
|
||||
}
|
||||
|
||||
impl UserMessage {
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
let mut markdown = String::new();
|
||||
if let Some(_) = self.checkpoint {
|
||||
if self
|
||||
.checkpoint
|
||||
.as_ref()
|
||||
.map_or(false, |checkpoint| checkpoint.show)
|
||||
{
|
||||
writeln!(markdown, "## User (checkpoint)").unwrap();
|
||||
} else {
|
||||
writeln!(markdown, "## User").unwrap();
|
||||
@@ -98,7 +109,7 @@ pub enum AgentThreadEntry {
|
||||
}
|
||||
|
||||
impl AgentThreadEntry {
|
||||
fn to_markdown(&self, cx: &App) -> String {
|
||||
pub fn to_markdown(&self, cx: &App) -> String {
|
||||
match self {
|
||||
Self::UserMessage(message) => message.to_markdown(cx),
|
||||
Self::AssistantMessage(message) => message.to_markdown(cx),
|
||||
@@ -106,6 +117,14 @@ impl AgentThreadEntry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user_message(&self) -> Option<&UserMessage> {
|
||||
if let AgentThreadEntry::UserMessage(message) = self {
|
||||
Some(message)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
|
||||
if let AgentThreadEntry::ToolCall(call) = self {
|
||||
itertools::Either::Left(call.diffs())
|
||||
@@ -399,7 +418,7 @@ impl ContentBlock {
|
||||
}
|
||||
}
|
||||
|
||||
let new_content = self.extract_content_from_block(block);
|
||||
let new_content = self.block_string_contents(block);
|
||||
|
||||
match self {
|
||||
ContentBlock::Empty => {
|
||||
@@ -409,7 +428,7 @@ impl ContentBlock {
|
||||
markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
|
||||
}
|
||||
ContentBlock::ResourceLink { resource_link } => {
|
||||
let existing_content = Self::resource_link_to_content(&resource_link.uri);
|
||||
let existing_content = Self::resource_link_md(&resource_link.uri);
|
||||
let combined = format!("{}\n{}", existing_content, new_content);
|
||||
|
||||
*self = Self::create_markdown_block(combined, language_registry, cx);
|
||||
@@ -417,14 +436,6 @@ impl ContentBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn resource_link_to_content(uri: &str) -> String {
|
||||
if let Some(uri) = MentionUri::parse(&uri).log_err() {
|
||||
uri.to_link()
|
||||
} else {
|
||||
uri.to_string().clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn create_markdown_block(
|
||||
content: String,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
@@ -436,11 +447,11 @@ impl ContentBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
|
||||
fn block_string_contents(&self, block: acp::ContentBlock) -> String {
|
||||
match block {
|
||||
acp::ContentBlock::Text(text_content) => text_content.text.clone(),
|
||||
acp::ContentBlock::ResourceLink(resource_link) => {
|
||||
Self::resource_link_to_content(&resource_link.uri)
|
||||
Self::resource_link_md(&resource_link.uri)
|
||||
}
|
||||
acp::ContentBlock::Resource(acp::EmbeddedResource {
|
||||
resource:
|
||||
@@ -449,13 +460,24 @@ impl ContentBlock {
|
||||
..
|
||||
}),
|
||||
..
|
||||
}) => Self::resource_link_to_content(&uri),
|
||||
acp::ContentBlock::Image(_)
|
||||
| acp::ContentBlock::Audio(_)
|
||||
| acp::ContentBlock::Resource(_) => String::new(),
|
||||
}) => Self::resource_link_md(&uri),
|
||||
acp::ContentBlock::Image(image) => Self::image_md(&image),
|
||||
acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn resource_link_md(uri: &str) -> String {
|
||||
if let Some(uri) = MentionUri::parse(&uri).log_err() {
|
||||
uri.as_link().to_string()
|
||||
} else {
|
||||
uri.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn image_md(_image: &acp::ImageContent) -> String {
|
||||
"`Image`".into()
|
||||
}
|
||||
|
||||
fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
|
||||
match self {
|
||||
ContentBlock::Empty => "",
|
||||
@@ -770,7 +792,7 @@ impl AcpThread {
|
||||
&mut self,
|
||||
update: acp::SessionUpdate,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
) -> Result<(), acp::Error> {
|
||||
match update {
|
||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||
self.push_user_content_block(None, content, cx);
|
||||
@@ -782,7 +804,7 @@ impl AcpThread {
|
||||
self.push_assistant_content_block(content, true, cx);
|
||||
}
|
||||
acp::SessionUpdate::ToolCall(tool_call) => {
|
||||
self.upsert_tool_call(tool_call, cx);
|
||||
self.upsert_tool_call(tool_call, cx)?;
|
||||
}
|
||||
acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
|
||||
self.update_tool_call(tool_call_update, cx)?;
|
||||
@@ -804,18 +826,25 @@ impl AcpThread {
|
||||
let entries_len = self.entries.len();
|
||||
|
||||
if let Some(last_entry) = self.entries.last_mut()
|
||||
&& let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
|
||||
&& let AgentThreadEntry::UserMessage(UserMessage {
|
||||
id,
|
||||
content,
|
||||
chunks,
|
||||
..
|
||||
}) = last_entry
|
||||
{
|
||||
*id = message_id.or(id.take());
|
||||
content.append(chunk, &language_registry, cx);
|
||||
content.append(chunk.clone(), &language_registry, cx);
|
||||
chunks.push(chunk);
|
||||
let idx = entries_len - 1;
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(idx));
|
||||
} else {
|
||||
let content = ContentBlock::new(chunk, &language_registry, cx);
|
||||
let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
|
||||
self.push_entry(
|
||||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id,
|
||||
content,
|
||||
chunks: vec![chunk],
|
||||
checkpoint: None,
|
||||
}),
|
||||
cx,
|
||||
@@ -911,32 +940,40 @@ impl AcpThread {
|
||||
}
|
||||
|
||||
/// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
|
||||
pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
|
||||
pub fn upsert_tool_call(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<(), acp::Error> {
|
||||
let status = ToolCallStatus::Allowed {
|
||||
status: tool_call.status,
|
||||
};
|
||||
self.upsert_tool_call_inner(tool_call, status, cx)
|
||||
self.upsert_tool_call_inner(tool_call.into(), status, cx)
|
||||
}
|
||||
|
||||
/// Fails if id does not match an existing entry.
|
||||
pub fn upsert_tool_call_inner(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
tool_call_update: acp::ToolCallUpdate,
|
||||
status: ToolCallStatus,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> Result<(), acp::Error> {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
|
||||
let id = call.id.clone();
|
||||
let id = tool_call_update.id.clone();
|
||||
|
||||
if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
|
||||
*current_call = call;
|
||||
if let Some((ix, current_call)) = self.tool_call_mut(&id) {
|
||||
current_call.update_fields(tool_call_update.fields, language_registry, cx);
|
||||
current_call.status = status;
|
||||
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
} else {
|
||||
let call =
|
||||
ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
|
||||
self.push_entry(AgentThreadEntry::ToolCall(call), cx);
|
||||
};
|
||||
|
||||
self.resolve_locations(id, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
|
||||
@@ -1005,10 +1042,10 @@ impl AcpThread {
|
||||
|
||||
pub fn request_tool_call_authorization(
|
||||
&mut self,
|
||||
tool_call: acp::ToolCall,
|
||||
tool_call: acp::ToolCallUpdate,
|
||||
options: Vec<acp::PermissionOption>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> oneshot::Receiver<acp::PermissionOptionId> {
|
||||
) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let status = ToolCallStatus::WaitingForConfirmation {
|
||||
@@ -1016,9 +1053,9 @@ impl AcpThread {
|
||||
respond_tx: tx,
|
||||
};
|
||||
|
||||
self.upsert_tool_call_inner(tool_call, status, cx);
|
||||
self.upsert_tool_call_inner(tool_call, status, cx)?;
|
||||
cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
|
||||
rx
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
pub fn authorize_tool_call(
|
||||
@@ -1134,9 +1171,12 @@ impl AcpThread {
|
||||
self.project.read(cx).languages().clone(),
|
||||
cx,
|
||||
);
|
||||
let request = acp::PromptRequest {
|
||||
prompt: message.clone(),
|
||||
session_id: self.session_id.clone(),
|
||||
};
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
||||
let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||
let message_id = if self
|
||||
.connection
|
||||
.session_editor(&self.session_id, cx)
|
||||
@@ -1150,67 +1190,63 @@ impl AcpThread {
|
||||
AgentThreadEntry::UserMessage(UserMessage {
|
||||
id: message_id.clone(),
|
||||
content: block,
|
||||
chunks: message,
|
||||
checkpoint: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
||||
self.run_turn(cx, async move |this, cx| {
|
||||
let old_checkpoint = git_store
|
||||
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||
.await
|
||||
.context("failed to get old checkpoint")
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((_ix, message)) = this.last_user_message() {
|
||||
message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
|
||||
git_checkpoint,
|
||||
show: false,
|
||||
});
|
||||
}
|
||||
this.connection.prompt(message_id, request, cx)
|
||||
})?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
|
||||
self.run_turn(cx, async move |this, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.connection
|
||||
.resume(&this.session_id, cx)
|
||||
.map(|resume| resume.run(cx))
|
||||
})?
|
||||
.context("resuming a session is not supported")?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
fn run_turn(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
|
||||
) -> BoxFuture<'static, Result<()>> {
|
||||
self.clear_completed_plan_entries(cx);
|
||||
|
||||
let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let cancel_task = self.cancel(cx);
|
||||
let request = acp::PromptRequest {
|
||||
prompt: message,
|
||||
session_id: self.session_id.clone(),
|
||||
};
|
||||
|
||||
self.send_task = Some(cx.spawn({
|
||||
let message_id = message_id.clone();
|
||||
async move |this, cx| {
|
||||
cancel_task.await;
|
||||
|
||||
old_checkpoint_tx.send(old_checkpoint.await).ok();
|
||||
if let Ok(result) = this.update(cx, |this, cx| {
|
||||
this.connection.prompt(message_id, request, cx)
|
||||
}) {
|
||||
tx.send(result.await).log_err();
|
||||
}
|
||||
}
|
||||
self.send_task = Some(cx.spawn(async move |this, cx| {
|
||||
cancel_task.await;
|
||||
tx.send(f(this, cx).await).ok();
|
||||
}));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let old_checkpoint = old_checkpoint_rx
|
||||
.await
|
||||
.map_err(|_| anyhow!("send canceled"))
|
||||
.flatten()
|
||||
.context("failed to get old checkpoint")
|
||||
.log_err();
|
||||
|
||||
let response = rx.await;
|
||||
|
||||
if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
|
||||
let new_checkpoint = git_store
|
||||
.update(cx, |git, cx| git.checkpoint(cx))?
|
||||
.await
|
||||
.context("failed to get new checkpoint")
|
||||
.log_err();
|
||||
if let Some(new_checkpoint) = new_checkpoint {
|
||||
let equal = git_store
|
||||
.update(cx, |git, cx| {
|
||||
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||
})?
|
||||
.await
|
||||
.unwrap_or(true);
|
||||
if !equal {
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some((ix, message)) = this.user_message_mut(&message_id) {
|
||||
message.checkpoint = Some(old_checkpoint);
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
|
||||
.await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
match response {
|
||||
@@ -1282,7 +1318,10 @@ impl AcpThread {
|
||||
return Task::ready(Err(anyhow!("message not found")));
|
||||
};
|
||||
|
||||
let checkpoint = message.checkpoint.clone();
|
||||
let checkpoint = message
|
||||
.checkpoint
|
||||
.as_ref()
|
||||
.map(|c| c.git_checkpoint.clone());
|
||||
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -1304,6 +1343,59 @@ impl AcpThread {
|
||||
})
|
||||
}
|
||||
|
||||
fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
|
||||
let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
|
||||
if let Some(checkpoint) = message.checkpoint.as_ref() {
|
||||
checkpoint.git_checkpoint.clone()
|
||||
} else {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
} else {
|
||||
return Task::ready(Ok(()));
|
||||
};
|
||||
|
||||
let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
|
||||
cx.spawn(async move |this, cx| {
|
||||
let new_checkpoint = new_checkpoint
|
||||
.await
|
||||
.context("failed to get new checkpoint")
|
||||
.log_err();
|
||||
if let Some(new_checkpoint) = new_checkpoint {
|
||||
let equal = git_store
|
||||
.update(cx, |git, cx| {
|
||||
git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
|
||||
})?
|
||||
.await
|
||||
.unwrap_or(true);
|
||||
this.update(cx, |this, cx| {
|
||||
let (ix, message) = this.last_user_message().context("no user message")?;
|
||||
let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
|
||||
checkpoint.show = !equal;
|
||||
cx.emit(AcpThreadEvent::EntryUpdated(ix));
|
||||
anyhow::Ok(())
|
||||
})??;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
|
||||
self.entries
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.find_map(|(ix, entry)| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
Some((ix, message))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
|
||||
self.entries.iter().find_map(|entry| {
|
||||
if let AgentThreadEntry::UserMessage(message) = entry {
|
||||
@@ -1540,6 +1632,7 @@ mod tests {
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::{
|
||||
any::Any,
|
||||
cell::RefCell,
|
||||
path::Path,
|
||||
rc::Rc,
|
||||
@@ -1566,11 +1659,7 @@ mod tests {
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let connection = Rc::new(FakeAgentConnection::new());
|
||||
let thread = cx
|
||||
.spawn(async move |mut cx| {
|
||||
connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||
.await
|
||||
})
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -1690,11 +1779,7 @@ mod tests {
|
||||
));
|
||||
|
||||
let thread = cx
|
||||
.spawn(async move |mut cx| {
|
||||
connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||
.await
|
||||
})
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -1777,7 +1862,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let thread = cx
|
||||
.spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -1840,11 +1925,7 @@ mod tests {
|
||||
}));
|
||||
|
||||
let thread = cx
|
||||
.spawn(async move |mut cx| {
|
||||
connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx)
|
||||
.await
|
||||
})
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -1952,10 +2033,11 @@ mod tests {
|
||||
}
|
||||
}));
|
||||
|
||||
let thread = connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -2012,8 +2094,8 @@ mod tests {
|
||||
.boxed_local()
|
||||
}
|
||||
}));
|
||||
let thread = connection
|
||||
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -2218,7 +2300,7 @@ mod tests {
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(
|
||||
rand::thread_rng()
|
||||
@@ -2228,9 +2310,8 @@ mod tests {
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
);
|
||||
let thread = cx
|
||||
.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
|
||||
.unwrap();
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
@@ -2284,6 +2365,10 @@ mod tests {
|
||||
_session_id: session_id.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct FakeAgentSessionEditor {
|
||||
|
||||
@@ -2,9 +2,9 @@ use crate::AcpThread;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use gpui::{AsyncApp, Entity, SharedString, Task};
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -22,7 +22,7 @@ pub trait AgentConnection {
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
@@ -36,6 +36,14 @@ pub trait AgentConnection {
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn resume(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionResume>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
fn session_editor(
|
||||
@@ -53,12 +61,24 @@ pub trait AgentConnection {
|
||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
|
||||
}
|
||||
|
||||
impl dyn AgentConnection {
|
||||
pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
|
||||
self.into_any().downcast().ok()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentSessionEditor {
|
||||
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
pub trait AgentSessionResume {
|
||||
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
@@ -160,3 +180,159 @@ impl AgentModelList {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
mod test_support {
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{AppContext as _, WeakEntity};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct StubAgentConnection {
|
||||
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
|
||||
}
|
||||
|
||||
impl StubAgentConnection {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
next_prompt_updates: Default::default(),
|
||||
permission_requests: HashMap::default(),
|
||||
sessions: Arc::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
|
||||
*self.next_prompt_updates.lock() = updates;
|
||||
}
|
||||
|
||||
pub fn with_permission_requests(
|
||||
mut self,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
) -> Self {
|
||||
self.permission_requests = permission_requests;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn send_update(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
update: acp::SessionUpdate,
|
||||
cx: &mut App,
|
||||
) {
|
||||
self.sessions
|
||||
.lock()
|
||||
.get(&session_id)
|
||||
.unwrap()
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::App,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method_id: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<gpui::Result<acp::PromptResponse>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
for update in self.next_prompt_updates.lock().drain(..) {
|
||||
let thread = thread.clone();
|
||||
let update = update.clone();
|
||||
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
|
||||
&& let Some(options) = self.permission_requests.get(&tool_call.id)
|
||||
{
|
||||
Some((tool_call.clone(), options.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let task = cx.spawn(async move |cx| {
|
||||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(
|
||||
tool_call.clone().into(),
|
||||
options.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
permission?.await?;
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
cx.spawn(async move |_| {
|
||||
try_join_all(tasks).await?;
|
||||
Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
_session_id: &agent_client_protocol::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
Some(Rc::new(StubAgentSessionEditor))
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct StubAgentSessionEditor;
|
||||
|
||||
impl AgentSessionEditor for StubAgentSessionEditor {
|
||||
fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
pub use test_support::*;
|
||||
|
||||
@@ -1,13 +1,46 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Result, bail};
|
||||
use std::path::PathBuf;
|
||||
use agent::ThreadId;
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use file_icons::FileIcons;
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use std::{
|
||||
fmt,
|
||||
ops::Range,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
};
|
||||
use ui::{App, IconName, SharedString};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MentionUri {
|
||||
File(PathBuf),
|
||||
Symbol(PathBuf, String),
|
||||
Thread(acp::SessionId),
|
||||
Rule(String),
|
||||
File {
|
||||
abs_path: PathBuf,
|
||||
is_directory: bool,
|
||||
},
|
||||
Symbol {
|
||||
path: PathBuf,
|
||||
name: String,
|
||||
line_range: Range<u32>,
|
||||
},
|
||||
Thread {
|
||||
id: ThreadId,
|
||||
name: String,
|
||||
},
|
||||
TextThread {
|
||||
path: PathBuf,
|
||||
name: String,
|
||||
},
|
||||
Rule {
|
||||
id: PromptId,
|
||||
name: String,
|
||||
},
|
||||
Selection {
|
||||
path: PathBuf,
|
||||
line_range: Range<u32>,
|
||||
},
|
||||
Fetch {
|
||||
url: Url,
|
||||
},
|
||||
}
|
||||
|
||||
impl MentionUri {
|
||||
@@ -17,58 +50,219 @@ impl MentionUri {
|
||||
match url.scheme() {
|
||||
"file" => {
|
||||
if let Some(fragment) = url.fragment() {
|
||||
Ok(Self::Symbol(path.into(), fragment.into()))
|
||||
let range = fragment
|
||||
.strip_prefix("L")
|
||||
.context("Line range must start with \"L\"")?;
|
||||
let (start, end) = range
|
||||
.split_once(":")
|
||||
.context("Line range must use colon as separator")?;
|
||||
let line_range = start
|
||||
.parse::<u32>()
|
||||
.context("Parsing line range start")?
|
||||
.checked_sub(1)
|
||||
.context("Line numbers should be 1-based")?
|
||||
..end
|
||||
.parse::<u32>()
|
||||
.context("Parsing line range end")?
|
||||
.checked_sub(1)
|
||||
.context("Line numbers should be 1-based")?;
|
||||
if let Some(name) = single_query_param(&url, "symbol")? {
|
||||
Ok(Self::Symbol {
|
||||
name,
|
||||
path: path.into(),
|
||||
line_range,
|
||||
})
|
||||
} else {
|
||||
Ok(Self::Selection {
|
||||
path: path.into(),
|
||||
line_range,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
let file_path =
|
||||
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
|
||||
let is_directory = input.ends_with("/");
|
||||
|
||||
Ok(Self::File(file_path))
|
||||
Ok(Self::File {
|
||||
abs_path: file_path,
|
||||
is_directory,
|
||||
})
|
||||
}
|
||||
}
|
||||
"zed" => {
|
||||
if let Some(thread) = path.strip_prefix("/agent/thread/") {
|
||||
Ok(Self::Thread(acp::SessionId(thread.into())))
|
||||
} else if let Some(rule) = path.strip_prefix("/agent/rule/") {
|
||||
Ok(Self::Rule(rule.into()))
|
||||
if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
|
||||
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
|
||||
Ok(Self::Thread {
|
||||
id: thread_id.into(),
|
||||
name,
|
||||
})
|
||||
} else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
|
||||
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
|
||||
Ok(Self::TextThread {
|
||||
path: path.into(),
|
||||
name,
|
||||
})
|
||||
} else if let Some(rule_id) = path.strip_prefix("/agent/rule/") {
|
||||
let name = single_query_param(&url, "name")?.context("Missing rule name")?;
|
||||
let rule_id = UserPromptId(rule_id.parse()?);
|
||||
Ok(Self::Rule {
|
||||
id: rule_id.into(),
|
||||
name,
|
||||
})
|
||||
} else {
|
||||
bail!("invalid zed url: {:?}", input);
|
||||
}
|
||||
}
|
||||
"http" | "https" => Ok(MentionUri::Fetch { url }),
|
||||
other => bail!("unrecognized scheme {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
MentionUri::File(path) => path.file_name().unwrap().to_string_lossy().into_owned(),
|
||||
MentionUri::Symbol(_path, name) => name.clone(),
|
||||
MentionUri::Thread(thread) => thread.to_string(),
|
||||
MentionUri::Rule(rule) => rule.clone(),
|
||||
MentionUri::File { abs_path, .. } => abs_path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.into_owned(),
|
||||
MentionUri::Symbol { name, .. } => name.clone(),
|
||||
MentionUri::Thread { name, .. } => name.clone(),
|
||||
MentionUri::TextThread { name, .. } => name.clone(),
|
||||
MentionUri::Rule { name, .. } => name.clone(),
|
||||
MentionUri::Selection {
|
||||
path, line_range, ..
|
||||
} => selection_name(path, line_range),
|
||||
MentionUri::Fetch { url } => url.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_link(&self) -> String {
|
||||
let name = self.name();
|
||||
let uri = self.to_uri();
|
||||
format!("[{name}]({uri})")
|
||||
}
|
||||
|
||||
pub fn to_uri(&self) -> String {
|
||||
pub fn icon_path(&self, cx: &mut App) -> SharedString {
|
||||
match self {
|
||||
MentionUri::File(path) => {
|
||||
format!("file://{}", path.display())
|
||||
}
|
||||
MentionUri::Symbol(path, name) => {
|
||||
format!("file://{}#{}", path.display(), name)
|
||||
}
|
||||
MentionUri::Thread(thread) => {
|
||||
format!("zed:///agent/thread/{}", thread.0)
|
||||
}
|
||||
MentionUri::Rule(rule) => {
|
||||
format!("zed:///agent/rule/{}", rule)
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
if *is_directory {
|
||||
FileIcons::get_folder_icon(false, cx)
|
||||
.unwrap_or_else(|| IconName::Folder.path().into())
|
||||
} else {
|
||||
FileIcons::get_icon(&abs_path, cx)
|
||||
.unwrap_or_else(|| IconName::File.path().into())
|
||||
}
|
||||
}
|
||||
MentionUri::Symbol { .. } => IconName::Code.path().into(),
|
||||
MentionUri::Thread { .. } => IconName::Thread.path().into(),
|
||||
MentionUri::TextThread { .. } => IconName::Thread.path().into(),
|
||||
MentionUri::Rule { .. } => IconName::Reader.path().into(),
|
||||
MentionUri::Selection { .. } => IconName::Reader.path().into(),
|
||||
MentionUri::Fetch { .. } => IconName::ToolWeb.path().into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_link<'a>(&'a self) -> MentionLink<'a> {
|
||||
MentionLink(self)
|
||||
}
|
||||
|
||||
pub fn to_uri(&self) -> Url {
|
||||
match self {
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
let mut url = Url::parse("file:///").unwrap();
|
||||
let mut path = abs_path.to_string_lossy().to_string();
|
||||
if *is_directory && !path.ends_with("/") {
|
||||
path.push_str("/");
|
||||
}
|
||||
url.set_path(&path);
|
||||
url
|
||||
}
|
||||
MentionUri::Symbol {
|
||||
path,
|
||||
name,
|
||||
line_range,
|
||||
} => {
|
||||
let mut url = Url::parse("file:///").unwrap();
|
||||
url.set_path(&path.to_string_lossy());
|
||||
url.query_pairs_mut().append_pair("symbol", name);
|
||||
url.set_fragment(Some(&format!(
|
||||
"L{}:{}",
|
||||
line_range.start + 1,
|
||||
line_range.end + 1
|
||||
)));
|
||||
url
|
||||
}
|
||||
MentionUri::Selection { path, line_range } => {
|
||||
let mut url = Url::parse("file:///").unwrap();
|
||||
url.set_path(&path.to_string_lossy());
|
||||
url.set_fragment(Some(&format!(
|
||||
"L{}:{}",
|
||||
line_range.start + 1,
|
||||
line_range.end + 1
|
||||
)));
|
||||
url
|
||||
}
|
||||
MentionUri::Thread { name, id } => {
|
||||
let mut url = Url::parse("zed:///").unwrap();
|
||||
url.set_path(&format!("/agent/thread/{id}"));
|
||||
url.query_pairs_mut().append_pair("name", name);
|
||||
url
|
||||
}
|
||||
MentionUri::TextThread { path, name } => {
|
||||
let mut url = Url::parse("zed:///").unwrap();
|
||||
url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy()));
|
||||
url.query_pairs_mut().append_pair("name", name);
|
||||
url
|
||||
}
|
||||
MentionUri::Rule { name, id } => {
|
||||
let mut url = Url::parse("zed:///").unwrap();
|
||||
url.set_path(&format!("/agent/rule/{id}"));
|
||||
url.query_pairs_mut().append_pair("name", name);
|
||||
url
|
||||
}
|
||||
MentionUri::Fetch { url } => url.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for MentionUri {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
Self::parse(s)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MentionLink<'a>(&'a MentionUri);
|
||||
|
||||
impl fmt::Display for MentionLink<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "[@{}]({})", self.0.name(), self.0.to_uri())
|
||||
}
|
||||
}
|
||||
|
||||
fn single_query_param(url: &Url, name: &'static str) -> Result<Option<String>> {
|
||||
let pairs = url.query_pairs().collect::<Vec<_>>();
|
||||
match pairs.as_slice() {
|
||||
[] => Ok(None),
|
||||
[(k, v)] => {
|
||||
if k != name {
|
||||
bail!("invalid query parameter")
|
||||
}
|
||||
|
||||
Ok(Some(v.to_string()))
|
||||
}
|
||||
_ => bail!("too many query pairs"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
|
||||
format!(
|
||||
"{} ({}:{})",
|
||||
path.file_name().unwrap_or_default().display(),
|
||||
line_range.start + 1,
|
||||
line_range.end + 1
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -76,50 +270,191 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mention_uri_parse_and_display() {
|
||||
// Test file URI
|
||||
fn test_parse_file_uri() {
|
||||
let file_uri = "file:///path/to/file.rs";
|
||||
let parsed = MentionUri::parse(file_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"),
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert!(!is_directory);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), file_uri);
|
||||
assert_eq!(parsed.to_uri().to_string(), file_uri);
|
||||
}
|
||||
|
||||
// Test symbol URI
|
||||
let symbol_uri = "file:///path/to/file.rs#MySymbol";
|
||||
#[test]
|
||||
fn test_parse_directory_uri() {
|
||||
let file_uri = "file:///path/to/dir/";
|
||||
let parsed = MentionUri::parse(file_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::File {
|
||||
abs_path,
|
||||
is_directory,
|
||||
} => {
|
||||
assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir/");
|
||||
assert!(is_directory);
|
||||
}
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), file_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_directory_uri_with_slash() {
|
||||
let uri = MentionUri::File {
|
||||
abs_path: PathBuf::from("/path/to/dir/"),
|
||||
is_directory: true,
|
||||
};
|
||||
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_directory_uri_without_slash() {
|
||||
let uri = MentionUri::File {
|
||||
abs_path: PathBuf::from("/path/to/dir"),
|
||||
is_directory: true,
|
||||
};
|
||||
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_symbol_uri() {
|
||||
let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
|
||||
let parsed = MentionUri::parse(symbol_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Symbol(path, symbol) => {
|
||||
MentionUri::Symbol {
|
||||
path,
|
||||
name,
|
||||
line_range,
|
||||
} => {
|
||||
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert_eq!(symbol, "MySymbol");
|
||||
assert_eq!(name, "MySymbol");
|
||||
assert_eq!(line_range.start, 9);
|
||||
assert_eq!(line_range.end, 19);
|
||||
}
|
||||
_ => panic!("Expected Symbol variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), symbol_uri);
|
||||
assert_eq!(parsed.to_uri().to_string(), symbol_uri);
|
||||
}
|
||||
|
||||
// Test thread URI
|
||||
let thread_uri = "zed:///agent/thread/session123";
|
||||
#[test]
|
||||
fn test_parse_selection_uri() {
|
||||
let selection_uri = "file:///path/to/file.rs#L5:15";
|
||||
let parsed = MentionUri::parse(selection_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Selection { path, line_range } => {
|
||||
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert_eq!(line_range.start, 4);
|
||||
assert_eq!(line_range.end, 14);
|
||||
}
|
||||
_ => panic!("Expected Selection variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), selection_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_thread_uri() {
|
||||
let thread_uri = "zed:///agent/thread/session123?name=Thread+name";
|
||||
let parsed = MentionUri::parse(thread_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Thread(session_id) => assert_eq!(session_id.0.as_ref(), "session123"),
|
||||
MentionUri::Thread {
|
||||
id: thread_id,
|
||||
name,
|
||||
} => {
|
||||
assert_eq!(thread_id.to_string(), "session123");
|
||||
assert_eq!(name, "Thread name");
|
||||
}
|
||||
_ => panic!("Expected Thread variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), thread_uri);
|
||||
assert_eq!(parsed.to_uri().to_string(), thread_uri);
|
||||
}
|
||||
|
||||
// Test rule URI
|
||||
let rule_uri = "zed:///agent/rule/my_rule";
|
||||
#[test]
|
||||
fn test_parse_rule_uri() {
|
||||
let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some+rule";
|
||||
let parsed = MentionUri::parse(rule_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Rule(rule) => assert_eq!(rule, "my_rule"),
|
||||
MentionUri::Rule { id, name } => {
|
||||
assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52");
|
||||
assert_eq!(name, "Some rule");
|
||||
}
|
||||
_ => panic!("Expected Rule variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), rule_uri);
|
||||
assert_eq!(parsed.to_uri().to_string(), rule_uri);
|
||||
}
|
||||
|
||||
// Test invalid scheme
|
||||
assert!(MentionUri::parse("http://example.com").is_err());
|
||||
#[test]
|
||||
fn test_parse_fetch_http_uri() {
|
||||
let http_uri = "http://example.com/path?query=value#fragment";
|
||||
let parsed = MentionUri::parse(http_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Fetch { url } => {
|
||||
assert_eq!(url.to_string(), http_uri);
|
||||
}
|
||||
_ => panic!("Expected Fetch variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), http_uri);
|
||||
}
|
||||
|
||||
// Test invalid zed path
|
||||
#[test]
|
||||
fn test_parse_fetch_https_uri() {
|
||||
let https_uri = "https://example.com/api/endpoint";
|
||||
let parsed = MentionUri::parse(https_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Fetch { url } => {
|
||||
assert_eq!(url.to_string(), https_uri);
|
||||
}
|
||||
_ => panic!("Expected Fetch variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri().to_string(), https_uri);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_scheme() {
|
||||
assert!(MentionUri::parse("ftp://example.com").is_err());
|
||||
assert!(MentionUri::parse("ssh://example.com").is_err());
|
||||
assert!(MentionUri::parse("unknown://example.com").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_zed_path() {
|
||||
assert!(MentionUri::parse("zed:///invalid/path").is_err());
|
||||
assert!(MentionUri::parse("zed:///agent/unknown/test").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_line_range_format() {
|
||||
// Missing L prefix
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
|
||||
|
||||
// Missing colon separator
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
|
||||
|
||||
// Invalid numbers
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_query_parameters() {
|
||||
// Invalid query parameter name
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
|
||||
|
||||
// Too many query parameters
|
||||
assert!(
|
||||
MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_based_line_numbers() {
|
||||
// Test that 0-based line numbers are rejected (should be 1-based)
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
|
||||
assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -844,11 +844,17 @@ impl Thread {
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !equal {
|
||||
this.update(cx, |this, cx| {
|
||||
this.insert_checkpoint(pending_checkpoint, cx)
|
||||
})?;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.pending_checkpoint = if equal {
|
||||
Some(pending_checkpoint)
|
||||
} else {
|
||||
this.insert_checkpoint(pending_checkpoint, cx);
|
||||
Some(ThreadCheckpoint {
|
||||
message_id: this.next_message_id,
|
||||
git_checkpoint: final_checkpoint,
|
||||
})
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -205,6 +205,22 @@ impl ThreadStore {
|
||||
(this, ready_rx)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
|
||||
Self {
|
||||
project,
|
||||
tools: cx.new(|_| ToolWorkingSet::default()),
|
||||
prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
prompt_store: None,
|
||||
context_server_tool_ids: HashMap::default(),
|
||||
threads: Vec::new(),
|
||||
project_context: SharedProjectContext::default(),
|
||||
reload_system_prompt_tx: mpsc::channel(0).0,
|
||||
_reload_system_prompt_task: Task::ready(()),
|
||||
_subscriptions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
&mut self,
|
||||
_project: Entity<Project>,
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||
use crate::{
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
|
||||
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
||||
WebSearchTool,
|
||||
AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
|
||||
DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
|
||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
|
||||
ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
|
||||
};
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
@@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
@@ -21,6 +21,7 @@ use prompt_store::{
|
||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use std::any::Any;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
@@ -426,9 +427,9 @@ impl NativeAgent {
|
||||
self.models.refresh_list(cx);
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, _| {
|
||||
let model_id = LanguageModels::model_id(&thread.selected_model);
|
||||
let model_id = LanguageModels::model_id(&thread.model());
|
||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.selected_model = model.clone();
|
||||
thread.set_model(model.clone());
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -439,6 +440,125 @@ impl NativeAgent {
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl NativeAgentConnection {
|
||||
pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
|
||||
self.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
}
|
||||
|
||||
fn run_turn(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
cx: &mut App,
|
||||
f: impl 'static
|
||||
+ FnOnce(
|
||||
Entity<Thread>,
|
||||
&mut App,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
|
||||
agent
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
||||
}) else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
let mut response_stream = match f(thread, cx) {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
cx.spawn(async move |cx| {
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
}) => {
|
||||
let recv = acp_thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||
})?;
|
||||
cx.background_spawn(async move {
|
||||
if let Some(recv) = recv.log_err()
|
||||
&& let Some(option) = recv
|
||||
.await
|
||||
.context("authorization sender was dropped")
|
||||
.log_err()
|
||||
{
|
||||
response
|
||||
.send(option)
|
||||
.map(|_| anyhow!("authorization receiver was dropped"))
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
@@ -472,7 +592,7 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model.clone();
|
||||
thread.set_model(model.clone());
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
@@ -502,7 +622,7 @@ impl AgentModelSelector for NativeAgentConnection {
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
let model = thread.read(cx).selected_model.clone();
|
||||
let model = thread.read(cx).model().clone();
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Provider not found")));
|
||||
@@ -522,7 +642,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
let agent = self.0.clone();
|
||||
log::info!("Creating new thread for project at: {:?}", cwd);
|
||||
@@ -583,22 +703,22 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
default_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(NowTool);
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(TerminalTool::new(project.clone(), cx));
|
||||
// TODO: Needs to be conditional based on zed model or not
|
||||
thread.add_tool(WebSearchTool);
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
|
||||
thread
|
||||
});
|
||||
|
||||
@@ -644,25 +764,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let id = id.expect("UserMessageId is required");
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Get session
|
||||
let (thread, acp_thread) = agent
|
||||
.update(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get_mut(&session_id)
|
||||
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
log::error!("Session not found: {}", session_id);
|
||||
anyhow::anyhow!("Session not found")
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
self.run_turn(session_id, cx, |thread, cx| {
|
||||
let content: Vec<UserMessageContent> = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
@@ -672,99 +777,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
log::debug!("Message id: {:?}", id);
|
||||
log::debug!("Message content: {:?}", content);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let mut response_stream =
|
||||
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
AgentResponseEvent::Text(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
false,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::Thinking(text) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.push_assistant_content_block(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text,
|
||||
annotations: None,
|
||||
}),
|
||||
true,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
|
||||
tool_call,
|
||||
options,
|
||||
response,
|
||||
}) => {
|
||||
let recv = acp_thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, options, cx)
|
||||
})?;
|
||||
cx.background_spawn(async move {
|
||||
if let Some(option) = recv
|
||||
.await
|
||||
.context("authorization sender was dropped")
|
||||
.log_err()
|
||||
{
|
||||
response
|
||||
.send(option)
|
||||
.map(|_| anyhow!("authorization receiver was dropped"))
|
||||
.log_err();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
AgentResponseEvent::ToolCall(tool_call) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.upsert_tool_call(tool_call, cx)
|
||||
})?;
|
||||
}
|
||||
AgentResponseEvent::ToolCallUpdate(update) => {
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.update_tool_call(update, cx)
|
||||
})??;
|
||||
}
|
||||
AgentResponseEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
return Ok(acp::PromptResponse { stop_reason });
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
// TODO: Consider sending an error message to the UI
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(acp::PromptResponse {
|
||||
stop_reason: acp::StopReason::EndTurn,
|
||||
})
|
||||
Ok(thread.update(cx, |thread, cx| {
|
||||
log::info!(
|
||||
"Sending message to thread with model: {:?}",
|
||||
thread.model().name()
|
||||
);
|
||||
thread.send(id, content, cx)
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
fn resume(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
|
||||
Some(Rc::new(NativeAgentSessionResume {
|
||||
connection: self.clone(),
|
||||
session_id: session_id.clone(),
|
||||
}) as _)
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
log::info!("Cancelling on session: {}", session_id);
|
||||
self.0.update(cx, |agent, cx| {
|
||||
@@ -786,6 +819,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||
})
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||
@@ -796,6 +833,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||
}
|
||||
}
|
||||
|
||||
struct NativeAgentSessionResume {
|
||||
connection: NativeAgentConnection,
|
||||
session_id: acp::SessionId,
|
||||
}
|
||||
|
||||
impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
|
||||
fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
|
||||
self.connection
|
||||
.run_turn(self.session_id.clone(), cx, |thread, cx| {
|
||||
thread.update(cx, |thread, cx| thread.resume(cx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -940,11 +991,7 @@ mod tests {
|
||||
// Create a thread/session
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
Rc::new(connection.clone()).new_thread(
|
||||
project.clone(),
|
||||
Path::new("/a"),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -961,7 +1008,7 @@ mod tests {
|
||||
agent.read_with(cx, |agent, _| {
|
||||
let session = agent.sessions.get(&session_id).unwrap();
|
||||
session.thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.selected_model.id().0, "fake");
|
||||
assert_eq!(thread.model().id().0, "fake");
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -12,10 +12,11 @@ use gpui::{
|
||||
};
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
|
||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use reqwest_client::ReqwestClient;
|
||||
@@ -129,6 +130,134 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_prompt_caching(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
// Send initial user message and verify it's cached
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Message 1"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
completion.messages[1..],
|
||||
vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Message 1".into()],
|
||||
cache: true
|
||||
}]
|
||||
);
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
|
||||
"Response to Message 1".into(),
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Send another user message and verify only the latest is cached
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Message 2"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
completion.messages[1..],
|
||||
vec![
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Message 1".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec!["Response to Message 1".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Message 2".into()],
|
||||
cache: true
|
||||
}
|
||||
]
|
||||
);
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
|
||||
"Response to Message 2".into(),
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate a tool call and verify that the latest tool result is cached
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: "tool_1".into(),
|
||||
name: EchoTool.name().into(),
|
||||
raw_input: json!({"text": "test"}).to_string(),
|
||||
input: json!({"text": "test"}),
|
||||
is_input_complete: true,
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
let tool_result = LanguageModelToolResult {
|
||||
tool_use_id: "tool_1".into(),
|
||||
tool_name: EchoTool.name().into(),
|
||||
is_error: false,
|
||||
content: "test".into(),
|
||||
output: Some("test".into()),
|
||||
};
|
||||
assert_eq!(
|
||||
completion.messages[1..],
|
||||
vec![
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Message 1".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec!["Response to Message 1".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Message 2".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec!["Response to Message 2".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Use the echo tool".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::ToolResult(tool_result)],
|
||||
cache: true
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
@@ -394,8 +523,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: "tool_id_1".into(),
|
||||
name: EchoTool.name().into(),
|
||||
raw_input: "{}".into(),
|
||||
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
|
||||
is_input_complete: true,
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
let tool_result = LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_1".into(),
|
||||
tool_name: EchoTool.name().into(),
|
||||
is_error: false,
|
||||
content: "def".into(),
|
||||
output: Some("def".into()),
|
||||
};
|
||||
assert_eq!(
|
||||
completion.messages[1..],
|
||||
vec![
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["abc".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use.clone())],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::ToolResult(tool_result.clone())],
|
||||
cache: true
|
||||
},
|
||||
]
|
||||
);
|
||||
|
||||
// Simulate reaching tool use limit.
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||
assert!(
|
||||
last_event
|
||||
.unwrap_err()
|
||||
.is::<language_model::ToolUseLimitReachedError>()
|
||||
);
|
||||
|
||||
let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
completion.messages[1..],
|
||||
vec![
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["abc".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::ToolResult(tool_result)],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Continue where you left off".into()],
|
||||
cache: true
|
||||
}
|
||||
]
|
||||
);
|
||||
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
|
||||
fake_model.end_last_completion_stream();
|
||||
events.collect::<Vec<_>>().await;
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## Assistant
|
||||
|
||||
Done
|
||||
"}
|
||||
)
|
||||
});
|
||||
|
||||
// Ensure we error if calling resume when tool use limit was *not* reached.
|
||||
let error = thread
|
||||
.update(cx, |thread, cx| thread.resume(cx))
|
||||
.unwrap_err();
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"can only resume after tool use limit is reached"
|
||||
)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let tool_use = LanguageModelToolUse {
|
||||
id: "tool_id_1".into(),
|
||||
name: EchoTool.name().into(),
|
||||
raw_input: "{}".into(),
|
||||
input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
|
||||
is_input_complete: true,
|
||||
};
|
||||
let tool_result = LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_1".into(),
|
||||
tool_name: EchoTool.name().into(),
|
||||
is_error: false,
|
||||
content: "def".into(),
|
||||
output: Some("def".into()),
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||
));
|
||||
fake_model.end_last_completion_stream();
|
||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||
assert!(
|
||||
last_event
|
||||
.unwrap_err()
|
||||
.is::<language_model::ToolUseLimitReachedError>()
|
||||
);
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), vec!["ghi"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let completion = fake_model.pending_completions().pop().unwrap();
|
||||
assert_eq!(
|
||||
completion.messages[1..],
|
||||
vec![
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["abc".into()],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::ToolUse(tool_use)],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::ToolResult(tool_result)],
|
||||
cache: false
|
||||
},
|
||||
LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["ghi".into()],
|
||||
cache: true
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
async fn expect_tool_call(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||
) -> acp::ToolCall {
|
||||
let event = events
|
||||
.next()
|
||||
@@ -411,7 +726,7 @@ async fn expect_tool_call(
|
||||
}
|
||||
|
||||
async fn expect_tool_call_update_fields(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||
) -> acp::ToolCallUpdate {
|
||||
let event = events
|
||||
.next()
|
||||
@@ -429,7 +744,7 @@ async fn expect_tool_call_update_fields(
|
||||
}
|
||||
|
||||
async fn next_tool_call_authorization(
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
|
||||
) -> ToolCallAuthorization {
|
||||
loop {
|
||||
let event = events
|
||||
@@ -841,7 +1156,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
// Create a thread using new_thread
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, cx))
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
@@ -1007,9 +1322,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> Vec<acp::StopReason> {
|
||||
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::future;
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
/// The text to echo.
|
||||
text: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
@@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
|
||||
use acp_thread::{MentionUri, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::{AgentProfileId, AgentSettings};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::adapt_schema_to_format;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
|
||||
use collections::IndexMap;
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
@@ -14,10 +14,10 @@ use futures::{
|
||||
};
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
|
||||
LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
@@ -25,14 +25,57 @@ use schemars::{JsonSchema, Schema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use smol::stream::StreamExt;
|
||||
use std::fmt::Write;
|
||||
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
|
||||
use std::{fmt::Write, ops::Range};
|
||||
use util::{ResultExt, markdown::MarkdownCodeBlock};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(
|
||||
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
|
||||
)]
|
||||
pub struct ThreadId(Arc<str>);
|
||||
|
||||
impl ThreadId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ThreadId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ThreadId {
|
||||
fn from(value: &str) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
/// The ID of the user prompt that initiated a request.
|
||||
///
|
||||
/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key).
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptId(Arc<str>);
|
||||
|
||||
impl PromptId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PromptId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Message {
|
||||
User(UserMessage),
|
||||
Agent(AgentMessage),
|
||||
Resume,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
@@ -47,6 +90,7 @@ impl Message {
|
||||
match self {
|
||||
Message::User(message) => message.to_markdown(),
|
||||
Message::Agent(message) => message.to_markdown(),
|
||||
Message::Resume => "[resumed after tool use limit was reached]".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -79,9 +123,9 @@ impl UserMessage {
|
||||
}
|
||||
UserMessageContent::Mention { uri, content } => {
|
||||
if !content.is_empty() {
|
||||
markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content));
|
||||
let _ = write!(&mut markdown, "{}\n\n{}\n", uri.as_link(), content);
|
||||
} else {
|
||||
markdown.push_str(&format!("{}\n", uri.to_link()));
|
||||
let _ = write!(&mut markdown, "{}\n", uri.as_link());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,12 +148,14 @@ impl UserMessage {
|
||||
const OPEN_FILES_TAG: &str = "<files>";
|
||||
const OPEN_SYMBOLS_TAG: &str = "<symbols>";
|
||||
const OPEN_THREADS_TAG: &str = "<threads>";
|
||||
const OPEN_FETCH_TAG: &str = "<fetched_urls>";
|
||||
const OPEN_RULES_TAG: &str =
|
||||
"<rules>\nThe user has specified the following rules that should be applied:\n";
|
||||
|
||||
let mut file_context = OPEN_FILES_TAG.to_string();
|
||||
let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
|
||||
let mut thread_context = OPEN_THREADS_TAG.to_string();
|
||||
let mut fetch_context = OPEN_FETCH_TAG.to_string();
|
||||
let mut rules_context = OPEN_RULES_TAG.to_string();
|
||||
|
||||
for chunk in &self.content {
|
||||
@@ -122,21 +168,40 @@ impl UserMessage {
|
||||
}
|
||||
UserMessageContent::Mention { uri, content } => {
|
||||
match uri {
|
||||
MentionUri::File(path) | MentionUri::Symbol(path, _) => {
|
||||
MentionUri::File { abs_path, .. } => {
|
||||
write!(
|
||||
&mut symbol_context,
|
||||
"\n{}",
|
||||
MarkdownCodeBlock {
|
||||
tag: &codeblock_tag(&path),
|
||||
tag: &codeblock_tag(&abs_path, None),
|
||||
text: &content.to_string(),
|
||||
}
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
MentionUri::Thread(_session_id) => {
|
||||
MentionUri::Symbol {
|
||||
path, line_range, ..
|
||||
}
|
||||
| MentionUri::Selection {
|
||||
path, line_range, ..
|
||||
} => {
|
||||
write!(
|
||||
&mut rules_context,
|
||||
"\n{}",
|
||||
MarkdownCodeBlock {
|
||||
tag: &codeblock_tag(&path, Some(line_range)),
|
||||
text: &content
|
||||
}
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
MentionUri::Thread { .. } => {
|
||||
write!(&mut thread_context, "\n{}\n", content).ok();
|
||||
}
|
||||
MentionUri::Rule(_user_prompt_id) => {
|
||||
MentionUri::TextThread { .. } => {
|
||||
write!(&mut thread_context, "\n{}\n", content).ok();
|
||||
}
|
||||
MentionUri::Rule { .. } => {
|
||||
write!(
|
||||
&mut rules_context,
|
||||
"\n{}",
|
||||
@@ -147,9 +212,12 @@ impl UserMessage {
|
||||
)
|
||||
.ok();
|
||||
}
|
||||
MentionUri::Fetch { url } => {
|
||||
write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok();
|
||||
}
|
||||
}
|
||||
|
||||
language_model::MessageContent::Text(uri.to_link())
|
||||
language_model::MessageContent::Text(uri.as_link().to_string())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -179,6 +247,13 @@ impl UserMessage {
|
||||
.push(language_model::MessageContent::Text(thread_context));
|
||||
}
|
||||
|
||||
if fetch_context.len() > OPEN_FETCH_TAG.len() {
|
||||
fetch_context.push_str("</fetched_urls>\n");
|
||||
message
|
||||
.content
|
||||
.push(language_model::MessageContent::Text(fetch_context));
|
||||
}
|
||||
|
||||
if rules_context.len() > OPEN_RULES_TAG.len() {
|
||||
rules_context.push_str("</user_rules>\n");
|
||||
message
|
||||
@@ -200,6 +275,26 @@ impl UserMessage {
|
||||
}
|
||||
}
|
||||
|
||||
fn codeblock_tag(full_path: &Path, line_range: Option<&Range<u32>>) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
|
||||
let _ = write!(result, "{} ", extension);
|
||||
}
|
||||
|
||||
let _ = write!(result, "{}", full_path.display());
|
||||
|
||||
if let Some(range) = line_range {
|
||||
if range.start == range.end {
|
||||
let _ = write!(result, ":{}", range.start + 1);
|
||||
} else {
|
||||
let _ = write!(result, ":{}-{}", range.start + 1, range.end + 1);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
impl AgentMessage {
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut markdown = String::from("## Assistant\n\n");
|
||||
@@ -269,7 +364,11 @@ impl AgentMessage {
|
||||
}
|
||||
|
||||
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
let mut content = Vec::with_capacity(self.content.len());
|
||||
let mut assistant_message = LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::with_capacity(self.content.len()),
|
||||
cache: false,
|
||||
};
|
||||
for chunk in &self.content {
|
||||
let chunk = match chunk {
|
||||
AgentMessageContent::Text(text) => {
|
||||
@@ -291,29 +390,30 @@ impl AgentMessage {
|
||||
language_model::MessageContent::Image(value.clone())
|
||||
}
|
||||
};
|
||||
content.push(chunk);
|
||||
assistant_message.content.push(chunk);
|
||||
}
|
||||
|
||||
let mut messages = vec![LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
let mut user_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
}];
|
||||
};
|
||||
|
||||
if !self.tool_results.is_empty() {
|
||||
let mut tool_results = Vec::with_capacity(self.tool_results.len());
|
||||
for tool_result in self.tool_results.values() {
|
||||
tool_results.push(language_model::MessageContent::ToolResult(
|
||||
for tool_result in self.tool_results.values() {
|
||||
user_message
|
||||
.content
|
||||
.push(language_model::MessageContent::ToolResult(
|
||||
tool_result.clone(),
|
||||
));
|
||||
}
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: tool_results,
|
||||
cache: false,
|
||||
});
|
||||
}
|
||||
|
||||
let mut messages = Vec::new();
|
||||
if !assistant_message.content.is_empty() {
|
||||
messages.push(assistant_message);
|
||||
}
|
||||
if !user_message.content.is_empty() {
|
||||
messages.push(user_message);
|
||||
}
|
||||
messages
|
||||
}
|
||||
}
|
||||
@@ -348,25 +448,28 @@ pub enum AgentResponseEvent {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolCallAuthorization {
|
||||
pub tool_call: acp::ToolCall,
|
||||
pub tool_call: acp::ToolCallUpdate,
|
||||
pub options: Vec<acp::PermissionOption>,
|
||||
pub response: oneshot::Sender<acp::PermissionOptionId>,
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
id: ThreadId,
|
||||
prompt_id: PromptId,
|
||||
messages: Vec<Message>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
pending_agent_message: Option<AgentMessage>,
|
||||
pending_message: Option<AgentMessage>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
tool_use_limit_reached: bool,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
profile_id: AgentProfileId,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
}
|
||||
@@ -378,21 +481,24 @@ impl Thread {
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
default_model: Arc<dyn LanguageModel>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
prompt_id: PromptId::new(),
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
running_turn: None,
|
||||
pending_agent_message: None,
|
||||
pending_message: None,
|
||||
tools: BTreeMap::default(),
|
||||
tool_use_limit_reached: false,
|
||||
context_server_registry,
|
||||
profile_id,
|
||||
project_context,
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
model,
|
||||
project,
|
||||
action_log,
|
||||
}
|
||||
@@ -406,13 +512,25 @@ impl Thread {
|
||||
&self.action_log
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
pub fn model(&self) -> &Arc<dyn LanguageModel> {
|
||||
&self.model
|
||||
}
|
||||
|
||||
pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
pub fn completion_mode(&self) -> CompletionMode {
|
||||
self.completion_mode
|
||||
}
|
||||
|
||||
pub fn set_completion_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn last_message(&self) -> Option<Message> {
|
||||
if let Some(message) = self.pending_agent_message.clone() {
|
||||
if let Some(message) = self.pending_message.clone() {
|
||||
Some(Message::Agent(message))
|
||||
} else {
|
||||
self.messages.last().cloned()
|
||||
@@ -427,6 +545,10 @@ impl Thread {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
pub fn profile(&self) -> &AgentProfileId {
|
||||
&self.profile_id
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||
self.profile_id = profile_id;
|
||||
}
|
||||
@@ -434,7 +556,7 @@ impl Thread {
|
||||
pub fn cancel(&mut self) {
|
||||
// TODO: do we need to emit a stop::cancel for ACP?
|
||||
self.running_turn.take();
|
||||
self.flush_pending_agent_message();
|
||||
self.flush_pending_message();
|
||||
}
|
||||
|
||||
pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
|
||||
@@ -448,96 +570,108 @@ impl Thread {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn resume(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
|
||||
anyhow::ensure!(
|
||||
self.tool_use_limit_reached,
|
||||
"can only resume after tool use limit is reached"
|
||||
);
|
||||
|
||||
self.messages.push(Message::Resume);
|
||||
cx.notify();
|
||||
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
Ok(self.run_turn(cx))
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send<T>(
|
||||
&mut self,
|
||||
message_id: UserMessageId,
|
||||
id: UserMessageId,
|
||||
content: impl IntoIterator<Item = T>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
|
||||
where
|
||||
T: Into<UserMessageContent>,
|
||||
{
|
||||
let model = self.selected_model.clone();
|
||||
log::info!("Thread::send called with model: {:?}", self.model.name());
|
||||
self.advance_prompt_id();
|
||||
|
||||
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
|
||||
log::info!("Thread::send called with model: {:?}", model.name());
|
||||
log::debug!("Thread::send content: {:?}", content);
|
||||
|
||||
self.messages
|
||||
.push(Message::User(UserMessage { id, content }));
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
|
||||
let user_message_ix = self.messages.len();
|
||||
self.messages.push(Message::User(UserMessage {
|
||||
id: message_id,
|
||||
content,
|
||||
}));
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
self.run_turn(cx)
|
||||
}
|
||||
|
||||
fn run_turn(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
|
||||
let model = self.model.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||
let event_stream = AgentResponseEventStream(events_tx);
|
||||
let message_ix = self.messages.len().saturating_sub(1);
|
||||
self.tool_use_limit_reached = false;
|
||||
self.running_turn = Some(cx.spawn(async move |this, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
let turn_result: Result<()> = async {
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
'outer: loop {
|
||||
loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(completion_intent, cx)
|
||||
let request = this.update(cx, |this, cx| {
|
||||
this.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
|
||||
let mut tool_use_limit_reached = false;
|
||||
let mut tool_uses = FuturesUnordered::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::Stop(reason)) => {
|
||||
match event? {
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::ToolUseLimitReached,
|
||||
) => {
|
||||
tool_use_limit_reached = true;
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
event_stream.send_stop(reason);
|
||||
if reason == StopReason::Refusal {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.pending_agent_message = None;
|
||||
thread.messages.truncate(user_message_ix);
|
||||
this.update(cx, |this, _cx| {
|
||||
this.flush_pending_message();
|
||||
this.messages.truncate(message_ix);
|
||||
})?;
|
||||
break 'outer;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok(event) => {
|
||||
event => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Error in completion stream: {:?}", error);
|
||||
event_stream.send_error(error);
|
||||
break;
|
||||
this.update(cx, |this, cx| {
|
||||
tool_uses.extend(this.handle_streamed_completion_event(
|
||||
event,
|
||||
&event_stream,
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
break;
|
||||
}
|
||||
log::info!("Found {} tool uses to execute", tool_uses.len());
|
||||
|
||||
// As tool results trickle in, insert them in the last user
|
||||
// message so that they can be sent on the next tick of the
|
||||
// agentic loop.
|
||||
let used_tools = tool_uses.is_empty();
|
||||
while let Some(tool_result) = tool_uses.next().await {
|
||||
log::info!("Tool finished {:?}", tool_result);
|
||||
|
||||
@@ -553,29 +687,30 @@ impl Thread {
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread
|
||||
.pending_agent_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})
|
||||
.ok();
|
||||
this.update(cx, |this, _cx| {
|
||||
this.pending_message()
|
||||
.tool_results
|
||||
.insert(tool_result.tool_use_id.clone(), tool_result);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
|
||||
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
if tool_use_limit_reached {
|
||||
log::info!("Tool use limit reached, completing turn");
|
||||
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
|
||||
return Err(language_model::ToolUseLimitReachedError.into());
|
||||
} else if used_tools {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
return Ok(());
|
||||
} else {
|
||||
this.update(cx, |this, _| this.flush_pending_message())?;
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.flush_pending_agent_message())
|
||||
.ok();
|
||||
|
||||
this.update(cx, |this, _| this.flush_pending_message()).ok();
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
event_stream.send_error(error);
|
||||
@@ -617,7 +752,8 @@ impl Thread {
|
||||
|
||||
match event {
|
||||
StartMessage { .. } => {
|
||||
self.messages.push(Message::Agent(AgentMessage::default()));
|
||||
self.flush_pending_message();
|
||||
self.pending_message = Some(AgentMessage::default());
|
||||
}
|
||||
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
|
||||
Thinking { text, signature } => {
|
||||
@@ -650,12 +786,12 @@ impl Thread {
|
||||
fn handle_text_event(
|
||||
&mut self,
|
||||
new_text: String,
|
||||
events_stream: &AgentResponseEventStream,
|
||||
event_stream: &AgentResponseEventStream,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
events_stream.send_text(&new_text);
|
||||
event_stream.send_text(&new_text);
|
||||
|
||||
let last_message = self.pending_agent_message();
|
||||
let last_message = self.pending_message();
|
||||
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
@@ -676,7 +812,7 @@ impl Thread {
|
||||
) {
|
||||
event_stream.send_thinking(&new_text);
|
||||
|
||||
let last_message = self.pending_agent_message();
|
||||
let last_message = self.pending_message();
|
||||
if let Some(AgentMessageContent::Thinking { text, signature }) =
|
||||
last_message.content.last_mut()
|
||||
{
|
||||
@@ -693,7 +829,7 @@ impl Thread {
|
||||
}
|
||||
|
||||
fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.pending_agent_message();
|
||||
let last_message = self.pending_message();
|
||||
last_message
|
||||
.content
|
||||
.push(AgentMessageContent::RedactedThinking(data));
|
||||
@@ -717,7 +853,7 @@ impl Thread {
|
||||
}
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let last_message = self.pending_agent_message();
|
||||
let last_message = self.pending_message();
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let AgentMessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
@@ -765,13 +901,14 @@ impl Thread {
|
||||
|
||||
let fs = self.project.read(cx).fs().clone();
|
||||
let tool_event_stream =
|
||||
ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
|
||||
ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs));
|
||||
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::InProgress),
|
||||
..Default::default()
|
||||
});
|
||||
let supports_images = self.selected_model.supports_images();
|
||||
let supports_images = self.model.supports_images();
|
||||
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
|
||||
log::info!("Running tool {}", tool_use.name);
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
let tool_result = tool_result.await.and_then(|output| {
|
||||
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
|
||||
@@ -820,12 +957,12 @@ impl Thread {
|
||||
}
|
||||
}
|
||||
|
||||
fn pending_agent_message(&mut self) -> &mut AgentMessage {
|
||||
self.pending_agent_message.get_or_insert_default()
|
||||
fn pending_message(&mut self) -> &mut AgentMessage {
|
||||
self.pending_message.get_or_insert_default()
|
||||
}
|
||||
|
||||
fn flush_pending_agent_message(&mut self) {
|
||||
let Some(mut message) = self.pending_agent_message.take() else {
|
||||
fn flush_pending_message(&mut self) {
|
||||
let Some(mut message) = self.pending_message.take() else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -874,7 +1011,7 @@ impl Thread {
|
||||
name: tool_name,
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(self.selected_model.tool_input_format())
|
||||
.input_schema(self.model.tool_input_format())
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
@@ -886,15 +1023,15 @@ impl Thread {
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
thread_id: Some(self.id.to_string()),
|
||||
prompt_id: Some(self.prompt_id.to_string()),
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
mode: Some(self.completion_mode.into()),
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
temperature: AgentSettings::temperature_for_model(self.model(), cx),
|
||||
thinking_allowed: true,
|
||||
};
|
||||
|
||||
@@ -907,7 +1044,7 @@ impl Thread {
|
||||
.profiles
|
||||
.get(&self.profile_id)
|
||||
.context("profile not found")?;
|
||||
let provider_id = self.selected_model.provider_id();
|
||||
let provider_id = self.model.provider_id();
|
||||
|
||||
Ok(self
|
||||
.tools
|
||||
@@ -943,13 +1080,26 @@ impl Thread {
|
||||
match message {
|
||||
Message::User(message) => messages.push(message.to_request()),
|
||||
Message::Agent(message) => messages.extend(message.to_request()),
|
||||
Message::Resume => messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec!["Continue where you left off".into()],
|
||||
cache: false,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = self.pending_agent_message.as_ref() {
|
||||
if let Some(message) = self.pending_message.as_ref() {
|
||||
messages.extend(message.to_request());
|
||||
}
|
||||
|
||||
if let Some(last_user_message) = messages
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find(|message| message.role == Role::User)
|
||||
{
|
||||
last_user_message.cache = true;
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
@@ -962,13 +1112,17 @@ impl Thread {
|
||||
markdown.push_str(&message.to_markdown());
|
||||
}
|
||||
|
||||
if let Some(message) = self.pending_agent_message.as_ref() {
|
||||
if let Some(message) = self.pending_message.as_ref() {
|
||||
markdown.push('\n');
|
||||
markdown.push_str(&message.to_markdown());
|
||||
}
|
||||
|
||||
markdown
|
||||
}
|
||||
|
||||
fn advance_prompt_id(&mut self) {
|
||||
self.prompt_id = PromptId::new();
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
@@ -1095,9 +1249,7 @@ where
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AgentResponseEventStream(
|
||||
mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
|
||||
|
||||
impl AgentResponseEventStream {
|
||||
fn send_text(&self, text: &str) {
|
||||
@@ -1184,16 +1336,14 @@ impl AgentResponseEventStream {
|
||||
}
|
||||
}
|
||||
|
||||
fn send_error(&self, error: LanguageModelCompletionError) {
|
||||
self.0.unbounded_send(Err(error)).ok();
|
||||
fn send_error(&self, error: impl Into<anyhow::Error>) {
|
||||
self.0.unbounded_send(Err(error.into())).ok();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCallEventStream {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
kind: acp::ToolKind,
|
||||
input: serde_json::Value,
|
||||
stream: AgentResponseEventStream,
|
||||
fs: Option<Arc<dyn Fs>>,
|
||||
}
|
||||
@@ -1201,35 +1351,21 @@ pub struct ToolCallEventStream {
|
||||
impl ToolCallEventStream {
|
||||
#[cfg(test)]
|
||||
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
|
||||
|
||||
let stream = ToolCallEventStream::new(
|
||||
&LanguageModelToolUse {
|
||||
id: "test_id".into(),
|
||||
name: "test_tool".into(),
|
||||
raw_input: String::new(),
|
||||
input: serde_json::Value::Null,
|
||||
is_input_complete: true,
|
||||
},
|
||||
acp::ToolKind::Other,
|
||||
AgentResponseEventStream(events_tx),
|
||||
None,
|
||||
);
|
||||
let stream =
|
||||
ToolCallEventStream::new("test_id".into(), AgentResponseEventStream(events_tx), None);
|
||||
|
||||
(stream, ToolCallEventStreamReceiver(events_rx))
|
||||
}
|
||||
|
||||
fn new(
|
||||
tool_use: &LanguageModelToolUse,
|
||||
kind: acp::ToolKind,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
stream: AgentResponseEventStream,
|
||||
fs: Option<Arc<dyn Fs>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
kind,
|
||||
input: tool_use.input.clone(),
|
||||
tool_use_id,
|
||||
stream,
|
||||
fs,
|
||||
}
|
||||
@@ -1276,12 +1412,13 @@ impl ToolCallEventStream {
|
||||
.0
|
||||
.unbounded_send(Ok(AgentResponseEvent::ToolCallAuthorization(
|
||||
ToolCallAuthorization {
|
||||
tool_call: AgentResponseEventStream::initial_tool_call(
|
||||
&self.tool_use_id,
|
||||
title.into(),
|
||||
self.kind.clone(),
|
||||
self.input.clone(),
|
||||
),
|
||||
tool_call: acp::ToolCallUpdate {
|
||||
id: acp::ToolCallId(self.tool_use_id.to_string().into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
title: Some(title.into()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
options: vec![
|
||||
acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("always_allow".into()),
|
||||
@@ -1323,9 +1460,7 @@ impl ToolCallEventStream {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub struct ToolCallEventStreamReceiver(
|
||||
mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
);
|
||||
pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
|
||||
|
||||
#[cfg(test)]
|
||||
impl ToolCallEventStreamReceiver {
|
||||
@@ -1353,7 +1488,7 @@ impl ToolCallEventStreamReceiver {
|
||||
|
||||
#[cfg(test)]
|
||||
impl std::ops::Deref for ToolCallEventStreamReceiver {
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
|
||||
type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
@@ -1367,18 +1502,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver {
|
||||
}
|
||||
}
|
||||
|
||||
fn codeblock_tag(full_path: &Path) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) {
|
||||
let _ = write!(result, "{} ", extension);
|
||||
}
|
||||
|
||||
let _ = write!(result, "{}", full_path.display());
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
impl From<&str> for UserMessageContent {
|
||||
fn from(text: &str) -> Self {
|
||||
Self::Text(text.into())
|
||||
|
||||
@@ -241,7 +241,7 @@ impl AgentTool for EditFileTool {
|
||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||
});
|
||||
let thread = self.thread.read(cx);
|
||||
let model = thread.selected_model.clone();
|
||||
let model = thread.model().clone();
|
||||
let action_log = thread.action_log().clone();
|
||||
|
||||
let authorize = self.authorize(&input, &event_stream, cx);
|
||||
@@ -1001,7 +1001,10 @@ mod tests {
|
||||
});
|
||||
|
||||
let event = stream_rx.expect_authorization().await;
|
||||
assert_eq!(event.tool_call.title, "test 1 (local settings)");
|
||||
assert_eq!(
|
||||
event.tool_call.fields.title,
|
||||
Some("test 1 (local settings)".into())
|
||||
);
|
||||
|
||||
// Test 2: Path outside project should require confirmation
|
||||
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||
@@ -1018,7 +1021,7 @@ mod tests {
|
||||
});
|
||||
|
||||
let event = stream_rx.expect_authorization().await;
|
||||
assert_eq!(event.tool_call.title, "test 2");
|
||||
assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
|
||||
|
||||
// Test 3: Relative path without .zed should not require confirmation
|
||||
let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
|
||||
@@ -1051,7 +1054,10 @@ mod tests {
|
||||
)
|
||||
});
|
||||
let event = stream_rx.expect_authorization().await;
|
||||
assert_eq!(event.tool_call.title, "test 4 (local settings)");
|
||||
assert_eq!(
|
||||
event.tool_call.fields.title,
|
||||
Some("test 4 (local settings)".into())
|
||||
);
|
||||
|
||||
// Test 5: When always_allow_tool_actions is enabled, no confirmation needed
|
||||
cx.update(|cx| {
|
||||
|
||||
@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{cell::RefCell, path::Path, rc::Rc};
|
||||
use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
|
||||
use ui::App;
|
||||
use util::ResultExt as _;
|
||||
|
||||
@@ -135,9 +135,9 @@ impl acp_old::Client for OldAcpClientDelegate {
|
||||
let response = cx
|
||||
.update(|cx| {
|
||||
self.thread.borrow().update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(tool_call, acp_options, cx)
|
||||
thread.request_tool_call_authorization(tool_call.into(), acp_options, cx)
|
||||
})
|
||||
})?
|
||||
})??
|
||||
.context("Failed to update thread")?
|
||||
.await;
|
||||
|
||||
@@ -168,7 +168,7 @@ impl acp_old::Client for OldAcpClientDelegate {
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?
|
||||
})??
|
||||
.context("Failed to update thread")?;
|
||||
|
||||
Ok(acp_old::PushToolCallResponse {
|
||||
@@ -423,7 +423,7 @@ impl AgentConnection for AcpConnection {
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let task = self.connection.request_any(
|
||||
acp_old::InitializeParams {
|
||||
@@ -507,4 +507,8 @@ impl AgentConnection for AcpConnection {
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use agent_client_protocol::{self as acp, Agent as _};
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use futures::AsyncBufReadExt as _;
|
||||
use futures::channel::oneshot;
|
||||
use futures::io::BufReader;
|
||||
use project::Project;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::{any::Any, cell::RefCell};
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
@@ -40,12 +42,13 @@ impl AcpConnection {
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdout = child.stdout.take().expect("Failed to take stdout");
|
||||
let stdin = child.stdin.take().expect("Failed to take stdin");
|
||||
let stdout = child.stdout.take().context("Failed to take stdout")?;
|
||||
let stdin = child.stdin.take().context("Failed to take stdin")?;
|
||||
let stderr = child.stderr.take().context("Failed to take stderr")?;
|
||||
log::trace!("Spawned (pid: {})", child.id());
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
@@ -63,6 +66,18 @@ impl AcpConnection {
|
||||
|
||||
let io_task = cx.background_spawn(io_task);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let mut stderr = BufReader::new(stderr);
|
||||
let mut line = String::new();
|
||||
while let Ok(n) = stderr.read_line(&mut line).await
|
||||
&& n > 0
|
||||
{
|
||||
log::warn!("agent stderr: {}", &line);
|
||||
line.clear();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
@@ -111,7 +126,7 @@ impl AgentConnection for AcpConnection {
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let conn = self.connection.clone();
|
||||
let sessions = self.sessions.clone();
|
||||
@@ -191,6 +206,10 @@ impl AgentConnection for AcpConnection {
|
||||
.spawn(async move { conn.cancel(params).await })
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientDelegate {
|
||||
@@ -214,7 +233,7 @@ impl acp::Client for ClientDelegate {
|
||||
thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
|
||||
})?;
|
||||
|
||||
let result = rx.await;
|
||||
let result = rx?.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
|
||||
|
||||
@@ -6,6 +6,7 @@ use context_server::listener::McpServerTool;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::process::Child;
|
||||
use std::any::Any;
|
||||
use std::cell::RefCell;
|
||||
use std::fmt::Display;
|
||||
use std::path::Path;
|
||||
@@ -13,7 +14,7 @@ use std::rc::Rc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use futures::channel::oneshot;
|
||||
use futures::{AsyncBufReadExt, AsyncWriteExt};
|
||||
use futures::{
|
||||
@@ -74,7 +75,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let cwd = cwd.to_owned();
|
||||
cx.spawn(async move |cx| {
|
||||
@@ -129,12 +130,25 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
&cwd,
|
||||
)?;
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
let stdout = child.stdout.take().context("Failed to take stdout")?;
|
||||
let stdin = child.stdin.take().context("Failed to take stdin")?;
|
||||
let stderr = child.stderr.take().context("Failed to take stderr")?;
|
||||
|
||||
let pid = child.id();
|
||||
log::trace!("Spawned (pid: {})", pid);
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let mut stderr = BufReader::new(stderr);
|
||||
let mut line = String::new();
|
||||
while let Ok(n) = stderr.read_line(&mut line).await
|
||||
&& n > 0
|
||||
{
|
||||
log::warn!("agent stderr: {}", &line);
|
||||
line.clear();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let mut outgoing_rx = Some(outgoing_rx);
|
||||
|
||||
@@ -289,6 +303,10 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -340,7 +358,7 @@ fn spawn_claude(
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
@@ -542,8 +560,9 @@ impl ClaudeAgentSession {
|
||||
thread.upsert_tool_call(
|
||||
claude_tool.as_acp(acp::ToolCallId(id.into())),
|
||||
cx,
|
||||
);
|
||||
)?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ impl McpServerTool for PermissionTool {
|
||||
let chosen_option = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_authorization(
|
||||
claude_tool.as_acp(tool_call_id),
|
||||
claude_tool.as_acp(tool_call_id).into(),
|
||||
vec![
|
||||
acp::PermissionOption {
|
||||
id: allow_option_id.clone(),
|
||||
@@ -169,7 +169,7 @@ impl McpServerTool for PermissionTool {
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
})??
|
||||
.await?;
|
||||
|
||||
let response = if chosen_option == allow_option_id {
|
||||
|
||||
@@ -422,8 +422,8 @@ pub async fn new_test_thread(
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let thread = connection
|
||||
.new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -309,7 +309,7 @@ pub struct AgentSettingsContent {
|
||||
///
|
||||
/// Default: true
|
||||
expand_terminal_card: Option<bool>,
|
||||
/// Whether to always use cmd-enter (or ctrl-enter on Linux) to send messages in the agent panel.
|
||||
/// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel.
|
||||
///
|
||||
/// Default: false
|
||||
use_modifier_to_send: Option<bool>,
|
||||
|
||||
@@ -50,7 +50,6 @@ fuzzy.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
indoc.workspace = true
|
||||
inventory.workspace = true
|
||||
itertools.workspace = true
|
||||
@@ -93,6 +92,7 @@ time.workspace = true
|
||||
time_format.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
url.workspace = true
|
||||
urlencoding.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
@@ -102,6 +102,9 @@ workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
acp_thread = { workspace = true, features = ["test-support"] }
|
||||
agent = { workspace = true, features = ["test-support"] }
|
||||
assistant_context = { workspace = true, features = ["test-support"] }
|
||||
assistant_tools.workspace = true
|
||||
buffer_diff = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod completion_provider;
|
||||
mod entry_view_state;
|
||||
mod message_editor;
|
||||
mod model_selector;
|
||||
mod model_selector_popover;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
444
crates/agent_ui/src/acp/entry_view_state.rs
Normal file
444
crates/agent_ui/src/acp/entry_view_state.rs
Normal file
@@ -0,0 +1,444 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use acp_thread::{AcpThread, AgentThreadEntry};
|
||||
use agent::{TextThreadStore, ThreadStore};
|
||||
use collections::HashMap;
|
||||
use editor::{Editor, EditorMode, MinimapVisibility};
|
||||
use gpui::{
|
||||
AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
|
||||
WeakEntity, Window,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use project::Project;
|
||||
use settings::Settings as _;
|
||||
use terminal_view::TerminalView;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Context, TextSize};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
|
||||
pub struct EntryViewState {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
entries: Vec<Entry>,
|
||||
}
|
||||
|
||||
impl EntryViewState {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
text_thread_store: Entity<TextThreadStore>,
|
||||
) -> Self {
|
||||
Self {
|
||||
workspace,
|
||||
project,
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn entry(&self, index: usize) -> Option<&Entry> {
|
||||
self.entries.get(index)
|
||||
}
|
||||
|
||||
pub fn sync_entry(
|
||||
&mut self,
|
||||
index: usize,
|
||||
thread: &Entity<AcpThread>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(thread_entry) = thread.read(cx).entries().get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match thread_entry {
|
||||
AgentThreadEntry::UserMessage(message) => {
|
||||
let has_id = message.id.is_some();
|
||||
let chunks = message.chunks.clone();
|
||||
let message_editor = cx.new(|cx| {
|
||||
let mut editor = MessageEditor::new(
|
||||
self.workspace.clone(),
|
||||
self.project.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.text_thread_store.clone(),
|
||||
editor::EditorMode::AutoHeight {
|
||||
min_lines: 1,
|
||||
max_lines: None,
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
if !has_id {
|
||||
editor.set_read_only(true, cx);
|
||||
}
|
||||
editor.set_message(chunks, window, cx);
|
||||
editor
|
||||
});
|
||||
cx.subscribe(&message_editor, move |_, editor, event, cx| {
|
||||
cx.emit(EntryViewEvent {
|
||||
entry_index: index,
|
||||
view_event: ViewEvent::MessageEditorEvent(editor, *event),
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
self.set_entry(index, Entry::UserMessage(message_editor));
|
||||
}
|
||||
AgentThreadEntry::ToolCall(tool_call) => {
|
||||
let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
|
||||
let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
|
||||
|
||||
let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
|
||||
views
|
||||
} else {
|
||||
self.set_entry(index, Entry::empty());
|
||||
let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
|
||||
unreachable!()
|
||||
};
|
||||
views
|
||||
};
|
||||
|
||||
for terminal in terminals {
|
||||
views.entry(terminal.entity_id()).or_insert_with(|| {
|
||||
create_terminal(
|
||||
self.workspace.clone(),
|
||||
self.project.clone(),
|
||||
terminal.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any()
|
||||
});
|
||||
}
|
||||
|
||||
for diff in diffs {
|
||||
views
|
||||
.entry(diff.entity_id())
|
||||
.or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
|
||||
}
|
||||
}
|
||||
AgentThreadEntry::AssistantMessage(_) => {
|
||||
if index == self.entries.len() {
|
||||
self.entries.push(Entry::empty())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn set_entry(&mut self, index: usize, entry: Entry) {
|
||||
if index == self.entries.len() {
|
||||
self.entries.push(entry);
|
||||
} else {
|
||||
self.entries[index] = entry;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, range: Range<usize>) {
|
||||
self.entries.drain(range);
|
||||
}
|
||||
|
||||
pub fn settings_changed(&mut self, cx: &mut App) {
|
||||
for entry in self.entries.iter() {
|
||||
match entry {
|
||||
Entry::UserMessage { .. } => {}
|
||||
Entry::Content(response_views) => {
|
||||
for view in response_views.values() {
|
||||
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
|
||||
diff_editor.update(cx, |diff_editor, cx| {
|
||||
diff_editor.set_text_style_refinement(
|
||||
diff_editor_text_style_refinement(cx),
|
||||
);
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<EntryViewEvent> for EntryViewState {}
|
||||
|
||||
pub struct EntryViewEvent {
|
||||
pub entry_index: usize,
|
||||
pub view_event: ViewEvent,
|
||||
}
|
||||
|
||||
pub enum ViewEvent {
|
||||
MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
|
||||
}
|
||||
|
||||
pub enum Entry {
|
||||
UserMessage(Entity<MessageEditor>),
|
||||
Content(HashMap<EntityId, AnyEntity>),
|
||||
}
|
||||
|
||||
impl Entry {
|
||||
pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
|
||||
match self {
|
||||
Self::UserMessage(editor) => Some(editor),
|
||||
Entry::Content(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
|
||||
self.content_map()?
|
||||
.get(&diff.entity_id())
|
||||
.cloned()
|
||||
.map(|entity| entity.downcast::<Editor>().unwrap())
|
||||
}
|
||||
|
||||
pub fn terminal(
|
||||
&self,
|
||||
terminal: &Entity<acp_thread::Terminal>,
|
||||
) -> Option<Entity<TerminalView>> {
|
||||
self.content_map()?
|
||||
.get(&terminal.entity_id())
|
||||
.cloned()
|
||||
.map(|entity| entity.downcast::<TerminalView>().unwrap())
|
||||
}
|
||||
|
||||
fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
|
||||
match self {
|
||||
Self::Content(map) => Some(map),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn empty() -> Self {
|
||||
Self::Content(HashMap::default())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn has_content(&self) -> bool {
|
||||
match self {
|
||||
Self::Content(map) => !map.is_empty(),
|
||||
Self::UserMessage(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_terminal(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
terminal: Entity<acp_thread::Terminal>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<TerminalView> {
|
||||
cx.new(|cx| {
|
||||
let mut view = TerminalView::new(
|
||||
terminal.read(cx).inner().clone(),
|
||||
workspace.clone(),
|
||||
None,
|
||||
project.downgrade(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
view.set_embedded_mode(Some(1000), cx);
|
||||
view
|
||||
})
|
||||
}
|
||||
|
||||
fn create_editor_diff(
|
||||
diff: Entity<acp_thread::Diff>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Entity<Editor> {
|
||||
cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
diff.read(cx).multibuffer().clone(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_vertical_scrollbar(false, cx);
|
||||
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::None, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
|
||||
editor
|
||||
})
|
||||
}
|
||||
|
||||
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
|
||||
TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{path::Path, rc::Rc};
|
||||
|
||||
use acp_thread::{AgentConnection, StubAgentConnection};
|
||||
use agent::{TextThreadStore, ThreadStore};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
|
||||
use editor::{EditorSettings, RowInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext as _, SemanticVersion, TestAppContext};
|
||||
|
||||
use crate::acp::entry_view_state::EntryViewState;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use pretty_assertions::assert_matches;
|
||||
use project::Project;
|
||||
use serde_json::json;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use theme::ThemeSettings;
|
||||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_diff_sync(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
json!({
|
||||
"hello.txt": "hi world"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
|
||||
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let tool_call = acp::ToolCall {
|
||||
id: acp::ToolCallId("tool".into()),
|
||||
title: "Tool call".into(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::InProgress,
|
||||
content: vec![acp::ToolCallContent::Diff {
|
||||
diff: acp::Diff {
|
||||
path: "/project/hello.txt".into(),
|
||||
old_text: Some("hi world".into()),
|
||||
new_text: "hello world".into(),
|
||||
},
|
||||
}],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
};
|
||||
let connection = Rc::new(StubAgentConnection::new());
|
||||
let thread = cx
|
||||
.update(|_, cx| {
|
||||
connection
|
||||
.clone()
|
||||
.new_thread(project.clone(), Path::new(path!("/project")), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
|
||||
|
||||
cx.update(|_, cx| {
|
||||
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
|
||||
});
|
||||
|
||||
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
|
||||
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
|
||||
|
||||
let view_state = cx.new(|_cx| {
|
||||
EntryViewState::new(
|
||||
workspace.downgrade(),
|
||||
project.clone(),
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
)
|
||||
});
|
||||
|
||||
view_state.update_in(cx, |view_state, window, cx| {
|
||||
view_state.sync_entry(0, &thread, window, cx)
|
||||
});
|
||||
|
||||
let diff = thread.read_with(cx, |thread, _cx| {
|
||||
thread
|
||||
.entries()
|
||||
.get(0)
|
||||
.unwrap()
|
||||
.diffs()
|
||||
.next()
|
||||
.unwrap()
|
||||
.clone()
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let diff_editor = view_state.read_with(cx, |view_state, _cx| {
|
||||
view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
|
||||
});
|
||||
assert_eq!(
|
||||
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
|
||||
"hi world\nhello world"
|
||||
);
|
||||
let row_infos = diff_editor.read_with(cx, |editor, cx| {
|
||||
let multibuffer = editor.buffer().read(cx);
|
||||
multibuffer
|
||||
.snapshot(cx)
|
||||
.row_infos(MultiBufferRow(0))
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
assert_matches!(
|
||||
row_infos.as_slice(),
|
||||
[
|
||||
RowInfo {
|
||||
multibuffer_row: Some(MultiBufferRow(0)),
|
||||
diff_status: Some(DiffHunkStatus {
|
||||
kind: DiffHunkStatusKind::Deleted,
|
||||
..
|
||||
}),
|
||||
..
|
||||
},
|
||||
RowInfo {
|
||||
multibuffer_row: Some(MultiBufferRow(1)),
|
||||
diff_status: Some(DiffHunkStatus {
|
||||
kind: DiffHunkStatusKind::Added,
|
||||
..
|
||||
}),
|
||||
..
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AgentSettings::register(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
EditorSettings::register(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -465,7 +465,7 @@ impl AgentConfiguration {
|
||||
"modifier-send",
|
||||
"Use modifier to submit a message",
|
||||
Some(
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux or Windows) required to send messages.".into(),
|
||||
),
|
||||
use_modifier_to_send,
|
||||
move |state, _window, cx| {
|
||||
@@ -1035,7 +1035,6 @@ fn extension_only_provides_context_server(manifest: &ExtensionManifest) -> bool
|
||||
&& manifest.grammars.is_empty()
|
||||
&& manifest.language_servers.is_empty()
|
||||
&& manifest.slash_commands.is_empty()
|
||||
&& manifest.indexed_docs_providers.is_empty()
|
||||
&& manifest.snippets.is_none()
|
||||
&& manifest.debug_locators.is_empty()
|
||||
}
|
||||
|
||||
@@ -65,8 +65,8 @@ use theme::ThemeSettings;
|
||||
use time::UtcOffset;
|
||||
use ui::utils::WithRemSize;
|
||||
use ui::{
|
||||
Banner, ButtonLike, Callout, ContextMenu, ContextMenuEntry, ElevationIndex, KeyBinding,
|
||||
PopoverMenu, PopoverMenuHandle, ProgressBar, Tab, Tooltip, prelude::*,
|
||||
Banner, Callout, ContextMenu, ContextMenuEntry, ElevationIndex, KeyBinding, PopoverMenu,
|
||||
PopoverMenuHandle, ProgressBar, Tab, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{
|
||||
@@ -818,10 +818,10 @@ impl AgentPanel {
|
||||
ActiveView::Thread { thread, .. } => {
|
||||
thread.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
}
|
||||
ActiveView::ExternalAgentThread { thread_view, .. } => {
|
||||
thread_view.update(cx, |thread_element, cx| thread_element.cancel(cx));
|
||||
}
|
||||
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => {}
|
||||
ActiveView::ExternalAgentThread { .. }
|
||||
| ActiveView::TextThread { .. }
|
||||
| ActiveView::History
|
||||
| ActiveView::Configuration => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -967,6 +967,9 @@ impl AgentPanel {
|
||||
agent: crate::ExternalAgent,
|
||||
}
|
||||
|
||||
let thread_store = self.thread_store.clone();
|
||||
let text_thread_store = self.context_store.clone();
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let server: Rc<dyn AgentServer> = match agent_choice {
|
||||
Some(agent) => {
|
||||
@@ -1001,7 +1004,15 @@ impl AgentPanel {
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let thread_view = cx.new(|cx| {
|
||||
crate::acp::AcpThreadView::new(server, workspace.clone(), project, window, cx)
|
||||
crate::acp::AcpThreadView::new(
|
||||
server,
|
||||
workspace.clone(),
|
||||
project,
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
this.set_active_view(ActiveView::ExternalAgentThread { thread_view }, window, cx);
|
||||
@@ -1246,13 +1257,11 @@ impl AgentPanel {
|
||||
ThemeSettings::get_global(cx).agent_font_size(cx) + delta;
|
||||
let _ = settings
|
||||
.agent_font_size
|
||||
.insert(theme::clamp_font_size(agent_font_size).0);
|
||||
.insert(Some(theme::clamp_font_size(agent_font_size).into()));
|
||||
},
|
||||
);
|
||||
} else {
|
||||
theme::adjust_agent_font_size(cx, |size| {
|
||||
*size += delta;
|
||||
});
|
||||
theme::adjust_agent_font_size(cx, |size| size + delta);
|
||||
}
|
||||
}
|
||||
WhichFontSize::BufferFont => {
|
||||
@@ -1974,9 +1983,7 @@ impl AgentPanel {
|
||||
|
||||
PopoverMenu::new("agent-nav-menu")
|
||||
.trigger_with_tooltip(
|
||||
IconButton::new("agent-nav-menu", icon)
|
||||
.icon_size(IconSize::Small)
|
||||
.style(ui::ButtonStyle::Subtle),
|
||||
IconButton::new("agent-nav-menu", icon).icon_size(IconSize::Small),
|
||||
{
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
@@ -2113,9 +2120,10 @@ impl AgentPanel {
|
||||
.pl_1()
|
||||
.gap_1()
|
||||
.child(match &self.active_view {
|
||||
ActiveView::History | ActiveView::Configuration => {
|
||||
self.render_toolbar_back_button(cx).into_any_element()
|
||||
}
|
||||
ActiveView::History | ActiveView::Configuration => div()
|
||||
.pl(DynamicSpacing::Base04.rems(cx))
|
||||
.child(self.render_toolbar_back_button(cx))
|
||||
.into_any_element(),
|
||||
_ => self
|
||||
.render_recent_entries_menu(IconName::MenuAlt, cx)
|
||||
.into_any_element(),
|
||||
@@ -2153,33 +2161,7 @@ impl AgentPanel {
|
||||
|
||||
let new_thread_menu = PopoverMenu::new("new_thread_menu")
|
||||
.trigger_with_tooltip(
|
||||
ButtonLike::new("new_thread_menu_btn").child(
|
||||
h_flex()
|
||||
.group("agent-selector")
|
||||
.gap_1p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.relative()
|
||||
.size_4()
|
||||
.justify_center()
|
||||
.child(
|
||||
h_flex()
|
||||
.group_hover("agent-selector", |s| s.invisible())
|
||||
.child(
|
||||
Icon::new(self.selected_agent.icon())
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.absolute()
|
||||
.invisible()
|
||||
.group_hover("agent-selector", |s| s.visible())
|
||||
.child(Icon::new(IconName::Plus)),
|
||||
),
|
||||
)
|
||||
.child(Label::new(self.selected_agent.label())),
|
||||
),
|
||||
IconButton::new("new_thread_menu_btn", IconName::Plus).icon_size(IconSize::Small),
|
||||
{
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
@@ -2397,15 +2379,24 @@ impl AgentPanel {
|
||||
.size_full()
|
||||
.gap(DynamicSpacing::Base08.rems(cx))
|
||||
.child(match &self.active_view {
|
||||
ActiveView::History | ActiveView::Configuration => {
|
||||
self.render_toolbar_back_button(cx).into_any_element()
|
||||
}
|
||||
ActiveView::History | ActiveView::Configuration => div()
|
||||
.pl(DynamicSpacing::Base04.rems(cx))
|
||||
.child(self.render_toolbar_back_button(cx))
|
||||
.into_any_element(),
|
||||
_ => h_flex()
|
||||
.h_full()
|
||||
.px(DynamicSpacing::Base04.rems(cx))
|
||||
.border_r_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(new_thread_menu)
|
||||
.child(
|
||||
h_flex()
|
||||
.px_0p5()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(self.selected_agent.icon()).color(Color::Muted),
|
||||
)
|
||||
.child(Label::new(self.selected_agent.label())),
|
||||
)
|
||||
.into_any_element(),
|
||||
})
|
||||
.child(self.render_title_view(window, cx)),
|
||||
@@ -2423,6 +2414,7 @@ impl AgentPanel {
|
||||
.pr(DynamicSpacing::Base06.rems(cx))
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(new_thread_menu)
|
||||
.child(self.render_recent_entries_menu(IconName::HistoryRerun, cx))
|
||||
.child(self.render_panel_options_menu(window, cx)),
|
||||
),
|
||||
|
||||
@@ -5,7 +5,6 @@ mod agent_diff;
|
||||
mod agent_model_selector;
|
||||
mod agent_panel;
|
||||
mod buffer_codegen;
|
||||
mod burn_mode_tooltip;
|
||||
mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_strip;
|
||||
@@ -243,7 +242,6 @@ pub fn init(
|
||||
client.telemetry().clone(),
|
||||
cx,
|
||||
);
|
||||
indexed_docs::init(cx);
|
||||
cx.observe_new(move |workspace, window, cx| {
|
||||
ConfigureContextServerModal::register(workspace, language_registry.clone(), window, cx)
|
||||
})
|
||||
@@ -410,12 +408,6 @@ fn update_slash_commands_from_settings(cx: &mut App) {
|
||||
let slash_command_registry = SlashCommandRegistry::global(cx);
|
||||
let settings = SlashCommandSettings::get_global(cx);
|
||||
|
||||
if settings.docs.enabled {
|
||||
slash_command_registry.register_command(assistant_slash_commands::DocsSlashCommand, true);
|
||||
} else {
|
||||
slash_command_registry.unregister_command(assistant_slash_commands::DocsSlashCommand);
|
||||
}
|
||||
|
||||
if settings.cargo_workspace.enabled {
|
||||
slash_command_registry
|
||||
.register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true);
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
||||
use ui::{prelude::*, tooltip_container};
|
||||
|
||||
pub struct BurnModeTooltip {
|
||||
selected: bool,
|
||||
}
|
||||
|
||||
impl BurnModeTooltip {
|
||||
pub fn new() -> Self {
|
||||
Self { selected: false }
|
||||
}
|
||||
|
||||
pub fn selected(mut self, selected: bool) -> Self {
|
||||
self.selected = selected;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for BurnModeTooltip {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let (icon, color) = if self.selected {
|
||||
(IconName::ZedBurnModeOn, Color::Error)
|
||||
} else {
|
||||
(IconName::ZedBurnMode, Color::Default)
|
||||
};
|
||||
|
||||
let turned_on = h_flex()
|
||||
.h_4()
|
||||
.px_1()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().text_accent.opacity(0.1))
|
||||
.rounded_sm()
|
||||
.child(
|
||||
Label::new("ON")
|
||||
.size(LabelSize::XSmall)
|
||||
.weight(FontWeight::SEMIBOLD)
|
||||
.color(Color::Accent),
|
||||
);
|
||||
|
||||
let title = h_flex()
|
||||
.gap_1p5()
|
||||
.child(Icon::new(icon).size(IconSize::Small).color(color))
|
||||
.child(Label::new("Burn Mode"))
|
||||
.when(self.selected, |title| title.child(turned_on));
|
||||
|
||||
tooltip_container(window, cx, |this, _, _| {
|
||||
this
|
||||
.child(title)
|
||||
.child(
|
||||
div()
|
||||
.max_w_64()
|
||||
.child(
|
||||
Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
)
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,18 +1,19 @@
|
||||
mod completion_provider;
|
||||
mod fetch_context_picker;
|
||||
pub(crate) mod fetch_context_picker;
|
||||
pub(crate) mod file_context_picker;
|
||||
mod rules_context_picker;
|
||||
mod symbol_context_picker;
|
||||
mod thread_context_picker;
|
||||
pub(crate) mod rules_context_picker;
|
||||
pub(crate) mod symbol_context_picker;
|
||||
pub(crate) mod thread_context_picker;
|
||||
|
||||
use std::ops::Range;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashSet;
|
||||
pub use completion_provider::ContextPickerCompletionProvider;
|
||||
use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use editor::{Anchor, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use fetch_context_picker::FetchContextPicker;
|
||||
use file_context_picker::FileContextPicker;
|
||||
use file_context_picker::render_file_context_entry;
|
||||
@@ -45,7 +46,7 @@ use agent::{
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerEntry {
|
||||
pub(crate) enum ContextPickerEntry {
|
||||
Mode(ContextPickerMode),
|
||||
Action(ContextPickerAction),
|
||||
}
|
||||
@@ -74,7 +75,7 @@ impl ContextPickerEntry {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerMode {
|
||||
pub(crate) enum ContextPickerMode {
|
||||
File,
|
||||
Symbol,
|
||||
Fetch,
|
||||
@@ -83,7 +84,7 @@ enum ContextPickerMode {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ContextPickerAction {
|
||||
pub(crate) enum ContextPickerAction {
|
||||
AddSelections,
|
||||
}
|
||||
|
||||
@@ -227,7 +228,7 @@ impl ContextPicker {
|
||||
}
|
||||
|
||||
fn build_menu(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Entity<ContextMenu> {
|
||||
let context_picker = cx.entity().clone();
|
||||
let context_picker = cx.entity();
|
||||
|
||||
let menu = ContextMenu::build(window, cx, move |menu, _window, cx| {
|
||||
let recent = self.recent_entries(cx);
|
||||
@@ -531,7 +532,7 @@ impl ContextPicker {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
recent_context_picker_entries(
|
||||
recent_context_picker_entries_with_store(
|
||||
context_store,
|
||||
self.thread_store.clone(),
|
||||
self.text_thread_store.clone(),
|
||||
@@ -585,7 +586,8 @@ impl Render for ContextPicker {
|
||||
})
|
||||
}
|
||||
}
|
||||
enum RecentEntry {
|
||||
|
||||
pub(crate) enum RecentEntry {
|
||||
File {
|
||||
project_path: ProjectPath,
|
||||
path_prefix: Arc<str>,
|
||||
@@ -593,7 +595,7 @@ enum RecentEntry {
|
||||
Thread(ThreadContextEntry),
|
||||
}
|
||||
|
||||
fn available_context_picker_entries(
|
||||
pub(crate) fn available_context_picker_entries(
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
thread_store: &Option<WeakEntity<ThreadStore>>,
|
||||
workspace: &Entity<Workspace>,
|
||||
@@ -630,24 +632,56 @@ fn available_context_picker_entries(
|
||||
entries
|
||||
}
|
||||
|
||||
fn recent_context_picker_entries(
|
||||
fn recent_context_picker_entries_with_store(
|
||||
context_store: Entity<ContextStore>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
text_thread_store: Option<WeakEntity<TextThreadStore>>,
|
||||
workspace: Entity<Workspace>,
|
||||
exclude_path: Option<ProjectPath>,
|
||||
cx: &App,
|
||||
) -> Vec<RecentEntry> {
|
||||
let project = workspace.read(cx).project();
|
||||
|
||||
let mut exclude_paths = context_store.read(cx).file_paths(cx);
|
||||
exclude_paths.extend(exclude_path);
|
||||
|
||||
let exclude_paths = exclude_paths
|
||||
.into_iter()
|
||||
.filter_map(|project_path| project.read(cx).absolute_path(&project_path, cx))
|
||||
.collect();
|
||||
|
||||
let exclude_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
recent_context_picker_entries(
|
||||
thread_store,
|
||||
text_thread_store,
|
||||
workspace,
|
||||
&exclude_paths,
|
||||
exclude_threads,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn recent_context_picker_entries(
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
text_thread_store: Option<WeakEntity<TextThreadStore>>,
|
||||
workspace: Entity<Workspace>,
|
||||
exclude_paths: &HashSet<PathBuf>,
|
||||
exclude_threads: &HashSet<ThreadId>,
|
||||
cx: &App,
|
||||
) -> Vec<RecentEntry> {
|
||||
let mut recent = Vec::with_capacity(6);
|
||||
let mut current_files = context_store.read(cx).file_paths(cx);
|
||||
current_files.extend(exclude_path);
|
||||
let workspace = workspace.read(cx);
|
||||
let project = workspace.project().read(cx);
|
||||
|
||||
recent.extend(
|
||||
workspace
|
||||
.recent_navigation_history_iter(cx)
|
||||
.filter(|(path, _)| !current_files.contains(path))
|
||||
.filter(|(_, abs_path)| {
|
||||
abs_path
|
||||
.as_ref()
|
||||
.map_or(true, |path| !exclude_paths.contains(path.as_path()))
|
||||
})
|
||||
.take(4)
|
||||
.filter_map(|(project_path, _)| {
|
||||
project
|
||||
@@ -659,8 +693,6 @@ fn recent_context_picker_entries(
|
||||
}),
|
||||
);
|
||||
|
||||
let current_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
let active_thread_id = workspace
|
||||
.panel::<AgentPanel>(cx)
|
||||
.and_then(|panel| Some(panel.read(cx).active_thread(cx)?.read(cx).id()));
|
||||
@@ -672,7 +704,7 @@ fn recent_context_picker_entries(
|
||||
let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx)
|
||||
.filter(|(_, thread)| match thread {
|
||||
ThreadContextEntry::Thread { id, .. } => {
|
||||
Some(id) != active_thread_id && !current_threads.contains(id)
|
||||
Some(id) != active_thread_id && !exclude_threads.contains(id)
|
||||
}
|
||||
ThreadContextEntry::Context { .. } => true,
|
||||
})
|
||||
@@ -710,7 +742,7 @@ fn add_selections_as_context(
|
||||
})
|
||||
}
|
||||
|
||||
fn selection_ranges(
|
||||
pub(crate) fn selection_ranges(
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> Vec<(Entity<Buffer>, Range<text::Anchor>)> {
|
||||
@@ -805,42 +837,9 @@ fn render_fold_icon_button(
|
||||
) -> Arc<dyn Send + Sync + Fn(FoldId, Range<Anchor>, &mut App) -> AnyElement> {
|
||||
Arc::new({
|
||||
move |fold_id, fold_range, cx| {
|
||||
let is_in_text_selection = editor.upgrade().is_some_and(|editor| {
|
||||
editor.update(cx, |editor, cx| {
|
||||
let snapshot = editor
|
||||
.buffer()
|
||||
.update(cx, |multi_buffer, cx| multi_buffer.snapshot(cx));
|
||||
|
||||
let is_in_pending_selection = || {
|
||||
editor
|
||||
.selections
|
||||
.pending
|
||||
.as_ref()
|
||||
.is_some_and(|pending_selection| {
|
||||
pending_selection
|
||||
.selection
|
||||
.range()
|
||||
.includes(&fold_range, &snapshot)
|
||||
})
|
||||
};
|
||||
|
||||
let mut is_in_complete_selection = || {
|
||||
editor
|
||||
.selections
|
||||
.disjoint_in_range::<usize>(fold_range.clone(), cx)
|
||||
.into_iter()
|
||||
.any(|selection| {
|
||||
// This is needed to cover a corner case, if we just check for an existing
|
||||
// selection in the fold range, having a cursor at the start of the fold
|
||||
// marks it as selected. Non-empty selections don't cause this.
|
||||
let length = selection.end - selection.start;
|
||||
length > 0
|
||||
})
|
||||
};
|
||||
|
||||
is_in_pending_selection() || is_in_complete_selection()
|
||||
})
|
||||
});
|
||||
let is_in_text_selection = editor
|
||||
.update(cx, |editor, cx| editor.is_range_selected(&fold_range, cx))
|
||||
.unwrap_or_default();
|
||||
|
||||
ButtonLike::new(fold_id)
|
||||
.style(ButtonStyle::Filled)
|
||||
|
||||
@@ -35,7 +35,7 @@ use super::symbol_context_picker::search_symbols;
|
||||
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
|
||||
use super::{
|
||||
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
|
||||
available_context_picker_entries, recent_context_picker_entries, selection_ranges,
|
||||
available_context_picker_entries, recent_context_picker_entries_with_store, selection_ranges,
|
||||
};
|
||||
use crate::message_editor::ContextCreasesAddon;
|
||||
|
||||
@@ -787,7 +787,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
.and_then(|b| b.read(cx).file())
|
||||
.map(|file| ProjectPath::from_file(file.as_ref(), cx));
|
||||
|
||||
let recent_entries = recent_context_picker_entries(
|
||||
let recent_entries = recent_context_picker_entries_with_store(
|
||||
context_store.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
|
||||
@@ -72,7 +72,7 @@ pub fn init(
|
||||
let Some(window) = window else {
|
||||
return;
|
||||
};
|
||||
let workspace = cx.entity().clone();
|
||||
let workspace = cx.entity();
|
||||
InlineAssistant::update_global(cx, |inline_assistant, cx| {
|
||||
inline_assistant.register_workspace(&workspace, window, cx)
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::{cmp::Reverse, sync::Arc};
|
||||
|
||||
use cloud_llm_client::Plan;
|
||||
use collections::{HashSet, IndexMap};
|
||||
use feature_flags::ZedProFeatureFlag;
|
||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||
@@ -10,7 +11,6 @@ use language_model::{
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use proto::Plan;
|
||||
use ui::{ListItem, ListItemSpacing, prelude::*};
|
||||
|
||||
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
|
||||
@@ -536,7 +536,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
) -> Option<gpui::AnyElement> {
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
|
||||
let plan = proto::Plan::ZedPro;
|
||||
let plan = Plan::ZedPro;
|
||||
|
||||
Some(
|
||||
h_flex()
|
||||
@@ -557,7 +557,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
window
|
||||
.dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
|
||||
}),
|
||||
Plan::Free | Plan::ZedProTrial => Button::new(
|
||||
Plan::ZedFree | Plan::ZedProTrial => Button::new(
|
||||
"try-pro",
|
||||
if plan == Plan::ZedProTrial {
|
||||
"Upgrade to Pro"
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread;
|
||||
use crate::agent_model_selector::AgentModelSelector;
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use crate::ui::{
|
||||
MaxModeTooltip,
|
||||
BurnModeTooltip,
|
||||
preview::{AgentPreview, UsageCallout},
|
||||
};
|
||||
use agent::history_store::HistoryStore;
|
||||
@@ -14,7 +14,7 @@ use agent::{
|
||||
context::{AgentContextKey, ContextLoadResult, load_context},
|
||||
context_store::ContextStoreEvent,
|
||||
};
|
||||
use agent_settings::{AgentSettings, CompletionMode};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
|
||||
use ai_onboarding::ApiKeysWithProviders;
|
||||
use buffer_diff::BufferDiff;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
@@ -55,7 +55,7 @@ use zed_actions::agent::ToggleModelSelector;
|
||||
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::profile_selector::ProfileSelector;
|
||||
use crate::profile_selector::{ProfileProvider, ProfileSelector};
|
||||
use crate::{
|
||||
ActiveThread, AgentDiffPane, ChatWithFollow, ExpandMessageEditor, Follow, KeepAll,
|
||||
ModelUsageContext, NewThread, OpenAgentDiff, RejectAll, RemoveAllContext, ToggleBurnMode,
|
||||
@@ -152,6 +152,24 @@ pub(crate) fn create_editor(
|
||||
editor
|
||||
}
|
||||
|
||||
impl ProfileProvider for Entity<Thread> {
|
||||
fn profiles_supported(&self, cx: &App) -> bool {
|
||||
self.read(cx)
|
||||
.configured_model()
|
||||
.map_or(false, |model| model.model.supports_tools())
|
||||
}
|
||||
|
||||
fn profile_id(&self, cx: &App) -> AgentProfileId {
|
||||
self.read(cx).profile().id().clone()
|
||||
}
|
||||
|
||||
fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) {
|
||||
self.update(cx, |this, cx| {
|
||||
this.set_profile(profile_id, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl MessageEditor {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -221,8 +239,9 @@ impl MessageEditor {
|
||||
)
|
||||
});
|
||||
|
||||
let profile_selector =
|
||||
cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx));
|
||||
let profile_selector = cx.new(|cx| {
|
||||
ProfileSelector::new(fs, Arc::new(thread.clone()), editor.focus_handle(cx), cx)
|
||||
});
|
||||
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
@@ -605,7 +624,7 @@ impl MessageEditor {
|
||||
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
|
||||
}))
|
||||
.tooltip(move |_window, cx| {
|
||||
cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
|
||||
cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
|
||||
.into()
|
||||
})
|
||||
.into_any_element(),
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
use crate::{ManageProfiles, ToggleProfileSelector};
|
||||
use agent::{
|
||||
Thread,
|
||||
agent_profile::{AgentProfile, AvailableProfiles},
|
||||
};
|
||||
use agent::agent_profile::{AgentProfile, AvailableProfiles};
|
||||
use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
|
||||
use fs::Fs;
|
||||
use gpui::{Action, Empty, Entity, FocusHandle, Subscription, prelude::*};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use gpui::{Action, Entity, FocusHandle, Subscription, prelude::*};
|
||||
use settings::{Settings as _, SettingsStore, update_settings_file};
|
||||
use std::sync::Arc;
|
||||
use ui::{
|
||||
@@ -14,10 +10,22 @@ use ui::{
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
/// Trait for types that can provide and manage agent profiles
|
||||
pub trait ProfileProvider {
|
||||
/// Get the current profile ID
|
||||
fn profile_id(&self, cx: &App) -> AgentProfileId;
|
||||
|
||||
/// Set the profile ID
|
||||
fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App);
|
||||
|
||||
/// Check if profiles are supported in the current context (e.g. if the model that is selected has tool support)
|
||||
fn profiles_supported(&self, cx: &App) -> bool;
|
||||
}
|
||||
|
||||
pub struct ProfileSelector {
|
||||
profiles: AvailableProfiles,
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
provider: Arc<dyn ProfileProvider>,
|
||||
menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
focus_handle: FocusHandle,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
@@ -26,7 +34,7 @@ pub struct ProfileSelector {
|
||||
impl ProfileSelector {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
provider: Arc<dyn ProfileProvider>,
|
||||
focus_handle: FocusHandle,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -37,7 +45,7 @@ impl ProfileSelector {
|
||||
Self {
|
||||
profiles: AgentProfile::available_profiles(cx),
|
||||
fs,
|
||||
thread,
|
||||
provider,
|
||||
menu_handle: PopoverMenuHandle::default(),
|
||||
focus_handle,
|
||||
_subscriptions: vec![settings_subscription],
|
||||
@@ -113,10 +121,10 @@ impl ProfileSelector {
|
||||
builtin_profiles::MINIMAL => Some("Chat about anything with no tools."),
|
||||
_ => None,
|
||||
};
|
||||
let thread_profile_id = self.thread.read(cx).profile().id();
|
||||
let thread_profile_id = self.provider.profile_id(cx);
|
||||
|
||||
let entry = ContextMenuEntry::new(profile_name.clone())
|
||||
.toggleable(IconPosition::End, &profile_id == thread_profile_id);
|
||||
.toggleable(IconPosition::End, profile_id == thread_profile_id);
|
||||
|
||||
let entry = if let Some(doc_text) = documentation {
|
||||
entry.documentation_aside(documentation_side(settings.dock), move |_| {
|
||||
@@ -128,7 +136,7 @@ impl ProfileSelector {
|
||||
|
||||
entry.handler({
|
||||
let fs = self.fs.clone();
|
||||
let thread = self.thread.clone();
|
||||
let provider = self.provider.clone();
|
||||
let profile_id = profile_id.clone();
|
||||
move |_window, cx| {
|
||||
update_settings_file::<AgentSettings>(fs.clone(), cx, {
|
||||
@@ -138,9 +146,7 @@ impl ProfileSelector {
|
||||
}
|
||||
});
|
||||
|
||||
thread.update(cx, |this, cx| {
|
||||
this.set_profile(profile_id.clone(), cx);
|
||||
});
|
||||
provider.set_profile(profile_id.clone(), cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -149,23 +155,15 @@ impl ProfileSelector {
|
||||
impl Render for ProfileSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
let profile_id = self.thread.read(cx).profile().id();
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
let profile_id = self.provider.profile_id(cx);
|
||||
let profile = settings.profiles.get(&profile_id);
|
||||
|
||||
let selected_profile = profile
|
||||
.map(|profile| profile.name.clone())
|
||||
.unwrap_or_else(|| "Unknown".into());
|
||||
|
||||
let configured_model = self.thread.read(cx).configured_model().or_else(|| {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
model_registry.default_model()
|
||||
});
|
||||
let Some(configured_model) = configured_model else {
|
||||
return Empty.into_any_element();
|
||||
};
|
||||
|
||||
if configured_model.model.supports_tools() {
|
||||
let this = cx.entity().clone();
|
||||
if self.provider.profiles_supported(cx) {
|
||||
let this = cx.entity();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
let trigger_button = Button::new("profile-selector-model", selected_profile)
|
||||
.label_size(LabelSize::Small)
|
||||
|
||||
@@ -7,22 +7,11 @@ use settings::{Settings, SettingsSources};
|
||||
/// Settings for slash commands.
|
||||
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
|
||||
pub struct SlashCommandSettings {
|
||||
/// Settings for the `/docs` slash command.
|
||||
#[serde(default)]
|
||||
pub docs: DocsCommandSettings,
|
||||
/// Settings for the `/cargo-workspace` slash command.
|
||||
#[serde(default)]
|
||||
pub cargo_workspace: CargoWorkspaceCommandSettings,
|
||||
}
|
||||
|
||||
/// Settings for the `/docs` slash command.
|
||||
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
|
||||
pub struct DocsCommandSettings {
|
||||
/// Whether `/docs` is enabled.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
/// Settings for the `/cargo-workspace` slash command.
|
||||
#[derive(Deserialize, Serialize, Debug, Default, Clone, JsonSchema)]
|
||||
pub struct CargoWorkspaceCommandSettings {
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
use crate::{
|
||||
burn_mode_tooltip::BurnModeTooltip,
|
||||
language_model_selector::{LanguageModelSelector, language_model_selector},
|
||||
ui::BurnModeTooltip,
|
||||
};
|
||||
use agent_settings::{AgentSettings, CompletionMode};
|
||||
use anyhow::Result;
|
||||
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection, SlashCommandWorkingSet};
|
||||
use assistant_slash_commands::{
|
||||
DefaultSlashCommand, DocsSlashCommand, DocsSlashCommandArgs, FileSlashCommand,
|
||||
selections_creases,
|
||||
};
|
||||
use assistant_slash_commands::{DefaultSlashCommand, FileSlashCommand, selections_creases};
|
||||
use client::{proto, zed_urls};
|
||||
use collections::{BTreeSet, HashMap, HashSet, hash_map};
|
||||
use editor::{
|
||||
@@ -30,7 +27,6 @@ use gpui::{
|
||||
StatefulInteractiveElement, Styled, Subscription, Task, Transformation, WeakEntity, actions,
|
||||
div, img, percentage, point, prelude::*, pulsating_between, size,
|
||||
};
|
||||
use indexed_docs::IndexedDocsStore;
|
||||
use language::{
|
||||
BufferSnapshot, LspAdapterDelegate, ToOffset,
|
||||
language_settings::{SoftWrap, all_language_settings},
|
||||
@@ -77,7 +73,7 @@ use crate::{slash_command::SlashCommandCompletionProvider, slash_command_picker}
|
||||
use assistant_context::{
|
||||
AssistantContext, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId,
|
||||
InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, MessageStatus,
|
||||
ParsedSlashCommand, PendingSlashCommandStatus, ThoughtProcessOutputSection,
|
||||
PendingSlashCommandStatus, ThoughtProcessOutputSection,
|
||||
};
|
||||
|
||||
actions!(
|
||||
@@ -701,19 +697,7 @@ impl TextThreadEditor {
|
||||
}
|
||||
};
|
||||
let render_trailer = {
|
||||
let command = command.clone();
|
||||
move |row, _unfold, _window: &mut Window, cx: &mut App| {
|
||||
// TODO: In the future we should investigate how we can expose
|
||||
// this as a hook on the `SlashCommand` trait so that we don't
|
||||
// need to special-case it here.
|
||||
if command.name == DocsSlashCommand::NAME {
|
||||
return render_docs_slash_command_trailer(
|
||||
row,
|
||||
command.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
move |_row, _unfold, _window: &mut Window, _cx: &mut App| {
|
||||
Empty.into_any()
|
||||
}
|
||||
};
|
||||
@@ -2398,70 +2382,6 @@ fn render_pending_slash_command_gutter_decoration(
|
||||
icon.into_any_element()
|
||||
}
|
||||
|
||||
fn render_docs_slash_command_trailer(
|
||||
row: MultiBufferRow,
|
||||
command: ParsedSlashCommand,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
if command.arguments.is_empty() {
|
||||
return Empty.into_any();
|
||||
}
|
||||
let args = DocsSlashCommandArgs::parse(&command.arguments);
|
||||
|
||||
let Some(store) = args
|
||||
.provider()
|
||||
.and_then(|provider| IndexedDocsStore::try_global(provider, cx).ok())
|
||||
else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
|
||||
let Some(package) = args.package() else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
|
||||
let mut children = Vec::new();
|
||||
|
||||
if store.is_indexing(&package) {
|
||||
children.push(
|
||||
div()
|
||||
.id(("crates-being-indexed", row.0))
|
||||
.child(Icon::new(IconName::ArrowCircle).with_animation(
|
||||
"arrow-circle",
|
||||
Animation::new(Duration::from_secs(4)).repeat(),
|
||||
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||
))
|
||||
.tooltip({
|
||||
let package = package.clone();
|
||||
Tooltip::text(format!("Indexing {package}…"))
|
||||
})
|
||||
.into_any_element(),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(latest_error) = store.latest_error_for_package(&package) {
|
||||
children.push(
|
||||
div()
|
||||
.id(("latest-error", row.0))
|
||||
.child(
|
||||
Icon::new(IconName::Warning)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Warning),
|
||||
)
|
||||
.tooltip(Tooltip::text(format!("Failed to index: {latest_error}")))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
let is_indexing = store.is_indexing(&package);
|
||||
let latest_error = store.latest_error_for_package(&package);
|
||||
|
||||
if !is_indexing && latest_error.is_none() {
|
||||
return Empty.into_any();
|
||||
}
|
||||
|
||||
h_flex().gap_2().children(children).into_any_element()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct CopyMetadata {
|
||||
creases: Vec<SelectedCreaseMetadata>,
|
||||
|
||||
@@ -2,11 +2,11 @@ use crate::ToggleBurnMode;
|
||||
use gpui::{Context, FontWeight, IntoElement, Render, Window};
|
||||
use ui::{KeyBinding, prelude::*, tooltip_container};
|
||||
|
||||
pub struct MaxModeTooltip {
|
||||
pub struct BurnModeTooltip {
|
||||
selected: bool,
|
||||
}
|
||||
|
||||
impl MaxModeTooltip {
|
||||
impl BurnModeTooltip {
|
||||
pub fn new() -> Self {
|
||||
Self { selected: false }
|
||||
}
|
||||
@@ -17,7 +17,7 @@ impl MaxModeTooltip {
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for MaxModeTooltip {
|
||||
impl Render for BurnModeTooltip {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let (icon, color) = if self.selected {
|
||||
(IconName::ZedBurnModeOn, Color::Error)
|
||||
|
||||
@@ -58,9 +58,7 @@ impl Assets {
|
||||
pub fn load_test_fonts(&self, cx: &App) {
|
||||
cx.text_system()
|
||||
.add_fonts(vec![
|
||||
self.load("fonts/plex-mono/ZedPlexMono-Regular.ttf")
|
||||
.unwrap()
|
||||
.unwrap(),
|
||||
self.load("fonts/lilex/Lilex-Regular.ttf").unwrap().unwrap(),
|
||||
])
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,9 @@ workspace = true
|
||||
[lib]
|
||||
path = "src/assistant_context.rs"
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
|
||||
[dependencies]
|
||||
agent_settings.workspace = true
|
||||
anyhow.workspace = true
|
||||
|
||||
@@ -138,6 +138,27 @@ impl ContextStore {
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
contexts: Default::default(),
|
||||
contexts_metadata: Default::default(),
|
||||
context_server_slash_command_ids: Default::default(),
|
||||
host_contexts: Default::default(),
|
||||
fs: project.read(cx).fs().clone(),
|
||||
languages: project.read(cx).languages().clone(),
|
||||
slash_commands: Arc::default(),
|
||||
telemetry: project.read(cx).client().telemetry().clone(),
|
||||
_watch_updates: Task::ready(None),
|
||||
client: project.read(cx).client(),
|
||||
project,
|
||||
project_is_shared: false,
|
||||
client_subscription: None,
|
||||
_project_subscriptions: Default::default(),
|
||||
prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_advertise_contexts(
|
||||
this: Entity<Self>,
|
||||
envelope: TypedEnvelope<proto::AdvertiseContexts>,
|
||||
|
||||
@@ -27,7 +27,6 @@ globset.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
language.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
|
||||
@@ -3,7 +3,6 @@ mod context_server_command;
|
||||
mod default_command;
|
||||
mod delta_command;
|
||||
mod diagnostics_command;
|
||||
mod docs_command;
|
||||
mod fetch_command;
|
||||
mod file_command;
|
||||
mod now_command;
|
||||
@@ -18,7 +17,6 @@ pub use crate::context_server_command::*;
|
||||
pub use crate::default_command::*;
|
||||
pub use crate::delta_command::*;
|
||||
pub use crate::diagnostics_command::*;
|
||||
pub use crate::docs_command::*;
|
||||
pub use crate::fetch_command::*;
|
||||
pub use crate::file_command::*;
|
||||
pub use crate::now_command::*;
|
||||
|
||||
@@ -1,543 +0,0 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use assistant_slash_command::{
|
||||
ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
|
||||
SlashCommandResult,
|
||||
};
|
||||
use gpui::{App, BackgroundExecutor, Entity, Task, WeakEntity};
|
||||
use indexed_docs::{
|
||||
DocsDotRsProvider, IndexedDocsRegistry, IndexedDocsStore, LocalRustdocProvider, PackageName,
|
||||
ProviderId,
|
||||
};
|
||||
use language::{BufferSnapshot, LspAdapterDelegate};
|
||||
use project::{Project, ProjectPath};
|
||||
use ui::prelude::*;
|
||||
use util::{ResultExt, maybe};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub struct DocsSlashCommand;
|
||||
|
||||
impl DocsSlashCommand {
|
||||
pub const NAME: &'static str = "docs";
|
||||
|
||||
fn path_to_cargo_toml(project: Entity<Project>, cx: &mut App) -> Option<Arc<Path>> {
|
||||
let worktree = project.read(cx).worktrees(cx).next()?;
|
||||
let worktree = worktree.read(cx);
|
||||
let entry = worktree.entry_for_path("Cargo.toml")?;
|
||||
let path = ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: entry.path.clone(),
|
||||
};
|
||||
Some(Arc::from(
|
||||
project.read(cx).absolute_path(&path, cx)?.as_path(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Ensures that the indexed doc providers for Rust are registered.
|
||||
///
|
||||
/// Ideally we would do this sooner, but we need to wait until we're able to
|
||||
/// access the workspace so we can read the project.
|
||||
fn ensure_rust_doc_providers_are_registered(
|
||||
&self,
|
||||
workspace: Option<WeakEntity<Workspace>>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let indexed_docs_registry = IndexedDocsRegistry::global(cx);
|
||||
if indexed_docs_registry
|
||||
.get_provider_store(LocalRustdocProvider::id())
|
||||
.is_none()
|
||||
{
|
||||
let index_provider_deps = maybe!({
|
||||
let workspace = workspace
|
||||
.as_ref()
|
||||
.context("no workspace")?
|
||||
.upgrade()
|
||||
.context("workspace dropped")?;
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let fs = project.read(cx).fs().clone();
|
||||
let cargo_workspace_root = Self::path_to_cargo_toml(project, cx)
|
||||
.and_then(|path| path.parent().map(|path| path.to_path_buf()))
|
||||
.context("no Cargo workspace root found")?;
|
||||
|
||||
anyhow::Ok((fs, cargo_workspace_root))
|
||||
});
|
||||
|
||||
if let Some((fs, cargo_workspace_root)) = index_provider_deps.log_err() {
|
||||
indexed_docs_registry.register_provider(Box::new(LocalRustdocProvider::new(
|
||||
fs,
|
||||
cargo_workspace_root,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if indexed_docs_registry
|
||||
.get_provider_store(DocsDotRsProvider::id())
|
||||
.is_none()
|
||||
{
|
||||
let http_client = maybe!({
|
||||
let workspace = workspace
|
||||
.as_ref()
|
||||
.context("no workspace")?
|
||||
.upgrade()
|
||||
.context("workspace was dropped")?;
|
||||
let project = workspace.read(cx).project().clone();
|
||||
anyhow::Ok(project.read(cx).client().http_client())
|
||||
});
|
||||
|
||||
if let Some(http_client) = http_client.log_err() {
|
||||
indexed_docs_registry
|
||||
.register_provider(Box::new(DocsDotRsProvider::new(http_client)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs just-in-time indexing for a given package, in case the slash command
|
||||
/// is run without any entries existing in the index.
|
||||
fn run_just_in_time_indexing(
|
||||
store: Arc<IndexedDocsStore>,
|
||||
key: String,
|
||||
package: PackageName,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Task<()> {
|
||||
executor.clone().spawn(async move {
|
||||
let (prefix, needs_full_index) = if let Some((prefix, _)) = key.split_once('*') {
|
||||
// If we have a wildcard in the search, we want to wait until
|
||||
// we've completely finished indexing so we get a full set of
|
||||
// results for the wildcard.
|
||||
(prefix.to_string(), true)
|
||||
} else {
|
||||
(key, false)
|
||||
};
|
||||
|
||||
// If we already have some entries, we assume that we've indexed the package before
|
||||
// and don't need to do it again.
|
||||
let has_any_entries = store
|
||||
.any_with_prefix(prefix.clone())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if has_any_entries {
|
||||
return ();
|
||||
};
|
||||
|
||||
let index_task = store.clone().index(package.clone());
|
||||
|
||||
if needs_full_index {
|
||||
_ = index_task.await;
|
||||
} else {
|
||||
loop {
|
||||
executor.timer(Duration::from_millis(200)).await;
|
||||
|
||||
if store
|
||||
.any_with_prefix(prefix.clone())
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
|| !store.is_indexing(&package)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SlashCommand for DocsSlashCommand {
|
||||
fn name(&self) -> String {
|
||||
Self::NAME.into()
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"insert docs".into()
|
||||
}
|
||||
|
||||
fn menu_text(&self) -> String {
|
||||
"Insert Documentation".into()
|
||||
}
|
||||
|
||||
fn requires_argument(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn complete_argument(
|
||||
self: Arc<Self>,
|
||||
arguments: &[String],
|
||||
_cancel: Arc<AtomicBool>,
|
||||
workspace: Option<WeakEntity<Workspace>>,
|
||||
_: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Vec<ArgumentCompletion>>> {
|
||||
self.ensure_rust_doc_providers_are_registered(workspace, cx);
|
||||
|
||||
let indexed_docs_registry = IndexedDocsRegistry::global(cx);
|
||||
let args = DocsSlashCommandArgs::parse(arguments);
|
||||
let store = args
|
||||
.provider()
|
||||
.context("no docs provider specified")
|
||||
.and_then(|provider| IndexedDocsStore::try_global(provider, cx));
|
||||
cx.background_spawn(async move {
|
||||
fn build_completions(items: Vec<String>) -> Vec<ArgumentCompletion> {
|
||||
items
|
||||
.into_iter()
|
||||
.map(|item| ArgumentCompletion {
|
||||
label: item.clone().into(),
|
||||
new_text: item.to_string(),
|
||||
after_completion: assistant_slash_command::AfterCompletion::Run,
|
||||
replace_previous_arguments: false,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
match args {
|
||||
DocsSlashCommandArgs::NoProvider => {
|
||||
let providers = indexed_docs_registry.list_providers();
|
||||
if providers.is_empty() {
|
||||
return Ok(vec![ArgumentCompletion {
|
||||
label: "No available docs providers.".into(),
|
||||
new_text: String::new(),
|
||||
after_completion: false.into(),
|
||||
replace_previous_arguments: false,
|
||||
}]);
|
||||
}
|
||||
|
||||
Ok(providers
|
||||
.into_iter()
|
||||
.map(|provider| ArgumentCompletion {
|
||||
label: provider.to_string().into(),
|
||||
new_text: provider.to_string(),
|
||||
after_completion: false.into(),
|
||||
replace_previous_arguments: false,
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider,
|
||||
package,
|
||||
index,
|
||||
} => {
|
||||
let store = store?;
|
||||
|
||||
if index {
|
||||
// We don't need to hold onto this task, as the `IndexedDocsStore` will hold it
|
||||
// until it completes.
|
||||
drop(store.clone().index(package.as_str().into()));
|
||||
}
|
||||
|
||||
let suggested_packages = store.clone().suggest_packages().await?;
|
||||
let search_results = store.search(package).await;
|
||||
|
||||
let mut items = build_completions(search_results);
|
||||
let workspace_crate_completions = suggested_packages
|
||||
.into_iter()
|
||||
.filter(|package_name| {
|
||||
!items
|
||||
.iter()
|
||||
.any(|item| item.label.text() == package_name.as_ref())
|
||||
})
|
||||
.map(|package_name| ArgumentCompletion {
|
||||
label: format!("{package_name} (unindexed)").into(),
|
||||
new_text: format!("{package_name}"),
|
||||
after_completion: true.into(),
|
||||
replace_previous_arguments: false,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
items.extend(workspace_crate_completions);
|
||||
|
||||
if items.is_empty() {
|
||||
return Ok(vec![ArgumentCompletion {
|
||||
label: format!(
|
||||
"Enter a {package_term} name.",
|
||||
package_term = package_term(&provider)
|
||||
)
|
||||
.into(),
|
||||
new_text: provider.to_string(),
|
||||
after_completion: false.into(),
|
||||
replace_previous_arguments: false,
|
||||
}]);
|
||||
}
|
||||
|
||||
Ok(items)
|
||||
}
|
||||
DocsSlashCommandArgs::SearchItemDocs { item_path, .. } => {
|
||||
let store = store?;
|
||||
let items = store.search(item_path).await;
|
||||
Ok(build_completions(items))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
arguments: &[String],
|
||||
_context_slash_command_output_sections: &[SlashCommandOutputSection<language::Anchor>],
|
||||
_context_buffer: BufferSnapshot,
|
||||
_workspace: WeakEntity<Workspace>,
|
||||
_delegate: Option<Arc<dyn LspAdapterDelegate>>,
|
||||
_: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Task<SlashCommandResult> {
|
||||
if arguments.is_empty() {
|
||||
return Task::ready(Err(anyhow!("missing an argument")));
|
||||
};
|
||||
|
||||
let args = DocsSlashCommandArgs::parse(arguments);
|
||||
let executor = cx.background_executor().clone();
|
||||
let task = cx.background_spawn({
|
||||
let store = args
|
||||
.provider()
|
||||
.context("no docs provider specified")
|
||||
.and_then(|provider| IndexedDocsStore::try_global(provider, cx));
|
||||
async move {
|
||||
let (provider, key) = match args.clone() {
|
||||
DocsSlashCommandArgs::NoProvider => bail!("no docs provider specified"),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider, package, ..
|
||||
} => (provider, package),
|
||||
DocsSlashCommandArgs::SearchItemDocs {
|
||||
provider,
|
||||
item_path,
|
||||
..
|
||||
} => (provider, item_path),
|
||||
};
|
||||
|
||||
if key.trim().is_empty() {
|
||||
bail!(
|
||||
"no {package_term} name provided",
|
||||
package_term = package_term(&provider)
|
||||
);
|
||||
}
|
||||
|
||||
let store = store?;
|
||||
|
||||
if let Some(package) = args.package() {
|
||||
Self::run_just_in_time_indexing(store.clone(), key.clone(), package, executor)
|
||||
.await;
|
||||
}
|
||||
|
||||
let (text, ranges) = if let Some((prefix, _)) = key.split_once('*') {
|
||||
let docs = store.load_many_by_prefix(prefix.to_string()).await?;
|
||||
|
||||
let mut text = String::new();
|
||||
let mut ranges = Vec::new();
|
||||
|
||||
for (key, docs) in docs {
|
||||
let prev_len = text.len();
|
||||
|
||||
text.push_str(&docs.0);
|
||||
text.push_str("\n");
|
||||
ranges.push((key, prev_len..text.len()));
|
||||
text.push_str("\n");
|
||||
}
|
||||
|
||||
(text, ranges)
|
||||
} else {
|
||||
let item_docs = store.load(key.clone()).await?;
|
||||
let text = item_docs.to_string();
|
||||
let range = 0..text.len();
|
||||
|
||||
(text, vec![(key, range)])
|
||||
};
|
||||
|
||||
anyhow::Ok((provider, text, ranges))
|
||||
}
|
||||
});
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let (provider, text, ranges) = task.await?;
|
||||
Ok(SlashCommandOutput {
|
||||
text,
|
||||
sections: ranges
|
||||
.into_iter()
|
||||
.map(|(key, range)| SlashCommandOutputSection {
|
||||
range,
|
||||
icon: IconName::FileDoc,
|
||||
label: format!("docs ({provider}): {key}",).into(),
|
||||
metadata: None,
|
||||
})
|
||||
.collect(),
|
||||
run_commands_in_text: false,
|
||||
}
|
||||
.to_event_stream())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn is_item_path_delimiter(char: char) -> bool {
|
||||
!char.is_alphanumeric() && char != '-' && char != '_'
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub enum DocsSlashCommandArgs {
|
||||
NoProvider,
|
||||
SearchPackageDocs {
|
||||
provider: ProviderId,
|
||||
package: String,
|
||||
index: bool,
|
||||
},
|
||||
SearchItemDocs {
|
||||
provider: ProviderId,
|
||||
package: String,
|
||||
item_path: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl DocsSlashCommandArgs {
|
||||
pub fn parse(arguments: &[String]) -> Self {
|
||||
let Some(provider) = arguments
|
||||
.get(0)
|
||||
.cloned()
|
||||
.filter(|arg| !arg.trim().is_empty())
|
||||
else {
|
||||
return Self::NoProvider;
|
||||
};
|
||||
let provider = ProviderId(provider.into());
|
||||
let Some(argument) = arguments.get(1) else {
|
||||
return Self::NoProvider;
|
||||
};
|
||||
|
||||
if let Some((package, rest)) = argument.split_once(is_item_path_delimiter) {
|
||||
if rest.trim().is_empty() {
|
||||
Self::SearchPackageDocs {
|
||||
provider,
|
||||
package: package.to_owned(),
|
||||
index: true,
|
||||
}
|
||||
} else {
|
||||
Self::SearchItemDocs {
|
||||
provider,
|
||||
package: package.to_owned(),
|
||||
item_path: argument.to_owned(),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Self::SearchPackageDocs {
|
||||
provider,
|
||||
package: argument.to_owned(),
|
||||
index: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn provider(&self) -> Option<ProviderId> {
|
||||
match self {
|
||||
Self::NoProvider => None,
|
||||
Self::SearchPackageDocs { provider, .. } | Self::SearchItemDocs { provider, .. } => {
|
||||
Some(provider.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn package(&self) -> Option<PackageName> {
|
||||
match self {
|
||||
Self::NoProvider => None,
|
||||
Self::SearchPackageDocs { package, .. } | Self::SearchItemDocs { package, .. } => {
|
||||
Some(package.as_str().into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the term used to refer to a package.
|
||||
fn package_term(provider: &ProviderId) -> &'static str {
|
||||
if provider == &DocsDotRsProvider::id() || provider == &LocalRustdocProvider::id() {
|
||||
return "crate";
|
||||
}
|
||||
|
||||
"package"
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_docs_slash_command_args() {
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["".to_string()]),
|
||||
DocsSlashCommandArgs::NoProvider
|
||||
);
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["rustdoc".to_string()]),
|
||||
DocsSlashCommandArgs::NoProvider
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "".to_string()]),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider: ProviderId("rustdoc".into()),
|
||||
package: "".into(),
|
||||
index: false
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["gleam".to_string(), "".to_string()]),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider: ProviderId("gleam".into()),
|
||||
package: "".into(),
|
||||
index: false
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui".to_string()]),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider: ProviderId("rustdoc".into()),
|
||||
package: "gpui".into(),
|
||||
index: false,
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib".to_string()]),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider: ProviderId("gleam".into()),
|
||||
package: "gleam_stdlib".into(),
|
||||
index: false
|
||||
}
|
||||
);
|
||||
|
||||
// Adding an item path delimiter indicates we can start indexing.
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["rustdoc".to_string(), "gpui:".to_string()]),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider: ProviderId("rustdoc".into()),
|
||||
package: "gpui".into(),
|
||||
index: true,
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&["gleam".to_string(), "gleam_stdlib/".to_string()]),
|
||||
DocsSlashCommandArgs::SearchPackageDocs {
|
||||
provider: ProviderId("gleam".into()),
|
||||
package: "gleam_stdlib".into(),
|
||||
index: true
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&[
|
||||
"rustdoc".to_string(),
|
||||
"gpui::foo::bar::Baz".to_string()
|
||||
]),
|
||||
DocsSlashCommandArgs::SearchItemDocs {
|
||||
provider: ProviderId("rustdoc".into()),
|
||||
package: "gpui".into(),
|
||||
item_path: "gpui::foo::bar::Baz".into()
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
DocsSlashCommandArgs::parse(&[
|
||||
"gleam".to_string(),
|
||||
"gleam_stdlib/gleam/int".to_string()
|
||||
]),
|
||||
DocsSlashCommandArgs::SearchItemDocs {
|
||||
provider: ProviderId("gleam".into()),
|
||||
package: "gleam_stdlib".into(),
|
||||
item_path: "gleam_stdlib/gleam/int".into()
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,6 @@ collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rodio = { workspace = true, features = ["wav", "playback", "tracing"] }
|
||||
rodio = { workspace = true, features = [ "wav", "playback", "tracing" ] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -18,7 +18,7 @@ fn main() {}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
mod windows_impl {
|
||||
use std::path::Path;
|
||||
use std::{borrow::Cow, path::Path};
|
||||
|
||||
use super::dialog::create_dialog_window;
|
||||
use super::updater::perform_update;
|
||||
@@ -37,9 +37,9 @@ mod windows_impl {
|
||||
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
|
||||
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Default)]
|
||||
struct Args {
|
||||
launch: Option<bool>,
|
||||
launch: bool,
|
||||
}
|
||||
|
||||
pub(crate) fn run() -> Result<()> {
|
||||
@@ -56,9 +56,9 @@ mod windows_impl {
|
||||
log::info!("======= Starting Zed update =======");
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let hwnd = create_dialog_window(rx)?.0 as isize;
|
||||
let args = parse_args();
|
||||
let args = parse_args(std::env::args().skip(1));
|
||||
std::thread::spawn(move || {
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch.unwrap_or(true));
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch);
|
||||
tx.send(result).ok();
|
||||
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
|
||||
});
|
||||
@@ -83,39 +83,27 @@ mod windows_impl {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut result = Args { launch: None };
|
||||
if let Some(candidate) = std::env::args().nth(1) {
|
||||
parse_single_arg(&candidate, &mut result);
|
||||
fn parse_args(input: impl IntoIterator<Item = String>) -> Args {
|
||||
let mut args: Args = Args { launch: true };
|
||||
|
||||
let mut input = input.into_iter();
|
||||
if let Some(arg) = input.next() {
|
||||
let launch_arg;
|
||||
|
||||
if arg == "--launch" {
|
||||
launch_arg = input.next().map(Cow::Owned);
|
||||
} else if let Some(rest) = arg.strip_prefix("--launch=") {
|
||||
launch_arg = Some(Cow::Borrowed(rest));
|
||||
} else {
|
||||
launch_arg = None;
|
||||
}
|
||||
|
||||
if launch_arg.as_deref() == Some("false") {
|
||||
args.launch = false;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn parse_single_arg(arg: &str, result: &mut Args) {
|
||||
let Some((key, value)) = arg.strip_prefix("--").and_then(|arg| arg.split_once('=')) else {
|
||||
log::error!(
|
||||
"Invalid argument format: '{}'. Expected format: --key=value",
|
||||
arg
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
match key {
|
||||
"launch" => parse_launch_arg(value, &mut result.launch),
|
||||
_ => log::error!("Unknown argument: --{}", key),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_launch_arg(value: &str, arg: &mut Option<bool>) {
|
||||
match value {
|
||||
"true" => *arg = Some(true),
|
||||
"false" => *arg = Some(false),
|
||||
_ => log::error!(
|
||||
"Invalid value for --launch: '{}'. Expected 'true' or 'false'",
|
||||
value
|
||||
),
|
||||
}
|
||||
args
|
||||
}
|
||||
|
||||
pub(crate) fn show_error(mut content: String) {
|
||||
@@ -135,44 +123,28 @@ mod windows_impl {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::windows_impl::{Args, parse_launch_arg, parse_single_arg};
|
||||
use crate::windows_impl::parse_args;
|
||||
|
||||
#[test]
|
||||
fn test_parse_launch_arg() {
|
||||
let mut arg = None;
|
||||
parse_launch_arg("true", &mut arg);
|
||||
assert_eq!(arg, Some(true));
|
||||
fn test_parse_args() {
|
||||
// launch can be specified via two separate arguments
|
||||
assert_eq!(parse_args(["--launch".into(), "true".into()]).launch, true);
|
||||
assert_eq!(
|
||||
parse_args(["--launch".into(), "false".into()]).launch,
|
||||
false
|
||||
);
|
||||
|
||||
let mut arg = None;
|
||||
parse_launch_arg("false", &mut arg);
|
||||
assert_eq!(arg, Some(false));
|
||||
// launch can be specified via one single argument
|
||||
assert_eq!(parse_args(["--launch=true".into()]).launch, true);
|
||||
assert_eq!(parse_args(["--launch=false".into()]).launch, false);
|
||||
|
||||
let mut arg = None;
|
||||
parse_launch_arg("invalid", &mut arg);
|
||||
assert_eq!(arg, None);
|
||||
}
|
||||
// launch defaults to true on no arguments
|
||||
assert_eq!(parse_args([]).launch, true);
|
||||
|
||||
#[test]
|
||||
fn test_parse_single_arg() {
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=true", &mut args);
|
||||
assert_eq!(args.launch, Some(true));
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=false", &mut args);
|
||||
assert_eq!(args.launch, Some(false));
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=invalid", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--unknown", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
// launch defaults to true on invalid arguments
|
||||
assert_eq!(parse_args(["--launch".into()]).launch, true);
|
||||
assert_eq!(parse_args(["--launch=".into()]).launch, true);
|
||||
assert_eq!(parse_args(["--launch=invalid".into()]).launch, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ use client::{
|
||||
};
|
||||
use collections::{BTreeMap, HashMap, HashSet};
|
||||
use fs::Fs;
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use futures::StreamExt;
|
||||
use gpui::{
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource,
|
||||
ScreenCaptureStream, Task, WeakEntity,
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FutureExt as _,
|
||||
ScreenCaptureSource, ScreenCaptureStream, Task, Timeout, WeakEntity,
|
||||
};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
@@ -370,57 +370,53 @@ impl Room {
|
||||
})?;
|
||||
|
||||
// Wait for client to re-establish a connection to the server.
|
||||
{
|
||||
let mut reconnection_timeout =
|
||||
cx.background_executor().timer(RECONNECT_TIMEOUT).fuse();
|
||||
let client_reconnection = async {
|
||||
let mut remaining_attempts = 3;
|
||||
while remaining_attempts > 0 {
|
||||
if client_status.borrow().is_connected() {
|
||||
log::info!("client reconnected, attempting to rejoin room");
|
||||
let executor = cx.background_executor().clone();
|
||||
let client_reconnection = async {
|
||||
let mut remaining_attempts = 3;
|
||||
while remaining_attempts > 0 {
|
||||
if client_status.borrow().is_connected() {
|
||||
log::info!("client reconnected, attempting to rejoin room");
|
||||
|
||||
let Some(this) = this.upgrade() else { break };
|
||||
match this.update(cx, |this, cx| this.rejoin(cx)) {
|
||||
Ok(task) => {
|
||||
if task.await.log_err().is_some() {
|
||||
return true;
|
||||
} else {
|
||||
remaining_attempts -= 1;
|
||||
}
|
||||
let Some(this) = this.upgrade() else { break };
|
||||
match this.update(cx, |this, cx| this.rejoin(cx)) {
|
||||
Ok(task) => {
|
||||
if task.await.log_err().is_some() {
|
||||
return true;
|
||||
} else {
|
||||
remaining_attempts -= 1;
|
||||
}
|
||||
Err(_app_dropped) => return false,
|
||||
}
|
||||
} else if client_status.borrow().is_signed_out() {
|
||||
return false;
|
||||
Err(_app_dropped) => return false,
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"waiting for client status change, remaining attempts {}",
|
||||
remaining_attempts
|
||||
);
|
||||
client_status.next().await;
|
||||
} else if client_status.borrow().is_signed_out() {
|
||||
return false;
|
||||
}
|
||||
false
|
||||
|
||||
log::info!(
|
||||
"waiting for client status change, remaining attempts {}",
|
||||
remaining_attempts
|
||||
);
|
||||
client_status.next().await;
|
||||
}
|
||||
.fuse();
|
||||
futures::pin_mut!(client_reconnection);
|
||||
false
|
||||
};
|
||||
|
||||
futures::select_biased! {
|
||||
reconnected = client_reconnection => {
|
||||
if reconnected {
|
||||
log::info!("successfully reconnected to room");
|
||||
// If we successfully joined the room, go back around the loop
|
||||
// waiting for future connection status changes.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
_ = reconnection_timeout => {
|
||||
log::info!("room reconnection timeout expired");
|
||||
}
|
||||
match client_reconnection
|
||||
.with_timeout(RECONNECT_TIMEOUT, &executor)
|
||||
.await
|
||||
{
|
||||
Ok(true) => {
|
||||
log::info!("successfully reconnected to room");
|
||||
// If we successfully joined the room, go back around the loop
|
||||
// waiting for future connection status changes.
|
||||
continue;
|
||||
}
|
||||
Ok(false) => break,
|
||||
Err(Timeout) => {
|
||||
log::info!("room reconnection timeout expired");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::Duration;
|
||||
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
|
||||
use futures::{StreamExt, stream::BoxStream};
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
|
||||
use http_client::{AsyncBody, Method, Request, http};
|
||||
use parking_lot::Mutex;
|
||||
use rpc::{
|
||||
ConnectionId, Peer, Receipt, TypedEnvelope,
|
||||
proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
|
||||
};
|
||||
use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct FakeServer {
|
||||
@@ -187,50 +183,27 @@ impl FakeServer {
|
||||
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
|
||||
self.executor.start_waiting();
|
||||
|
||||
loop {
|
||||
let message = self
|
||||
.state
|
||||
.lock()
|
||||
.incoming
|
||||
.as_mut()
|
||||
.expect("not connected")
|
||||
.next()
|
||||
.await
|
||||
.context("other half hung up")?;
|
||||
self.executor.finish_waiting();
|
||||
let type_name = message.payload_type_name();
|
||||
let message = message.into_any();
|
||||
let message = self
|
||||
.state
|
||||
.lock()
|
||||
.incoming
|
||||
.as_mut()
|
||||
.expect("not connected")
|
||||
.next()
|
||||
.await
|
||||
.context("other half hung up")?;
|
||||
self.executor.finish_waiting();
|
||||
let type_name = message.payload_type_name();
|
||||
let message = message.into_any();
|
||||
|
||||
if message.is::<TypedEnvelope<M>>() {
|
||||
return Ok(*message.downcast().unwrap());
|
||||
}
|
||||
|
||||
let accepted_tos_at = chrono::Utc::now()
|
||||
.checked_sub_signed(Duration::hours(5))
|
||||
.expect("failed to build accepted_tos_at")
|
||||
.timestamp() as u64;
|
||||
|
||||
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
|
||||
self.respond(
|
||||
message
|
||||
.downcast::<TypedEnvelope<GetPrivateUserInfo>>()
|
||||
.unwrap()
|
||||
.receipt(),
|
||||
GetPrivateUserInfoResponse {
|
||||
metrics_id: "the-metrics-id".into(),
|
||||
staff: false,
|
||||
flags: Default::default(),
|
||||
accepted_tos_at: Some(accepted_tos_at),
|
||||
},
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
panic!(
|
||||
"fake server received unexpected message type: {:?}",
|
||||
type_name
|
||||
);
|
||||
if message.is::<TypedEnvelope<M>>() {
|
||||
return Ok(*message.downcast().unwrap());
|
||||
}
|
||||
|
||||
panic!(
|
||||
"fake server received unexpected message type: {:?}",
|
||||
type_name
|
||||
);
|
||||
}
|
||||
|
||||
pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
|
||||
|
||||
@@ -177,7 +177,6 @@ impl UserStore {
|
||||
let (mut current_user_tx, current_user_rx) = watch::channel();
|
||||
let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded();
|
||||
let rpc_subscriptions = vec![
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_update_plan),
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_update_contacts),
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_update_invite_info),
|
||||
client.add_message_handler(cx.weak_entity(), Self::handle_show_contacts),
|
||||
@@ -343,26 +342,6 @@ impl UserStore {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_update_plan(
|
||||
this: Entity<Self>,
|
||||
_message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = this
|
||||
.read_with(&cx, |this, _| this.client.upgrade())?
|
||||
.context("client was dropped")?;
|
||||
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user")?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.update_authenticated_user(response, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
match message {
|
||||
UpdateContacts::Wait(barrier) => {
|
||||
@@ -1019,19 +998,6 @@ impl RequestUsage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_proto(amount: u32, limit: proto::UsageLimit) -> Option<Self> {
|
||||
let limit = match limit.variant? {
|
||||
proto::usage_limit::Variant::Limited(limited) => {
|
||||
UsageLimit::Limited(limited.limit as i32)
|
||||
}
|
||||
proto::usage_limit::Variant::Unlimited(_) => UsageLimit::Unlimited,
|
||||
};
|
||||
Some(RequestUsage {
|
||||
limit,
|
||||
amount: amount as i32,
|
||||
})
|
||||
}
|
||||
|
||||
fn from_headers(
|
||||
limit_name: &str,
|
||||
amount_name: &str,
|
||||
|
||||
@@ -19,7 +19,6 @@ test-support = ["sqlite"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-stripe.workspace = true
|
||||
async-trait.workspace = true
|
||||
async-tungstenite.workspace = true
|
||||
aws-config = { version = "1.1.5" }
|
||||
@@ -30,16 +29,13 @@ axum-extra = { version = "0.4", features = ["erased-json"] }
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
livekit_api.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
@@ -65,7 +61,6 @@ subtle.workspace = true
|
||||
supermaven_api.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
toml.workspace = true
|
||||
@@ -136,6 +131,3 @@ util.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
worktree = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
|
||||
[package.metadata.cargo-machete]
|
||||
ignored = ["async-stripe"]
|
||||
|
||||
@@ -219,12 +219,6 @@ spec:
|
||||
secretKeyRef:
|
||||
name: slack
|
||||
key: panics_webhook
|
||||
- name: STRIPE_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: stripe
|
||||
key: api_key
|
||||
optional: true
|
||||
- name: COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR
|
||||
value: "1000"
|
||||
- name: SUPERMAVEN_ADMIN_API_KEY
|
||||
|
||||
@@ -474,67 +474,6 @@ CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id
|
||||
|
||||
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
|
||||
|
||||
CREATE TABLE rate_buckets (
|
||||
user_id INT NOT NULL,
|
||||
rate_limit_name VARCHAR(255) NOT NULL,
|
||||
token_count INT NOT NULL,
|
||||
last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
|
||||
PRIMARY KEY (user_id, rate_limit_name),
|
||||
FOREIGN KEY (user_id) REFERENCES users (id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_preferences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id INTEGER NOT NULL REFERENCES users (id),
|
||||
max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL,
|
||||
model_request_overages_enabled bool NOT NULL DEFAULT FALSE,
|
||||
model_request_overages_spend_limit_in_cents integer NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id INTEGER NOT NULL REFERENCES users (id),
|
||||
has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
stripe_customer_id TEXT NOT NULL,
|
||||
trial_started_at TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_customers_on_stripe_customer_id" ON billing_customers (stripe_customer_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS billing_subscriptions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
billing_customer_id INTEGER NOT NULL REFERENCES billing_customers (id),
|
||||
stripe_subscription_id TEXT NOT NULL,
|
||||
stripe_subscription_status TEXT NOT NULL,
|
||||
stripe_cancel_at TIMESTAMP,
|
||||
stripe_cancellation_reason TEXT,
|
||||
kind TEXT,
|
||||
stripe_current_period_start BIGINT,
|
||||
stripe_current_period_end BIGINT
|
||||
);
|
||||
|
||||
CREATE INDEX "ix_billing_subscriptions_on_billing_customer_id" ON billing_subscriptions (billing_customer_id);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_subscriptions_on_stripe_subscription_id" ON billing_subscriptions (stripe_subscription_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS processed_stripe_events (
|
||||
stripe_event_id TEXT PRIMARY KEY,
|
||||
stripe_event_type TEXT NOT NULL,
|
||||
stripe_event_created_timestamp INTEGER NOT NULL,
|
||||
processed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX "ix_processed_stripe_events_on_stripe_event_created_timestamp" ON processed_stripe_events (stripe_event_created_timestamp);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS "breakpoints" (
|
||||
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
alter table users
|
||||
alter column admin set not null;
|
||||
@@ -0,0 +1,2 @@
|
||||
alter table billing_customers
|
||||
add column orb_customer_id text;
|
||||
@@ -0,0 +1 @@
|
||||
drop table rate_buckets;
|
||||
@@ -1,19 +1,11 @@
|
||||
pub mod billing;
|
||||
pub mod contributors;
|
||||
pub mod events;
|
||||
pub mod extensions;
|
||||
pub mod ips_file;
|
||||
pub mod slack;
|
||||
|
||||
use crate::db::Database;
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
db::{User, UserId},
|
||||
rpc,
|
||||
};
|
||||
use ::rpc::proto;
|
||||
use crate::{AppState, Error, Result, auth, db::UserId, rpc};
|
||||
use anyhow::Context as _;
|
||||
use axum::extract;
|
||||
use axum::{
|
||||
Extension, Json, Router,
|
||||
body::Body,
|
||||
@@ -25,7 +17,6 @@ use axum::{
|
||||
routing::{get, post},
|
||||
};
|
||||
use axum_extra::response::ErasedJson;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tower::ServiceBuilder;
|
||||
@@ -100,10 +91,7 @@ impl std::fmt::Display for SystemIdHeader {
|
||||
|
||||
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/users/look_up", get(look_up_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
||||
.route("/users/:id/update_plan", post(update_plan))
|
||||
.route("/rpc_server_snapshot", get(get_rpc_server_snapshot))
|
||||
.merge(contributors::router())
|
||||
.layer(
|
||||
@@ -144,99 +132,6 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
||||
Ok::<_, Error>(next.run(req).await)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LookUpUserParams {
|
||||
identifier: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LookUpUserResponse {
|
||||
user: Option<User>,
|
||||
}
|
||||
|
||||
async fn look_up_user(
|
||||
Query(params): Query<LookUpUserParams>,
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
) -> Result<Json<LookUpUserResponse>> {
|
||||
let user = resolve_identifier_to_user(&app.db, ¶ms.identifier).await?;
|
||||
let user = if let Some(user) = user {
|
||||
match user {
|
||||
UserOrId::User(user) => Some(user),
|
||||
UserOrId::Id(id) => app.db.get_user_by_id(id).await?,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Json(LookUpUserResponse { user }))
|
||||
}
|
||||
|
||||
enum UserOrId {
|
||||
User(User),
|
||||
Id(UserId),
|
||||
}
|
||||
|
||||
async fn resolve_identifier_to_user(
|
||||
db: &Arc<Database>,
|
||||
identifier: &str,
|
||||
) -> Result<Option<UserOrId>> {
|
||||
if let Some(identifier) = identifier.parse::<i32>().ok() {
|
||||
let user = db.get_user_by_id(UserId(identifier)).await?;
|
||||
|
||||
return Ok(user.map(UserOrId::User));
|
||||
}
|
||||
|
||||
if identifier.starts_with("cus_") {
|
||||
let billing_customer = db
|
||||
.get_billing_customer_by_stripe_customer_id(&identifier)
|
||||
.await?;
|
||||
|
||||
return Ok(billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id)));
|
||||
}
|
||||
|
||||
if identifier.starts_with("sub_") {
|
||||
let billing_subscription = db
|
||||
.get_billing_subscription_by_stripe_subscription_id(&identifier)
|
||||
.await?;
|
||||
|
||||
if let Some(billing_subscription) = billing_subscription {
|
||||
let billing_customer = db
|
||||
.get_billing_customer_by_id(billing_subscription.billing_customer_id)
|
||||
.await?;
|
||||
|
||||
return Ok(
|
||||
billing_customer.map(|billing_customer| UserOrId::Id(billing_customer.user_id))
|
||||
);
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
if identifier.contains('@') {
|
||||
let user = db.get_user_by_email(identifier).await?;
|
||||
|
||||
return Ok(user.map(UserOrId::User));
|
||||
}
|
||||
|
||||
if let Some(user) = db.get_user_by_github_login(identifier).await? {
|
||||
return Ok(Some(UserOrId::User(user)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct CreateUserParams {
|
||||
github_user_id: i32,
|
||||
github_login: String,
|
||||
email_address: String,
|
||||
email_confirmation_code: Option<String>,
|
||||
#[serde(default)]
|
||||
admin: bool,
|
||||
#[serde(default)]
|
||||
invite_count: i32,
|
||||
}
|
||||
|
||||
async fn get_rpc_server_snapshot(
|
||||
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||
) -> Result<ErasedJson> {
|
||||
@@ -295,90 +190,3 @@ async fn create_access_token(
|
||||
encrypted_access_token,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshLlmTokensResponse {}
|
||||
|
||||
async fn refresh_llm_tokens(
|
||||
Path(user_id): Path<UserId>,
|
||||
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||
) -> Result<Json<RefreshLlmTokensResponse>> {
|
||||
rpc_server.refresh_llm_tokens_for_user(user_id).await;
|
||||
|
||||
Ok(Json(RefreshLlmTokensResponse {}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UpdatePlanBody {
|
||||
pub plan: cloud_llm_client::Plan,
|
||||
pub subscription_period: SubscriptionPeriod,
|
||||
pub usage: cloud_llm_client::CurrentUsage,
|
||||
pub trial_started_at: Option<DateTime<Utc>>,
|
||||
pub is_usage_based_billing_enabled: bool,
|
||||
pub is_account_too_young: bool,
|
||||
pub has_overdue_invoices: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
struct SubscriptionPeriod {
|
||||
pub started_at: DateTime<Utc>,
|
||||
pub ended_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UpdatePlanResponse {}
|
||||
|
||||
async fn update_plan(
|
||||
Path(user_id): Path<UserId>,
|
||||
Extension(rpc_server): Extension<Arc<rpc::Server>>,
|
||||
extract::Json(body): extract::Json<UpdatePlanBody>,
|
||||
) -> Result<Json<UpdatePlanResponse>> {
|
||||
let plan = match body.plan {
|
||||
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
let update_user_plan = proto::UpdateUserPlan {
|
||||
plan: plan.into(),
|
||||
trial_started_at: body
|
||||
.trial_started_at
|
||||
.map(|trial_started_at| trial_started_at.timestamp() as u64),
|
||||
is_usage_based_billing_enabled: Some(body.is_usage_based_billing_enabled),
|
||||
usage: Some(proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: body.usage.model_requests.used,
|
||||
model_requests_usage_limit: Some(usage_limit_to_proto(body.usage.model_requests.limit)),
|
||||
edit_predictions_usage_amount: body.usage.edit_predictions.used,
|
||||
edit_predictions_usage_limit: Some(usage_limit_to_proto(
|
||||
body.usage.edit_predictions.limit,
|
||||
)),
|
||||
}),
|
||||
subscription_period: Some(proto::SubscriptionPeriod {
|
||||
started_at: body.subscription_period.started_at.timestamp() as u64,
|
||||
ended_at: body.subscription_period.ended_at.timestamp() as u64,
|
||||
}),
|
||||
account_too_young: Some(body.is_account_too_young),
|
||||
has_overdue_invoices: Some(body.has_overdue_invoices),
|
||||
};
|
||||
|
||||
rpc_server
|
||||
.update_plan_for_user(user_id, update_user_plan)
|
||||
.await?;
|
||||
|
||||
Ok(Json(UpdatePlanResponse {}))
|
||||
}
|
||||
|
||||
fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||
proto::UsageLimit {
|
||||
variant: Some(match limit {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use stripe::SubscriptionStatus;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::StripeSubscriptionStatus;
|
||||
use crate::db::{CreateBillingCustomerParams, billing_customer};
|
||||
use crate::stripe_client::{StripeClient, StripeCustomerId};
|
||||
|
||||
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
fn from(value: SubscriptionStatus) -> Self {
|
||||
match value {
|
||||
SubscriptionStatus::Incomplete => Self::Incomplete,
|
||||
SubscriptionStatus::IncompleteExpired => Self::IncompleteExpired,
|
||||
SubscriptionStatus::Trialing => Self::Trialing,
|
||||
SubscriptionStatus::Active => Self::Active,
|
||||
SubscriptionStatus::PastDue => Self::PastDue,
|
||||
SubscriptionStatus::Canceled => Self::Canceled,
|
||||
SubscriptionStatus::Unpaid => Self::Unpaid,
|
||||
SubscriptionStatus::Paused => Self::Paused,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds or creates a billing customer using the provided customer.
|
||||
pub async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &dyn StripeClient,
|
||||
customer_id: &StripeCustomerId,
|
||||
) -> anyhow::Result<Option<billing_customer::Model>> {
|
||||
// If we already have a billing customer record associated with the Stripe customer,
|
||||
// there's nothing more we need to do.
|
||||
if let Some(billing_customer) = app
|
||||
.db
|
||||
.get_billing_customer_by_stripe_customer_id(customer_id.0.as_ref())
|
||||
.await?
|
||||
{
|
||||
return Ok(Some(billing_customer));
|
||||
}
|
||||
|
||||
let customer = stripe_client.get_customer(customer_id).await?;
|
||||
|
||||
let Some(email) = customer.email else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(user) = app.db.get_user_by_email(&email).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let billing_customer = app
|
||||
.db
|
||||
.create_billing_customer(&CreateBillingCustomerParams {
|
||||
user_id: user.id,
|
||||
stripe_customer_id: customer.id.to_string(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
@@ -564,170 +564,10 @@ fn for_snowflake(
|
||||
country_code: Option<String>,
|
||||
checksum_matched: bool,
|
||||
) -> impl Iterator<Item = SnowflakeRow> {
|
||||
body.events.into_iter().filter_map(move |event| {
|
||||
body.events.into_iter().map(move |event| {
|
||||
let timestamp =
|
||||
first_event_at + Duration::milliseconds(event.milliseconds_since_first_event);
|
||||
// We will need to double check, but I believe all of the events that
|
||||
// are being transformed here are now migrated over to use the
|
||||
// telemetry::event! macro, as of this commit so this code can go away
|
||||
// when we feel enough users have upgraded past this point.
|
||||
let (event_type, mut event_properties) = match &event.event {
|
||||
Event::Editor(e) => (
|
||||
match e.operation.as_str() {
|
||||
"open" => "Editor Opened".to_string(),
|
||||
"save" => "Editor Saved".to_string(),
|
||||
_ => format!("Unknown Editor Event: {}", e.operation),
|
||||
},
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::EditPrediction(e) => (
|
||||
format!(
|
||||
"Edit Prediction {}",
|
||||
if e.suggestion_accepted {
|
||||
"Accepted"
|
||||
} else {
|
||||
"Discarded"
|
||||
}
|
||||
),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::EditPredictionRating(e) => (
|
||||
"Edit Prediction Rated".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Call(e) => {
|
||||
let event_type = match e.operation.trim() {
|
||||
"unshare project" => "Project Unshared".to_string(),
|
||||
"open channel notes" => "Channel Notes Opened".to_string(),
|
||||
"share project" => "Project Shared".to_string(),
|
||||
"join channel" => "Channel Joined".to_string(),
|
||||
"hang up" => "Call Ended".to_string(),
|
||||
"accept incoming" => "Incoming Call Accepted".to_string(),
|
||||
"invite" => "Participant Invited".to_string(),
|
||||
"disable microphone" => "Microphone Disabled".to_string(),
|
||||
"enable microphone" => "Microphone Enabled".to_string(),
|
||||
"enable screen share" => "Screen Share Enabled".to_string(),
|
||||
"disable screen share" => "Screen Share Disabled".to_string(),
|
||||
"decline incoming" => "Incoming Call Declined".to_string(),
|
||||
_ => format!("Unknown Call Event: {}", e.operation),
|
||||
};
|
||||
|
||||
(event_type, serde_json::to_value(e).unwrap())
|
||||
}
|
||||
Event::Assistant(e) => (
|
||||
match e.phase {
|
||||
telemetry_events::AssistantPhase::Response => "Assistant Responded".to_string(),
|
||||
telemetry_events::AssistantPhase::Invoked => "Assistant Invoked".to_string(),
|
||||
telemetry_events::AssistantPhase::Accepted => {
|
||||
"Assistant Response Accepted".to_string()
|
||||
}
|
||||
telemetry_events::AssistantPhase::Rejected => {
|
||||
"Assistant Response Rejected".to_string()
|
||||
}
|
||||
},
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Cpu(_) | Event::Memory(_) => return None,
|
||||
Event::App(e) => {
|
||||
let mut properties = json!({});
|
||||
let event_type = match e.operation.trim() {
|
||||
// App
|
||||
"open" => "App Opened".to_string(),
|
||||
"first open" => "App First Opened".to_string(),
|
||||
"first open for release channel" => {
|
||||
"App First Opened For Release Channel".to_string()
|
||||
}
|
||||
"close" => "App Closed".to_string(),
|
||||
|
||||
// Project
|
||||
"open project" => "Project Opened".to_string(),
|
||||
"open node project" => {
|
||||
properties["project_type"] = json!("node");
|
||||
"Project Opened".to_string()
|
||||
}
|
||||
"open pnpm project" => {
|
||||
properties["project_type"] = json!("pnpm");
|
||||
"Project Opened".to_string()
|
||||
}
|
||||
"open yarn project" => {
|
||||
properties["project_type"] = json!("yarn");
|
||||
"Project Opened".to_string()
|
||||
}
|
||||
|
||||
// SSH
|
||||
"create ssh server" => "SSH Server Created".to_string(),
|
||||
"create ssh project" => "SSH Project Created".to_string(),
|
||||
"open ssh project" => "SSH Project Opened".to_string(),
|
||||
|
||||
// Welcome Page
|
||||
"welcome page: change keymap" => "Welcome Keymap Changed".to_string(),
|
||||
"welcome page: change theme" => "Welcome Theme Changed".to_string(),
|
||||
"welcome page: close" => "Welcome Page Closed".to_string(),
|
||||
"welcome page: edit settings" => "Welcome Settings Edited".to_string(),
|
||||
"welcome page: install cli" => "Welcome CLI Installed".to_string(),
|
||||
"welcome page: open" => "Welcome Page Opened".to_string(),
|
||||
"welcome page: open extensions" => "Welcome Extensions Page Opened".to_string(),
|
||||
"welcome page: sign in to copilot" => "Welcome Copilot Signed In".to_string(),
|
||||
"welcome page: toggle diagnostic telemetry" => {
|
||||
"Welcome Diagnostic Telemetry Toggled".to_string()
|
||||
}
|
||||
"welcome page: toggle metric telemetry" => {
|
||||
"Welcome Metric Telemetry Toggled".to_string()
|
||||
}
|
||||
"welcome page: toggle vim" => "Welcome Vim Mode Toggled".to_string(),
|
||||
"welcome page: view docs" => "Welcome Documentation Viewed".to_string(),
|
||||
|
||||
// Extensions
|
||||
"extensions page: open" => "Extensions Page Opened".to_string(),
|
||||
"extensions: install extension" => "Extension Installed".to_string(),
|
||||
"extensions: uninstall extension" => "Extension Uninstalled".to_string(),
|
||||
|
||||
// Misc
|
||||
"markdown preview: open" => "Markdown Preview Opened".to_string(),
|
||||
"project diagnostics: open" => "Project Diagnostics Opened".to_string(),
|
||||
"project search: open" => "Project Search Opened".to_string(),
|
||||
"repl sessions: open" => "REPL Session Started".to_string(),
|
||||
|
||||
// Feature Upsell
|
||||
"feature upsell: toggle vim" => {
|
||||
properties["source"] = json!("Feature Upsell");
|
||||
"Vim Mode Toggled".to_string()
|
||||
}
|
||||
_ => e
|
||||
.operation
|
||||
.strip_prefix("feature upsell: viewed docs (")
|
||||
.and_then(|s| s.strip_suffix(')'))
|
||||
.map_or_else(
|
||||
|| format!("Unknown App Event: {}", e.operation),
|
||||
|docs_url| {
|
||||
properties["url"] = json!(docs_url);
|
||||
properties["source"] = json!("Feature Upsell");
|
||||
"Documentation Viewed".to_string()
|
||||
},
|
||||
),
|
||||
};
|
||||
(event_type, properties)
|
||||
}
|
||||
Event::Setting(e) => (
|
||||
"Settings Changed".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Extension(e) => (
|
||||
"Extension Loaded".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Edit(e) => (
|
||||
"Editor Edited".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Action(e) => (
|
||||
"Action Invoked".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Repl(e) => (
|
||||
"Kernel Status Changed".to_string(),
|
||||
serde_json::to_value(e).unwrap(),
|
||||
),
|
||||
Event::Flexible(e) => (
|
||||
e.event_type.clone(),
|
||||
serde_json::to_value(&e.event_properties).unwrap(),
|
||||
@@ -759,7 +599,7 @@ fn for_snowflake(
|
||||
})
|
||||
});
|
||||
|
||||
Some(SnowflakeRow {
|
||||
SnowflakeRow {
|
||||
time: timestamp,
|
||||
user_id: body.metrics_id.clone(),
|
||||
device_id: body.system_id.clone(),
|
||||
@@ -767,7 +607,7 @@ fn for_snowflake(
|
||||
event_properties,
|
||||
user_properties,
|
||||
insert_id: Some(Uuid::new_v4().to_string()),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -41,12 +41,7 @@ use worktree_settings_file::LocalSettingsKind;
|
||||
pub use tests::TestDb;
|
||||
|
||||
pub use ids::*;
|
||||
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
|
||||
pub use queries::billing_subscriptions::{
|
||||
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
|
||||
};
|
||||
pub use queries::contributors::ContributorSelector;
|
||||
pub use queries::processed_stripe_events::CreateProcessedStripeEventParams;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
pub use tables::user::Model as User;
|
||||
pub use tables::*;
|
||||
|
||||
@@ -70,9 +70,6 @@ macro_rules! id_type {
|
||||
}
|
||||
|
||||
id_type!(AccessTokenId);
|
||||
id_type!(BillingCustomerId);
|
||||
id_type!(BillingSubscriptionId);
|
||||
id_type!(BillingPreferencesId);
|
||||
id_type!(BufferId);
|
||||
id_type!(ChannelBufferCollaboratorId);
|
||||
id_type!(ChannelChatParticipantId);
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use super::*;
|
||||
|
||||
pub mod access_tokens;
|
||||
pub mod billing_customers;
|
||||
pub mod billing_preferences;
|
||||
pub mod billing_subscriptions;
|
||||
pub mod buffers;
|
||||
pub mod channels;
|
||||
pub mod contacts;
|
||||
@@ -12,7 +9,6 @@ pub mod embeddings;
|
||||
pub mod extensions;
|
||||
pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod processed_stripe_events;
|
||||
pub mod projects;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingCustomerParams {
|
||||
pub user_id: UserId,
|
||||
pub stripe_customer_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingCustomerParams {
|
||||
pub user_id: ActiveValue<UserId>,
|
||||
pub stripe_customer_id: ActiveValue<String>,
|
||||
pub has_overdue_invoices: ActiveValue<bool>,
|
||||
pub trial_started_at: ActiveValue<Option<DateTime>>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Creates a new billing customer.
|
||||
pub async fn create_billing_customer(
|
||||
&self,
|
||||
params: &CreateBillingCustomerParams,
|
||||
) -> Result<billing_customer::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
|
||||
user_id: ActiveValue::set(params.user_id),
|
||||
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_with_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(customer)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the specified billing customer.
|
||||
pub async fn update_billing_customer(
|
||||
&self,
|
||||
id: BillingCustomerId,
|
||||
params: &UpdateBillingCustomerParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_customer::Entity::update(billing_customer::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
user_id: params.user_id.clone(),
|
||||
stripe_customer_id: params.stripe_customer_id.clone(),
|
||||
has_overdue_invoices: params.has_overdue_invoices.clone(),
|
||||
trial_started_at: params.trial_started_at.clone(),
|
||||
created_at: ActiveValue::not_set(),
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_billing_customer_by_id(
|
||||
&self,
|
||||
id: BillingCustomerId,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::Id.eq(id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing customer for the user with the specified ID.
|
||||
pub async fn get_billing_customer_by_user_id(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing customer for the user with the specified Stripe customer ID.
|
||||
pub async fn get_billing_customer_by_stripe_customer_id(
|
||||
&self,
|
||||
stripe_customer_id: &str,
|
||||
) -> Result<Option<billing_customer::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_customer::Entity::find()
|
||||
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
impl Database {
|
||||
/// Returns the billing preferences for the given user, if they exist.
|
||||
pub async fn get_billing_preferences(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_preference::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_preference::Entity::find()
|
||||
.filter(billing_preference::Column::UserId.eq(user_id))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
use anyhow::Context as _;
|
||||
|
||||
use crate::db::billing_subscription::{
|
||||
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateBillingSubscriptionParams {
|
||||
pub billing_customer_id: BillingCustomerId,
|
||||
pub kind: Option<SubscriptionKind>,
|
||||
pub stripe_subscription_id: String,
|
||||
pub stripe_subscription_status: StripeSubscriptionStatus,
|
||||
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
|
||||
pub stripe_current_period_start: Option<i64>,
|
||||
pub stripe_current_period_end: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct UpdateBillingSubscriptionParams {
|
||||
pub billing_customer_id: ActiveValue<BillingCustomerId>,
|
||||
pub kind: ActiveValue<Option<SubscriptionKind>>,
|
||||
pub stripe_subscription_id: ActiveValue<String>,
|
||||
pub stripe_subscription_status: ActiveValue<StripeSubscriptionStatus>,
|
||||
pub stripe_cancel_at: ActiveValue<Option<DateTime>>,
|
||||
pub stripe_cancellation_reason: ActiveValue<Option<StripeCancellationReason>>,
|
||||
pub stripe_current_period_start: ActiveValue<Option<i64>>,
|
||||
pub stripe_current_period_end: ActiveValue<Option<i64>>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Creates a new billing subscription.
|
||||
pub async fn create_billing_subscription(
|
||||
&self,
|
||||
params: &CreateBillingSubscriptionParams,
|
||||
) -> Result<billing_subscription::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||
kind: ActiveValue::set(params.kind),
|
||||
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
||||
stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status),
|
||||
stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason),
|
||||
stripe_current_period_start: ActiveValue::set(params.stripe_current_period_start),
|
||||
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
.last_insert_id;
|
||||
|
||||
Ok(billing_subscription::Entity::find_by_id(id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.context("failed to retrieve inserted billing subscription")?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Updates the specified billing subscription.
|
||||
pub async fn update_billing_subscription(
|
||||
&self,
|
||||
id: BillingSubscriptionId,
|
||||
params: &UpdateBillingSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_subscription::Entity::update(billing_subscription::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
billing_customer_id: params.billing_customer_id.clone(),
|
||||
kind: params.kind.clone(),
|
||||
stripe_subscription_id: params.stripe_subscription_id.clone(),
|
||||
stripe_subscription_status: params.stripe_subscription_status.clone(),
|
||||
stripe_cancel_at: params.stripe_cancel_at.clone(),
|
||||
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
|
||||
stripe_current_period_start: params.stripe_current_period_start.clone(),
|
||||
stripe_current_period_end: params.stripe_current_period_end.clone(),
|
||||
created_at: ActiveValue::not_set(),
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the billing subscription with the specified Stripe subscription ID.
|
||||
pub async fn get_billing_subscription_by_stripe_subscription_id(
|
||||
&self,
|
||||
stripe_subscription_id: &str,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.filter(
|
||||
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_billing_subscription(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.filter(
|
||||
Condition::all()
|
||||
.add(
|
||||
Condition::any()
|
||||
.add(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.add(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Trialing),
|
||||
),
|
||||
)
|
||||
.add(billing_subscription::Column::Kind.is_not_null()),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the user has an active billing subscription.
|
||||
pub async fn has_active_billing_subscription(&self, user_id: UserId) -> Result<bool> {
|
||||
Ok(self.count_active_billing_subscriptions(user_id).await? > 0)
|
||||
}
|
||||
|
||||
/// Returns the count of the active billing subscriptions for the user with the specified ID.
|
||||
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
|
||||
self.transaction(|tx| async move {
|
||||
let count = billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(
|
||||
billing_customer::Column::UserId.eq(user_id).and(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active)
|
||||
.or(billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Trialing)),
|
||||
),
|
||||
)
|
||||
.count(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(count as usize)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CreateProcessedStripeEventParams {
|
||||
pub stripe_event_id: String,
|
||||
pub stripe_event_type: String,
|
||||
pub stripe_event_created_timestamp: i64,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Creates a new processed Stripe event.
|
||||
pub async fn create_processed_stripe_event(
|
||||
&self,
|
||||
params: &CreateProcessedStripeEventParams,
|
||||
) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
|
||||
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
|
||||
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
|
||||
stripe_event_created_timestamp: ActiveValue::set(
|
||||
params.stripe_event_created_timestamp,
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the processed Stripe event with the specified event ID.
|
||||
pub async fn get_processed_stripe_event_by_event_id(
|
||||
&self,
|
||||
event_id: &str,
|
||||
) -> Result<Option<processed_stripe_event::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(processed_stripe_event::Entity::find_by_id(event_id)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns the processed Stripe events with the specified event IDs.
|
||||
pub async fn get_processed_stripe_events_by_event_ids(
|
||||
&self,
|
||||
event_ids: &[&str],
|
||||
) -> Result<Vec<processed_stripe_event::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(processed_stripe_event::Entity::find()
|
||||
.filter(
|
||||
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),
|
||||
)
|
||||
.all(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns whether the Stripe event with the specified ID has already been processed.
|
||||
pub async fn already_processed_stripe_event(&self, event_id: &str) -> Result<bool> {
|
||||
Ok(self
|
||||
.get_processed_stripe_event_by_event_id(event_id)
|
||||
.await?
|
||||
.is_some())
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,4 @@
|
||||
pub mod access_token;
|
||||
pub mod billing_customer;
|
||||
pub mod billing_preference;
|
||||
pub mod billing_subscription;
|
||||
pub mod buffer;
|
||||
pub mod buffer_operation;
|
||||
pub mod buffer_snapshot;
|
||||
@@ -23,7 +20,6 @@ pub mod notification;
|
||||
pub mod notification_kind;
|
||||
pub mod observed_buffer_edits;
|
||||
pub mod observed_channel_messages;
|
||||
pub mod processed_stripe_event;
|
||||
pub mod project;
|
||||
pub mod project_collaborator;
|
||||
pub mod project_repository;
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
use crate::db::{BillingCustomerId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
/// A billing customer.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_customers")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingCustomerId,
|
||||
pub user_id: UserId,
|
||||
pub stripe_customer_id: String,
|
||||
pub has_overdue_invoices: bool,
|
||||
pub trial_started_at: Option<DateTime>,
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
#[sea_orm(has_many = "super::billing_subscription::Entity")]
|
||||
BillingSubscription,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_subscription::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingSubscription.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,32 +0,0 @@
|
||||
use crate::db::{BillingPreferencesId, UserId};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_preferences")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingPreferencesId,
|
||||
pub created_at: DateTime,
|
||||
pub user_id: UserId,
|
||||
pub max_monthly_llm_usage_spending_in_cents: i32,
|
||||
pub model_request_overages_enabled: bool,
|
||||
pub model_request_overages_spend_limit_in_cents: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,176 +0,0 @@
|
||||
use crate::db::{BillingCustomerId, BillingSubscriptionId};
|
||||
use crate::stripe_client;
|
||||
use chrono::{Datelike as _, NaiveDate, Utc};
|
||||
use sea_orm::entity::prelude::*;
|
||||
use serde::Serialize;
|
||||
|
||||
/// A billing subscription.
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_subscriptions")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingSubscriptionId,
|
||||
pub billing_customer_id: BillingCustomerId,
|
||||
pub kind: Option<SubscriptionKind>,
|
||||
pub stripe_subscription_id: String,
|
||||
pub stripe_subscription_status: StripeSubscriptionStatus,
|
||||
pub stripe_cancel_at: Option<DateTime>,
|
||||
pub stripe_cancellation_reason: Option<StripeCancellationReason>,
|
||||
pub stripe_current_period_start: Option<i64>,
|
||||
pub stripe_current_period_end: Option<i64>,
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
|
||||
let period_start = self.stripe_current_period_start?;
|
||||
chrono::DateTime::from_timestamp(period_start, 0)
|
||||
}
|
||||
|
||||
pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
|
||||
let period_end = self.stripe_current_period_end?;
|
||||
chrono::DateTime::from_timestamp(period_end, 0)
|
||||
}
|
||||
|
||||
pub fn current_period(
|
||||
subscription: Option<Self>,
|
||||
is_staff: bool,
|
||||
) -> Option<(DateTimeUtc, DateTimeUtc)> {
|
||||
if is_staff {
|
||||
let now = Utc::now();
|
||||
let year = now.year();
|
||||
let month = now.month();
|
||||
|
||||
let first_day_of_this_month =
|
||||
NaiveDate::from_ymd_opt(year, month, 1)?.and_hms_opt(0, 0, 0)?;
|
||||
|
||||
let next_month = if month == 12 { 1 } else { month + 1 };
|
||||
let next_month_year = if month == 12 { year + 1 } else { year };
|
||||
let first_day_of_next_month =
|
||||
NaiveDate::from_ymd_opt(next_month_year, next_month, 1)?.and_hms_opt(23, 59, 59)?;
|
||||
|
||||
let last_day_of_this_month = first_day_of_next_month - chrono::Days::new(1);
|
||||
|
||||
Some((
|
||||
first_day_of_this_month.and_utc(),
|
||||
last_day_of_this_month.and_utc(),
|
||||
))
|
||||
} else {
|
||||
let subscription = subscription?;
|
||||
let period_start_at = subscription.current_period_start_at()?;
|
||||
let period_end_at = subscription.current_period_end_at()?;
|
||||
|
||||
Some((period_start_at, period_end_at))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::billing_customer::Entity",
|
||||
from = "Column::BillingCustomerId",
|
||||
to = "super::billing_customer::Column::Id"
|
||||
)]
|
||||
BillingCustomer,
|
||||
}
|
||||
|
||||
impl Related<super::billing_customer::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingCustomer.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SubscriptionKind {
|
||||
#[sea_orm(string_value = "zed_pro")]
|
||||
ZedPro,
|
||||
#[sea_orm(string_value = "zed_pro_trial")]
|
||||
ZedProTrial,
|
||||
#[sea_orm(string_value = "zed_free")]
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
impl From<SubscriptionKind> for cloud_llm_client::Plan {
|
||||
fn from(value: SubscriptionKind) -> Self {
|
||||
match value {
|
||||
SubscriptionKind::ZedPro => Self::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Self::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => Self::ZedFree,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The status of a Stripe subscription.
|
||||
///
|
||||
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-status)
|
||||
#[derive(
|
||||
Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash, Serialize,
|
||||
)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StripeSubscriptionStatus {
|
||||
#[default]
|
||||
#[sea_orm(string_value = "incomplete")]
|
||||
Incomplete,
|
||||
#[sea_orm(string_value = "incomplete_expired")]
|
||||
IncompleteExpired,
|
||||
#[sea_orm(string_value = "trialing")]
|
||||
Trialing,
|
||||
#[sea_orm(string_value = "active")]
|
||||
Active,
|
||||
#[sea_orm(string_value = "past_due")]
|
||||
PastDue,
|
||||
#[sea_orm(string_value = "canceled")]
|
||||
Canceled,
|
||||
#[sea_orm(string_value = "unpaid")]
|
||||
Unpaid,
|
||||
#[sea_orm(string_value = "paused")]
|
||||
Paused,
|
||||
}
|
||||
|
||||
impl StripeSubscriptionStatus {
|
||||
pub fn is_cancelable(&self) -> bool {
|
||||
match self {
|
||||
Self::Trialing | Self::Active | Self::PastDue => true,
|
||||
Self::Incomplete
|
||||
| Self::IncompleteExpired
|
||||
| Self::Canceled
|
||||
| Self::Unpaid
|
||||
| Self::Paused => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The cancellation reason for a Stripe subscription.
|
||||
///
|
||||
/// [Stripe docs](https://docs.stripe.com/api/subscriptions/object#subscription_object-cancellation_details-reason)
|
||||
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
|
||||
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StripeCancellationReason {
|
||||
#[sea_orm(string_value = "cancellation_requested")]
|
||||
CancellationRequested,
|
||||
#[sea_orm(string_value = "payment_disputed")]
|
||||
PaymentDisputed,
|
||||
#[sea_orm(string_value = "payment_failed")]
|
||||
PaymentFailed,
|
||||
}
|
||||
|
||||
impl From<stripe_client::StripeCancellationDetailsReason> for StripeCancellationReason {
|
||||
fn from(value: stripe_client::StripeCancellationDetailsReason) -> Self {
|
||||
match value {
|
||||
stripe_client::StripeCancellationDetailsReason::CancellationRequested => {
|
||||
Self::CancellationRequested
|
||||
}
|
||||
stripe_client::StripeCancellationDetailsReason::PaymentDisputed => {
|
||||
Self::PaymentDisputed
|
||||
}
|
||||
stripe_client::StripeCancellationDetailsReason::PaymentFailed => Self::PaymentFailed,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "processed_stripe_events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub stripe_event_id: String,
|
||||
pub stripe_event_type: String,
|
||||
pub stripe_event_created_timestamp: i64,
|
||||
pub processed_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -29,8 +29,6 @@ pub struct Model {
|
||||
pub enum Relation {
|
||||
#[sea_orm(has_many = "super::access_token::Entity")]
|
||||
AccessToken,
|
||||
#[sea_orm(has_one = "super::billing_customer::Entity")]
|
||||
BillingCustomer,
|
||||
#[sea_orm(has_one = "super::room_participant::Entity")]
|
||||
RoomParticipant,
|
||||
#[sea_orm(has_many = "super::project::Entity")]
|
||||
@@ -68,12 +66,6 @@ impl Related<super::access_token::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_customer::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingCustomer.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::room_participant::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::RoomParticipant.def()
|
||||
|
||||
@@ -8,7 +8,6 @@ mod embedding_tests;
|
||||
mod extension_tests;
|
||||
mod feature_flag_tests;
|
||||
mod message_tests;
|
||||
mod processed_stripe_event_tests;
|
||||
mod user_tests;
|
||||
|
||||
use crate::migrations::run_database_migrations;
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::test_both_dbs;
|
||||
|
||||
use super::{CreateProcessedStripeEventParams, Database};
|
||||
|
||||
test_both_dbs!(
|
||||
test_already_processed_stripe_event,
|
||||
test_already_processed_stripe_event_postgres,
|
||||
test_already_processed_stripe_event_sqlite
|
||||
);
|
||||
|
||||
async fn test_already_processed_stripe_event(db: &Arc<Database>) {
|
||||
let unprocessed_event_id = "evt_1PiJOuRxOf7d5PNaw2zzWiyO".to_string();
|
||||
let processed_event_id = "evt_1PiIfMRxOf7d5PNakHrAUe8P".to_string();
|
||||
|
||||
db.create_processed_stripe_event(&CreateProcessedStripeEventParams {
|
||||
stripe_event_id: processed_event_id.clone(),
|
||||
stripe_event_type: "customer.created".into(),
|
||||
stripe_event_created_timestamp: 1722355968,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
db.already_processed_stripe_event(&processed_event_id)
|
||||
.await
|
||||
.unwrap(),
|
||||
"Expected {processed_event_id} to already be processed"
|
||||
);
|
||||
|
||||
assert!(
|
||||
!db.already_processed_stripe_event(&unprocessed_event_id)
|
||||
.await
|
||||
.unwrap(),
|
||||
"Expected {unprocessed_event_id} to be unprocessed"
|
||||
);
|
||||
}
|
||||
@@ -7,8 +7,6 @@ pub mod llm;
|
||||
pub mod migrations;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
pub mod stripe_client;
|
||||
pub mod user_backfiller;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -22,21 +20,16 @@ use axum::{
|
||||
};
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::stripe_billing::StripeBilling;
|
||||
use crate::stripe_client::{RealStripeClient, StripeClient};
|
||||
|
||||
pub type Result<T, E = Error> = std::result::Result<T, E>;
|
||||
|
||||
pub enum Error {
|
||||
Http(StatusCode, String, HeaderMap),
|
||||
Database(sea_orm::error::DbErr),
|
||||
Internal(anyhow::Error),
|
||||
Stripe(stripe::StripeError),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for Error {
|
||||
@@ -51,12 +44,6 @@ impl From<sea_orm::error::DbErr> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<stripe::StripeError> for Error {
|
||||
fn from(error: stripe::StripeError) -> Self {
|
||||
Self::Stripe(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<axum::Error> for Error {
|
||||
fn from(error: axum::Error) -> Self {
|
||||
Self::Internal(error.into())
|
||||
@@ -104,14 +91,6 @@ impl IntoResponse for Error {
|
||||
);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
|
||||
}
|
||||
Error::Stripe(error) => {
|
||||
log::error!(
|
||||
"HTTP error {}: {:?}",
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
&error
|
||||
);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,7 +101,6 @@ impl std::fmt::Debug for Error {
|
||||
Error::Http(code, message, _headers) => (code, message).fmt(f),
|
||||
Error::Database(error) => error.fmt(f),
|
||||
Error::Internal(error) => error.fmt(f),
|
||||
Error::Stripe(error) => error.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,7 +111,6 @@ impl std::fmt::Display for Error {
|
||||
Error::Http(code, message, _) => write!(f, "{code}: {message}"),
|
||||
Error::Database(error) => error.fmt(f),
|
||||
Error::Internal(error) => error.fmt(f),
|
||||
Error::Stripe(error) => error.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,7 +156,6 @@ pub struct Config {
|
||||
pub zed_client_checksum_seed: Option<String>,
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
@@ -234,7 +210,6 @@ impl Config {
|
||||
auto_join_channel_id: None,
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
@@ -266,14 +241,8 @@ impl ServiceMode {
|
||||
|
||||
pub struct AppState {
|
||||
pub db: Arc<Database>,
|
||||
pub llm_db: Option<Arc<LlmDatabase>>,
|
||||
pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
/// This is a real instance of the Stripe client; we're working to replace references to this with the
|
||||
/// [`StripeClient`] trait.
|
||||
pub real_stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_client: Option<Arc<dyn StripeClient>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub executor: Executor,
|
||||
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
||||
pub config: Config,
|
||||
@@ -286,20 +255,6 @@ impl AppState {
|
||||
let mut db = Database::new(db_options).await?;
|
||||
db.initialize_notification_kinds().await?;
|
||||
|
||||
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
|
||||
.llm_database_url
|
||||
.clone()
|
||||
.zip(config.llm_database_max_connections)
|
||||
{
|
||||
let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
|
||||
llm_db_options.max_connections(llm_database_max_connections);
|
||||
let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
|
||||
llm_db.initialize().await?;
|
||||
Some(Arc::new(llm_db))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let livekit_client = if let Some(((server, key), secret)) = config
|
||||
.livekit_server
|
||||
.as_ref()
|
||||
@@ -316,18 +271,10 @@ impl AppState {
|
||||
};
|
||||
|
||||
let db = Arc::new(db);
|
||||
let stripe_client = build_stripe_client(&config).map(Arc::new).log_err();
|
||||
let this = Self {
|
||||
db: db.clone(),
|
||||
llm_db,
|
||||
livekit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
real_stripe_client: stripe_client.clone(),
|
||||
stripe_client: stripe_client
|
||||
.map(|stripe_client| Arc::new(RealStripeClient::new(stripe_client)) as _),
|
||||
executor,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
@@ -340,14 +287,6 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
|
||||
let api_key = config
|
||||
.stripe_api_key
|
||||
.as_ref()
|
||||
.context("missing stripe_api_key")?;
|
||||
Ok(stripe::Client::new(api_key))
|
||||
}
|
||||
|
||||
async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
|
||||
let keys = aws_sdk_s3::config::Credentials::new(
|
||||
config
|
||||
|
||||
@@ -1,12 +1 @@
|
||||
pub mod db;
|
||||
mod token;
|
||||
|
||||
pub use token::*;
|
||||
|
||||
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
|
||||
|
||||
/// The name of the feature flag that bypasses the account age check.
|
||||
pub const BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG: &str = "bypass-account-age-check";
|
||||
|
||||
/// The minimum account age an account must have in order to use the LLM service.
|
||||
pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
|
||||
|
||||
@@ -1,30 +1,9 @@
|
||||
mod ids;
|
||||
mod queries;
|
||||
mod seed;
|
||||
mod tables;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::HashMap;
|
||||
pub use ids::*;
|
||||
pub use seed::*;
|
||||
pub use tables::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use tests::TestLlmDb;
|
||||
use usage_measure::UsageMeasure;
|
||||
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
pub use sea_orm::ConnectOptions;
|
||||
use sea_orm::prelude::*;
|
||||
use sea_orm::{
|
||||
ActiveValue, DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait,
|
||||
};
|
||||
use sea_orm::{DatabaseConnection, DatabaseTransaction, IsolationLevel, TransactionTrait};
|
||||
|
||||
use crate::Result;
|
||||
use crate::db::TransactionHandle;
|
||||
@@ -36,9 +15,6 @@ pub struct LlmDatabase {
|
||||
pool: DatabaseConnection,
|
||||
#[allow(unused)]
|
||||
executor: Executor,
|
||||
provider_ids: HashMap<LanguageModelProvider, ProviderId>,
|
||||
models: HashMap<(LanguageModelProvider, String), model::Model>,
|
||||
usage_measure_ids: HashMap<UsageMeasure, UsageMeasureId>,
|
||||
#[cfg(test)]
|
||||
runtime: Option<tokio::runtime::Runtime>,
|
||||
}
|
||||
@@ -51,59 +27,11 @@ impl LlmDatabase {
|
||||
options: options.clone(),
|
||||
pool: sea_orm::Database::connect(options).await?,
|
||||
executor,
|
||||
provider_ids: HashMap::default(),
|
||||
models: HashMap::default(),
|
||||
usage_measure_ids: HashMap::default(),
|
||||
#[cfg(test)]
|
||||
runtime: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn initialize(&mut self) -> Result<()> {
|
||||
self.initialize_providers().await?;
|
||||
self.initialize_models().await?;
|
||||
self.initialize_usage_measures().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the list of all known models, with their [`LanguageModelProvider`].
|
||||
pub fn all_models(&self) -> Vec<(LanguageModelProvider, model::Model)> {
|
||||
self.models
|
||||
.iter()
|
||||
.map(|((model_provider, _model_name), model)| (*model_provider, model.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
/// Returns the names of the known models for the given [`LanguageModelProvider`].
|
||||
pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
|
||||
self.models
|
||||
.keys()
|
||||
.filter_map(|(model_provider, model_name)| {
|
||||
if model_provider == &provider {
|
||||
Some(model_name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
.get(&(provider, name.to_string()))
|
||||
.with_context(|| format!("unknown model {provider:?}:{name}"))?)
|
||||
}
|
||||
|
||||
pub fn model_by_id(&self, id: ModelId) -> Result<&model::Model> {
|
||||
Ok(self
|
||||
.models
|
||||
.values()
|
||||
.find(|model| model.id == id)
|
||||
.with_context(|| format!("no model for ID {id:?}"))?)
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ConnectOptions {
|
||||
&self.options
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
use sea_orm::{DbErr, entity::prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::id_type;
|
||||
|
||||
id_type!(BillingEventId);
|
||||
id_type!(ModelId);
|
||||
id_type!(ProviderId);
|
||||
id_type!(RevokedAccessTokenId);
|
||||
id_type!(UsageId);
|
||||
id_type!(UsageMeasureId);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user