Compare commits
66 Commits
list-ui-fo
...
zeta2-cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a4ee4fed7 | ||
|
|
ea4bf46a36 | ||
|
|
05545abab6 | ||
|
|
a85608566d | ||
|
|
69af5261ea | ||
|
|
b9e2f61a38 | ||
|
|
38bbb497dd | ||
|
|
0cc7b4a93c | ||
|
|
cc32bfdfdf | ||
|
|
50de8ddc28 | ||
|
|
f770011d7f | ||
|
|
f2a6b57909 | ||
|
|
96b67ac70e | ||
|
|
64d362cbce | ||
|
|
5d561aa494 | ||
|
|
4ee2daeded | ||
|
|
c27d8e0c7a | ||
|
|
f6c5c68751 | ||
|
|
74e5b848ff | ||
|
|
ee399ebccf | ||
|
|
54c82f2732 | ||
|
|
e14a4ab90d | ||
|
|
0343b5ff06 | ||
|
|
26202e5af2 | ||
|
|
ee912366a3 | ||
|
|
673a98a277 | ||
|
|
5674445a61 | ||
|
|
53513cab23 | ||
|
|
e885a939ba | ||
|
|
a01a2ed0e0 | ||
|
|
af3bc45a26 | ||
|
|
173074f248 | ||
|
|
a7cb64c64d | ||
|
|
c6472fd7a8 | ||
|
|
c0710fa8ca | ||
|
|
f321d02207 | ||
|
|
1c09985fb3 | ||
|
|
d986077592 | ||
|
|
555b6ee4e5 | ||
|
|
6446963a0c | ||
|
|
ceb907e0dc | ||
|
|
3dbccc828e | ||
|
|
853e625259 | ||
|
|
0784bb8192 | ||
|
|
9046091164 | ||
|
|
6384966ab5 | ||
|
|
8b9c74726a | ||
|
|
63586ff2e4 | ||
|
|
35e5aa4e71 | ||
|
|
7ea94a32be | ||
|
|
6d6c3d648a | ||
|
|
53b2f37452 | ||
|
|
92b946e8e5 | ||
|
|
e9b4f59e0f | ||
|
|
989adde57b | ||
|
|
393d6787a3 | ||
|
|
4a582504d4 | ||
|
|
cfb2925169 | ||
|
|
14f4e867aa | ||
|
|
4d54ccf494 | ||
|
|
5b1c87b6a6 | ||
|
|
0fef17baa2 | ||
|
|
526196917b | ||
|
|
a598fbaa73 | ||
|
|
634ae72cad | ||
|
|
98edf1bf0b |
2
.rules
2
.rules
@@ -59,7 +59,7 @@ Trying to update an entity while it's already being updated must be avoided as t
|
||||
|
||||
When `read_with`, `update`, or `update_in` are used with an async context, the closure's return value is wrapped in an `anyhow::Result`.
|
||||
|
||||
`WeakEntity<T>` is a weak handle. It has `read_with`, `update`, and `update_in` methods that work the same, but always return an `anyhow::Result` so that they can fail if the entity no longer exists. This can be useful to avoid memory leaks - if entities have mutually recursive handles to eachother they will never be dropped.
|
||||
`WeakEntity<T>` is a weak handle. It has `read_with`, `update`, and `update_in` methods that work the same, but always return an `anyhow::Result` so that they can fail if the entity no longer exists. This can be useful to avoid memory leaks - if entities have mutually recursive handles to each other they will never be dropped.
|
||||
|
||||
## Concurrency
|
||||
|
||||
|
||||
615
Cargo.lock
generated
615
Cargo.lock
generated
@@ -39,7 +39,6 @@ dependencies = [
|
||||
"util",
|
||||
"uuid",
|
||||
"watch",
|
||||
"which 6.0.3",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -301,6 +300,7 @@ dependencies = [
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"language",
|
||||
"language_model",
|
||||
@@ -416,7 +416,6 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"shlex",
|
||||
"smol",
|
||||
"streaming_diff",
|
||||
"task",
|
||||
@@ -689,6 +688,9 @@ name = "arbitrary"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
|
||||
dependencies = [
|
||||
"derive_arbitrary",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
@@ -1024,7 +1026,6 @@ dependencies = [
|
||||
"util",
|
||||
"watch",
|
||||
"web_search",
|
||||
"which 6.0.3",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zlog",
|
||||
@@ -2189,7 +2190,7 @@ dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"log",
|
||||
@@ -2687,6 +2688,53 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"float8",
|
||||
"gemm 0.17.1",
|
||||
"half",
|
||||
"memmap2",
|
||||
"num-traits",
|
||||
"num_cpus",
|
||||
"rand 0.9.1",
|
||||
"rand_distr",
|
||||
"rayon",
|
||||
"safetensors",
|
||||
"thiserror 1.0.69",
|
||||
"ug",
|
||||
"yoke",
|
||||
"zip 1.1.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-nn"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383"
|
||||
dependencies = [
|
||||
"candle-core",
|
||||
"half",
|
||||
"libc",
|
||||
"num-traits",
|
||||
"rayon",
|
||||
"safetensors",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.1"
|
||||
source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383"
|
||||
dependencies = [
|
||||
"candle-core",
|
||||
"candle-nn",
|
||||
"prost 0.12.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cap-fs-ext"
|
||||
version = "3.4.4"
|
||||
@@ -4637,6 +4685,20 @@ version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b"
|
||||
|
||||
[[package]]
|
||||
name = "denoise"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"candle-core",
|
||||
"candle-onnx",
|
||||
"log",
|
||||
"realfft",
|
||||
"rodio",
|
||||
"rustfft",
|
||||
"thiserror 2.0.12",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der"
|
||||
version = "0.6.1"
|
||||
@@ -4668,6 +4730,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_arbitrary"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "0.99.19"
|
||||
@@ -4823,7 +4896,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users 0.5.0",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4981,6 +5054,25 @@ version = "1.0.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005"
|
||||
|
||||
[[package]]
|
||||
name = "dyn-stack"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"reborrow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dyn-stack"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ec4rs"
|
||||
version = "1.2.0"
|
||||
@@ -5042,6 +5134,36 @@ dependencies = [
|
||||
"zeta",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "edit_prediction_context"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"arrayvec",
|
||||
"clap",
|
||||
"collections",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"log",
|
||||
"ordered-float 2.10.1",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"slotmap",
|
||||
"strum 0.27.1",
|
||||
"text",
|
||||
"tree-sitter",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "editor"
|
||||
version = "0.1.0"
|
||||
@@ -5225,6 +5347,18 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf"
|
||||
|
||||
[[package]]
|
||||
name = "enum-as-inner"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "enumflags2"
|
||||
version = "0.7.11"
|
||||
@@ -5855,6 +5989,18 @@ version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d"
|
||||
|
||||
[[package]]
|
||||
name = "float8"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4203231de188ebbdfb85c11f3c20ca2b063945710de04e7b59268731e728b462"
|
||||
dependencies = [
|
||||
"half",
|
||||
"num-traits",
|
||||
"rand 0.9.1",
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float_next_after"
|
||||
version = "1.0.0"
|
||||
@@ -6309,6 +6455,243 @@ dependencies = [
|
||||
"thread_local",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32"
|
||||
dependencies = [
|
||||
"dyn-stack 0.10.0",
|
||||
"gemm-c32 0.17.1",
|
||||
"gemm-c64 0.17.1",
|
||||
"gemm-common 0.17.1",
|
||||
"gemm-f16 0.17.1",
|
||||
"gemm-f32 0.17.1",
|
||||
"gemm-f64 0.17.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 10.7.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451"
|
||||
dependencies = [
|
||||
"dyn-stack 0.13.0",
|
||||
"gemm-c32 0.18.2",
|
||||
"gemm-c64 0.18.2",
|
||||
"gemm-common 0.18.2",
|
||||
"gemm-f16 0.18.2",
|
||||
"gemm-f32 0.18.2",
|
||||
"gemm-f64 0.18.2",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 11.6.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-c32"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0"
|
||||
dependencies = [
|
||||
"dyn-stack 0.10.0",
|
||||
"gemm-common 0.17.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 10.7.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-c32"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847"
|
||||
dependencies = [
|
||||
"dyn-stack 0.13.0",
|
||||
"gemm-common 0.18.2",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 11.6.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-c64"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a"
|
||||
dependencies = [
|
||||
"dyn-stack 0.10.0",
|
||||
"gemm-common 0.17.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 10.7.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-c64"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf"
|
||||
dependencies = [
|
||||
"dyn-stack 0.13.0",
|
||||
"gemm-common 0.18.2",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 11.6.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-common"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"dyn-stack 0.10.0",
|
||||
"half",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"pulp 0.18.22",
|
||||
"raw-cpuid 10.7.0",
|
||||
"rayon",
|
||||
"seq-macro",
|
||||
"sysctl 0.5.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-common"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"dyn-stack 0.13.0",
|
||||
"half",
|
||||
"libm",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"pulp 0.21.5",
|
||||
"raw-cpuid 11.6.0",
|
||||
"rayon",
|
||||
"seq-macro",
|
||||
"sysctl 0.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-f16"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4"
|
||||
dependencies = [
|
||||
"dyn-stack 0.10.0",
|
||||
"gemm-common 0.17.1",
|
||||
"gemm-f32 0.17.1",
|
||||
"half",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 10.7.0",
|
||||
"rayon",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-f16"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109"
|
||||
dependencies = [
|
||||
"dyn-stack 0.13.0",
|
||||
"gemm-common 0.18.2",
|
||||
"gemm-f32 0.18.2",
|
||||
"half",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 11.6.0",
|
||||
"rayon",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-f32"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113"
|
||||
dependencies = [
|
||||
"dyn-stack 0.10.0",
|
||||
"gemm-common 0.17.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 10.7.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-f32"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864"
|
||||
dependencies = [
|
||||
"dyn-stack 0.13.0",
|
||||
"gemm-common 0.18.2",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 11.6.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-f64"
|
||||
version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0"
|
||||
dependencies = [
|
||||
"dyn-stack 0.10.0",
|
||||
"gemm-common 0.17.1",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 10.7.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm-f64"
|
||||
version = "0.18.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd"
|
||||
dependencies = [
|
||||
"dyn-stack 0.13.0",
|
||||
"gemm-common 0.18.2",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"raw-cpuid 11.6.0",
|
||||
"seq-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generator"
|
||||
version = "0.8.5"
|
||||
@@ -7583,9 +7966,12 @@ version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"num-traits",
|
||||
"rand 0.9.1",
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9179,6 +9565,7 @@ dependencies = [
|
||||
"credentials_provider",
|
||||
"deepseek",
|
||||
"editor",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
"gpui",
|
||||
@@ -9212,6 +9599,7 @@ dependencies = [
|
||||
"vercel",
|
||||
"workspace-hack",
|
||||
"x_ai",
|
||||
"zed_env_vars",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9305,6 +9693,7 @@ dependencies = [
|
||||
"pet-fs",
|
||||
"pet-poetry",
|
||||
"pet-reporter",
|
||||
"pet-virtualenv",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"regex",
|
||||
@@ -10174,6 +10563,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10440,12 +10830,6 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
|
||||
|
||||
[[package]]
|
||||
name = "naga"
|
||||
version = "25.0.1"
|
||||
@@ -10819,6 +11203,7 @@ version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
@@ -12512,6 +12897,15 @@ dependencies = [
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "primal-check"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-crate"
|
||||
version = "3.3.0"
|
||||
@@ -12820,7 +13214,7 @@ dependencies = [
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"multimap 0.8.3",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
"prost 0.9.0",
|
||||
"prost-types 0.9.0",
|
||||
@@ -12837,9 +13231,9 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
|
||||
dependencies = [
|
||||
"bytes 1.10.1",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"log",
|
||||
"multimap 0.10.0",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
"petgraph",
|
||||
"prettyplease",
|
||||
@@ -12870,7 +13264,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
@@ -13011,6 +13405,32 @@ dependencies = [
|
||||
"wasmtime-math",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pulp"
|
||||
version = "0.18.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"libm",
|
||||
"num-complex",
|
||||
"reborrow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pulp"
|
||||
version = "0.21.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cfg-if",
|
||||
"libm",
|
||||
"num-complex",
|
||||
"reborrow",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qoi"
|
||||
version = "0.4.1"
|
||||
@@ -13187,6 +13607,16 @@ dependencies = [
|
||||
"getrandom 0.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand 0.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "range-map"
|
||||
version = "0.2.0"
|
||||
@@ -13252,6 +13682,24 @@ dependencies = [
|
||||
"rgb",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "10.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-window-handle"
|
||||
version = "0.6.2"
|
||||
@@ -13300,6 +13748,21 @@ dependencies = [
|
||||
"font-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "realfft"
|
||||
version = "3.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677"
|
||||
dependencies = [
|
||||
"rustfft",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reborrow"
|
||||
version = "0.5.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
|
||||
|
||||
[[package]]
|
||||
name = "recent_projects"
|
||||
version = "0.1.0"
|
||||
@@ -14116,6 +14579,20 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustfft"
|
||||
version = "6.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6f140db74548f7c9d7cce60912c9ac414e74df5e718dc947d514b051b42f3f4"
|
||||
dependencies = [
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"primal-check",
|
||||
"strength_reduce",
|
||||
"transpose",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
@@ -14340,6 +14817,16 @@ version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "safetensors"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "salsa20"
|
||||
version = "0.10.2"
|
||||
@@ -14721,6 +15208,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "seq-macro"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.221"
|
||||
@@ -15675,6 +16168,12 @@ dependencies = [
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strength_reduce"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
|
||||
|
||||
[[package]]
|
||||
name = "strict-num"
|
||||
version = "0.1.1"
|
||||
@@ -16165,6 +16664,34 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sysctl"
|
||||
version = "0.5.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"byteorder",
|
||||
"enum-as-inner",
|
||||
"libc",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sysctl"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"byteorder",
|
||||
"enum-as-inner",
|
||||
"libc",
|
||||
"thiserror 1.0.69",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sysinfo"
|
||||
version = "0.31.4"
|
||||
@@ -17259,6 +17786,16 @@ dependencies = [
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "transpose"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"strength_reduce",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.25.6"
|
||||
@@ -17620,6 +18157,27 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ug"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "90b70b37e9074642bc5f60bb23247fd072a84314ca9e71cdf8527593406a0dd3"
|
||||
dependencies = [
|
||||
"gemm 0.18.2",
|
||||
"half",
|
||||
"libloading",
|
||||
"memmap2",
|
||||
"num",
|
||||
"num-traits",
|
||||
"num_cpus",
|
||||
"rayon",
|
||||
"safetensors",
|
||||
"serde",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
"yoke",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ui"
|
||||
version = "0.1.0"
|
||||
@@ -18888,7 +19446,7 @@ dependencies = [
|
||||
"reqwest 0.11.27",
|
||||
"scratch",
|
||||
"semver",
|
||||
"zip",
|
||||
"zip 0.6.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -20033,7 +20591,7 @@ dependencies = [
|
||||
"idna",
|
||||
"indexmap",
|
||||
"inout",
|
||||
"itertools 0.12.1",
|
||||
"itertools 0.11.0",
|
||||
"itertools 0.13.0",
|
||||
"jiff",
|
||||
"lazy_static",
|
||||
@@ -20047,6 +20605,7 @@ dependencies = [
|
||||
"lyon_path",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"memmap2",
|
||||
"mime_guess",
|
||||
"miniz_oxide",
|
||||
"mio 1.0.3",
|
||||
@@ -20055,8 +20614,10 @@ dependencies = [
|
||||
"nix 0.29.0",
|
||||
"nix 0.30.1",
|
||||
"nom 7.1.3",
|
||||
"num",
|
||||
"num-bigint",
|
||||
"num-bigint-dig",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
@@ -20072,6 +20633,7 @@ dependencies = [
|
||||
"phf_shared",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost 0.12.6",
|
||||
"prost 0.9.0",
|
||||
"prost-types 0.9.0",
|
||||
"quote",
|
||||
@@ -20079,6 +20641,7 @@ dependencies = [
|
||||
"rand 0.9.1",
|
||||
"rand_chacha 0.3.1",
|
||||
"rand_core 0.6.4",
|
||||
"rand_distr",
|
||||
"regalloc2",
|
||||
"regex",
|
||||
"regex-automata",
|
||||
@@ -20108,6 +20671,7 @@ dependencies = [
|
||||
"sqlx-macros-core",
|
||||
"sqlx-postgres",
|
||||
"sqlx-sqlite",
|
||||
"stable_deref_trait",
|
||||
"strum 0.26.3",
|
||||
"subtle",
|
||||
"syn 1.0.109",
|
||||
@@ -20141,6 +20705,7 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.60.2",
|
||||
"winnow",
|
||||
"zeroize",
|
||||
"zvariant",
|
||||
@@ -20677,6 +21242,7 @@ dependencies = [
|
||||
name = "zed_env_vars"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"gpui",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -20990,6 +21556,21 @@ dependencies = [
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "1.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"crc32fast",
|
||||
"crossbeam-utils",
|
||||
"displaydoc",
|
||||
"indexmap",
|
||||
"num_enum",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zlib-rs"
|
||||
version = "0.5.0"
|
||||
|
||||
@@ -52,10 +52,12 @@ members = [
|
||||
"crates/debugger_tools",
|
||||
"crates/debugger_ui",
|
||||
"crates/deepseek",
|
||||
"crates/denoise",
|
||||
"crates/diagnostics",
|
||||
"crates/docs_preprocessor",
|
||||
"crates/edit_prediction",
|
||||
"crates/edit_prediction_button",
|
||||
"crates/edit_prediction_context",
|
||||
"crates/editor",
|
||||
"crates/eval",
|
||||
"crates/explorer_command_injector",
|
||||
@@ -312,6 +314,7 @@ icons = { path = "crates/icons" }
|
||||
image_viewer = { path = "crates/image_viewer" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
edit_prediction_button = { path = "crates/edit_prediction_button" }
|
||||
edit_prediction_context = { path = "crates/edit_prediction_context" }
|
||||
inspector_ui = { path = "crates/inspector_ui" }
|
||||
install_cli = { path = "crates/install_cli" }
|
||||
jj = { path = "crates/jj" }
|
||||
@@ -582,6 +585,7 @@ pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", re
|
||||
pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
|
||||
portable-pty = "0.9.0"
|
||||
postage = { version = "0.5", features = ["futures-traits"] }
|
||||
pretty_assertions = { version = "1.3.0", features = ["unstable"] }
|
||||
@@ -630,6 +634,7 @@ sha2 = "0.10"
|
||||
shellexpand = "2.1.0"
|
||||
shlex = "1.3.0"
|
||||
simplelog = "0.12.2"
|
||||
slotmap = "1.0.6"
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
smol = "2.0"
|
||||
sqlformat = "0.2"
|
||||
|
||||
@@ -462,8 +462,8 @@
|
||||
"ctrl-k ctrl-w": "workspace::CloseAllItemsAndPanes",
|
||||
"back": "pane::GoBack",
|
||||
"ctrl-alt--": "pane::GoBack",
|
||||
"ctrl-alt-_": "pane::GoForward",
|
||||
"forward": "pane::GoForward",
|
||||
"ctrl-alt-_": "pane::GoForward",
|
||||
"ctrl-alt-g": "search::SelectNextMatch",
|
||||
"f3": "search::SelectNextMatch",
|
||||
"ctrl-alt-shift-g": "search::SelectPreviousMatch",
|
||||
|
||||
@@ -497,6 +497,8 @@
|
||||
"shift-alt-down": "editor::DuplicateLineDown",
|
||||
"shift-alt-right": "editor::SelectLargerSyntaxNode", // Expand selection
|
||||
"shift-alt-left": "editor::SelectSmallerSyntaxNode", // Shrink selection
|
||||
"ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand selection (VSCode version)
|
||||
"ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink selection (VSCode version)
|
||||
"ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection
|
||||
"ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word
|
||||
"ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand
|
||||
|
||||
@@ -914,7 +914,11 @@
|
||||
/// Whether to have terminal cards in the agent panel expanded, showing the whole command output.
|
||||
///
|
||||
/// Default: true
|
||||
"expand_terminal_card": true
|
||||
"expand_terminal_card": true,
|
||||
// Minimum number of lines to display in the agent message editor.
|
||||
//
|
||||
// Default: 4
|
||||
"message_editor_min_lines": 4
|
||||
},
|
||||
// The settings for slash commands.
|
||||
"slash_commands": {
|
||||
|
||||
@@ -45,7 +45,6 @@ url.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -7,12 +7,12 @@ use agent_settings::AgentSettings;
|
||||
use collections::HashSet;
|
||||
pub use connection::*;
|
||||
pub use diff::*;
|
||||
use futures::future::Shared;
|
||||
use language::language_settings::FormatOnSave;
|
||||
pub use mention::*;
|
||||
use project::lsp_store::{FormatTrigger, LspFormatTarget};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings as _;
|
||||
use task::{Shell, ShellBuilder};
|
||||
pub use terminal::*;
|
||||
|
||||
use action_log::ActionLog;
|
||||
@@ -34,7 +34,7 @@ use std::rc::Rc;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
|
||||
use ui::App;
|
||||
use util::{ResultExt, get_system_shell};
|
||||
use util::{ResultExt, get_default_system_shell};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -786,7 +786,6 @@ pub struct AcpThread {
|
||||
token_usage: Option<TokenUsage>,
|
||||
prompt_capabilities: acp::PromptCapabilities,
|
||||
_observe_prompt_capabilities: Task<anyhow::Result<()>>,
|
||||
determine_shell: Shared<Task<String>>,
|
||||
terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
|
||||
}
|
||||
|
||||
@@ -873,20 +872,6 @@ impl AcpThread {
|
||||
}
|
||||
});
|
||||
|
||||
let determine_shell = cx
|
||||
.background_spawn(async move {
|
||||
if cfg!(windows) {
|
||||
return get_system_shell();
|
||||
}
|
||||
|
||||
if which::which("bash").is_ok() {
|
||||
"bash".into()
|
||||
} else {
|
||||
get_system_shell()
|
||||
}
|
||||
})
|
||||
.shared();
|
||||
|
||||
Self {
|
||||
action_log,
|
||||
shared_buffers: Default::default(),
|
||||
@@ -901,7 +886,6 @@ impl AcpThread {
|
||||
prompt_capabilities,
|
||||
_observe_prompt_capabilities: task,
|
||||
terminals: HashMap::default(),
|
||||
determine_shell,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1127,9 +1111,33 @@ impl AcpThread {
|
||||
let update = update.into();
|
||||
let languages = self.project.read(cx).languages().clone();
|
||||
|
||||
let ix = self
|
||||
.index_for_tool_call(update.id())
|
||||
.context("Tool call not found")?;
|
||||
let ix = match self.index_for_tool_call(update.id()) {
|
||||
Some(ix) => ix,
|
||||
None => {
|
||||
// Tool call not found - create a failed tool call entry
|
||||
let failed_tool_call = ToolCall {
|
||||
id: update.id().clone(),
|
||||
label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
|
||||
kind: acp::ToolKind::Fetch,
|
||||
content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
|
||||
acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Tool call not found".to_string(),
|
||||
annotations: None,
|
||||
meta: None,
|
||||
}),
|
||||
&languages,
|
||||
cx,
|
||||
))],
|
||||
status: ToolCallStatus::Failed,
|
||||
locations: Vec::new(),
|
||||
resolved_locations: Vec::new(),
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
};
|
||||
self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
|
||||
unreachable!()
|
||||
};
|
||||
@@ -1940,28 +1948,13 @@ impl AcpThread {
|
||||
|
||||
pub fn create_terminal(
|
||||
&self,
|
||||
mut command: String,
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
extra_env: Vec<acp::EnvVariable>,
|
||||
cwd: Option<PathBuf>,
|
||||
output_byte_limit: Option<u64>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<Entity<Terminal>>> {
|
||||
for arg in args {
|
||||
command.push(' ');
|
||||
command.push_str(&arg);
|
||||
}
|
||||
|
||||
let shell_command = if cfg!(windows) {
|
||||
format!("$null | & {{{}}}", command.replace("\"", "'"))
|
||||
} else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
|
||||
// Make sure once we're *inside* the shell, we cd into `cwd`
|
||||
format!("(cd {cwd}; {}) </dev/null", command)
|
||||
} else {
|
||||
format!("({}) </dev/null", command)
|
||||
};
|
||||
let args = vec!["-c".into(), shell_command];
|
||||
|
||||
let env = match &cwd {
|
||||
Some(dir) => self.project.update(cx, |project, cx| {
|
||||
project.directory_environment(dir.as_path().into(), cx)
|
||||
@@ -1982,20 +1975,30 @@ impl AcpThread {
|
||||
|
||||
let project = self.project.clone();
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
let determine_shell = self.determine_shell.clone();
|
||||
|
||||
let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
|
||||
let terminal_task = cx.spawn({
|
||||
let terminal_id = terminal_id.clone();
|
||||
async move |_this, cx| {
|
||||
let program = determine_shell.await;
|
||||
let env = env.await;
|
||||
let (command, args) = ShellBuilder::new(
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project
|
||||
.remote_client()
|
||||
.and_then(|r| r.read(cx).default_system_shell())
|
||||
})?
|
||||
.as_deref(),
|
||||
&Shell::Program(get_default_system_shell()),
|
||||
)
|
||||
.redirect_stdin_to_dev_null()
|
||||
.build(Some(command), &args);
|
||||
let terminal = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_terminal_task(
|
||||
task::SpawnInTerminal {
|
||||
command: Some(program),
|
||||
args,
|
||||
command: Some(command.clone()),
|
||||
args: args.clone(),
|
||||
cwd: cwd.clone(),
|
||||
env,
|
||||
..Default::default()
|
||||
@@ -2008,7 +2011,7 @@ impl AcpThread {
|
||||
cx.new(|cx| {
|
||||
Terminal::new(
|
||||
terminal_id,
|
||||
command,
|
||||
&format!("{} {}", command, args.join(" ")),
|
||||
cwd,
|
||||
output_byte_limit.map(|l| l as usize),
|
||||
terminal,
|
||||
@@ -3181,4 +3184,65 @@ mod tests {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let connection = Rc::new(FakeAgentConnection::new());
|
||||
let thread = cx
|
||||
.update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Try to update a tool call that doesn't exist
|
||||
let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
|
||||
thread.update(cx, |thread, cx| {
|
||||
let result = thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
|
||||
id: nonexistent_id.clone(),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
..Default::default()
|
||||
},
|
||||
meta: None,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
|
||||
// The update should succeed (not return an error)
|
||||
assert!(result.is_ok());
|
||||
|
||||
// There should now be exactly one entry in the thread
|
||||
assert_eq!(thread.entries.len(), 1);
|
||||
|
||||
// The entry should be a failed tool call
|
||||
if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
|
||||
assert_eq!(tool_call.id, nonexistent_id);
|
||||
assert!(matches!(tool_call.status, ToolCallStatus::Failed));
|
||||
assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
|
||||
|
||||
// Check that the content contains the error message
|
||||
assert_eq!(tool_call.content.len(), 1);
|
||||
if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
|
||||
match content_block {
|
||||
ContentBlock::Markdown { markdown } => {
|
||||
let markdown_text = markdown.read(cx).source();
|
||||
assert!(markdown_text.contains("Tool call not found"));
|
||||
}
|
||||
ContentBlock::Empty => panic!("Expected markdown content, got empty"),
|
||||
ContentBlock::ResourceLink { .. } => {
|
||||
panic!("Expected markdown content, got resource link")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ pub struct TerminalOutput {
|
||||
impl Terminal {
|
||||
pub fn new(
|
||||
id: acp::TerminalId,
|
||||
command: String,
|
||||
command_label: &str,
|
||||
working_dir: Option<PathBuf>,
|
||||
output_byte_limit: Option<usize>,
|
||||
terminal: Entity<terminal::Terminal>,
|
||||
@@ -40,7 +40,7 @@ impl Terminal {
|
||||
id,
|
||||
command: cx.new(|cx| {
|
||||
Markdown::new(
|
||||
format!("```\n{}\n```", command).into(),
|
||||
format!("```\n{}\n```", command_label).into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage, VersionCheckType};
|
||||
use auto_update::{AutoUpdateStatus, AutoUpdater, DismissMessage, VersionCheckType};
|
||||
use editor::Editor;
|
||||
use extension_host::{ExtensionOperation, ExtensionStore};
|
||||
use futures::StreamExt;
|
||||
@@ -280,18 +280,13 @@ impl ActivityIndicator {
|
||||
});
|
||||
}
|
||||
|
||||
fn dismiss_error_message(
|
||||
&mut self,
|
||||
_: &DismissErrorMessage,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let error_dismissed = if let Some(updater) = &self.auto_updater {
|
||||
updater.update(cx, |updater, cx| updater.dismiss_error(cx))
|
||||
fn dismiss_message(&mut self, _: &DismissMessage, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let dismissed = if let Some(updater) = &self.auto_updater {
|
||||
updater.update(cx, |updater, cx| updater.dismiss(cx))
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if error_dismissed {
|
||||
if dismissed {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -513,7 +508,7 @@ impl ActivityIndicator {
|
||||
on_click: Some(Arc::new(move |this, window, cx| {
|
||||
this.statuses
|
||||
.retain(|status| !downloading.contains(&status.name));
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
this.dismiss_message(&DismissMessage, window, cx)
|
||||
})),
|
||||
tooltip_message: None,
|
||||
});
|
||||
@@ -542,7 +537,7 @@ impl ActivityIndicator {
|
||||
on_click: Some(Arc::new(move |this, window, cx| {
|
||||
this.statuses
|
||||
.retain(|status| !checking_for_update.contains(&status.name));
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
this.dismiss_message(&DismissMessage, window, cx)
|
||||
})),
|
||||
tooltip_message: None,
|
||||
});
|
||||
@@ -650,13 +645,14 @@ impl ActivityIndicator {
|
||||
.and_then(|updater| match &updater.read(cx).status() {
|
||||
AutoUpdateStatus::Checking => Some(Content {
|
||||
icon: Some(
|
||||
Icon::new(IconName::Download)
|
||||
Icon::new(IconName::LoadCircle)
|
||||
.size(IconSize::Small)
|
||||
.with_rotate_animation(3)
|
||||
.into_any_element(),
|
||||
),
|
||||
message: "Checking for Zed updates…".to_string(),
|
||||
on_click: Some(Arc::new(|this, window, cx| {
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
this.dismiss_message(&DismissMessage, window, cx)
|
||||
})),
|
||||
tooltip_message: None,
|
||||
}),
|
||||
@@ -668,19 +664,20 @@ impl ActivityIndicator {
|
||||
),
|
||||
message: "Downloading Zed update…".to_string(),
|
||||
on_click: Some(Arc::new(|this, window, cx| {
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
this.dismiss_message(&DismissMessage, window, cx)
|
||||
})),
|
||||
tooltip_message: Some(Self::version_tooltip_message(version)),
|
||||
}),
|
||||
AutoUpdateStatus::Installing { version } => Some(Content {
|
||||
icon: Some(
|
||||
Icon::new(IconName::Download)
|
||||
Icon::new(IconName::LoadCircle)
|
||||
.size(IconSize::Small)
|
||||
.with_rotate_animation(3)
|
||||
.into_any_element(),
|
||||
),
|
||||
message: "Installing Zed update…".to_string(),
|
||||
on_click: Some(Arc::new(|this, window, cx| {
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
this.dismiss_message(&DismissMessage, window, cx)
|
||||
})),
|
||||
tooltip_message: Some(Self::version_tooltip_message(version)),
|
||||
}),
|
||||
@@ -690,17 +687,18 @@ impl ActivityIndicator {
|
||||
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
|
||||
tooltip_message: Some(Self::version_tooltip_message(version)),
|
||||
}),
|
||||
AutoUpdateStatus::Errored => Some(Content {
|
||||
AutoUpdateStatus::Errored { error } => Some(Content {
|
||||
icon: Some(
|
||||
Icon::new(IconName::Warning)
|
||||
.size(IconSize::Small)
|
||||
.into_any_element(),
|
||||
),
|
||||
message: "Auto update failed".to_string(),
|
||||
message: "Failed to update Zed".to_string(),
|
||||
on_click: Some(Arc::new(|this, window, cx| {
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
window.dispatch_action(Box::new(workspace::OpenLog), cx);
|
||||
this.dismiss_message(&DismissMessage, window, cx);
|
||||
})),
|
||||
tooltip_message: None,
|
||||
tooltip_message: Some(format!("{error}")),
|
||||
}),
|
||||
AutoUpdateStatus::Idle => None,
|
||||
})
|
||||
@@ -738,7 +736,7 @@ impl ActivityIndicator {
|
||||
})),
|
||||
message,
|
||||
on_click: Some(Arc::new(|this, window, cx| {
|
||||
this.dismiss_error_message(&Default::default(), window, cx)
|
||||
this.dismiss_message(&Default::default(), window, cx)
|
||||
})),
|
||||
tooltip_message: None,
|
||||
})
|
||||
@@ -777,7 +775,7 @@ impl Render for ActivityIndicator {
|
||||
let result = h_flex()
|
||||
.id("activity-indicator")
|
||||
.on_action(cx.listener(Self::show_error_message))
|
||||
.on_action(cx.listener(Self::dismiss_error_message));
|
||||
.on_action(cx.listener(Self::dismiss_message));
|
||||
let Some(content) = self.content_to_render(cx) else {
|
||||
return result;
|
||||
};
|
||||
|
||||
@@ -30,6 +30,7 @@ fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
gpui_tokio = { workspace = true, optional = true }
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
|
||||
@@ -7,15 +7,19 @@ mod gemini;
|
||||
pub mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
use client::ProxySettings;
|
||||
use collections::HashMap;
|
||||
pub use custom::*;
|
||||
use fs::Fs;
|
||||
pub use gemini::*;
|
||||
use http_client::read_no_proxy_from_env;
|
||||
use project::agent_server_store::AgentServerStore;
|
||||
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use std::{any::Any, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
pub use acp::AcpConnection;
|
||||
@@ -77,3 +81,25 @@ impl dyn AgentServer {
|
||||
self.into_any().downcast().ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the default proxy environment variables to pass through to the agent
|
||||
pub fn load_proxy_env(cx: &mut App) -> HashMap<String, String> {
|
||||
let proxy_url = cx
|
||||
.read_global(|settings: &SettingsStore, _| settings.get::<ProxySettings>(None).proxy_url());
|
||||
let mut env = HashMap::default();
|
||||
|
||||
if let Some(proxy_url) = &proxy_url {
|
||||
let env_var = if proxy_url.scheme() == "https" {
|
||||
"HTTPS_PROXY"
|
||||
} else {
|
||||
"HTTP_PROXY"
|
||||
};
|
||||
env.insert(env_var.to_owned(), proxy_url.to_string());
|
||||
}
|
||||
|
||||
if let Some(no_proxy) = read_no_proxy_from_env() {
|
||||
env.insert("NO_PROXY".to_owned(), no_proxy);
|
||||
}
|
||||
|
||||
env
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AppContext as _, SharedString, Task};
|
||||
use project::agent_server_store::{AllAgentServersSettings, CLAUDE_CODE_NAME};
|
||||
|
||||
use crate::{AgentServer, AgentServerDelegate};
|
||||
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -60,6 +60,7 @@ impl AgentServer for ClaudeCode {
|
||||
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let store = delegate.store.downgrade();
|
||||
let extra_env = load_proxy_env(cx);
|
||||
let default_mode = self.default_mode(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
@@ -70,7 +71,7 @@ impl AgentServer for ClaudeCode {
|
||||
.context("Claude Code is not registered")?;
|
||||
anyhow::Ok(agent.get_command(
|
||||
root_dir.as_deref(),
|
||||
Default::default(),
|
||||
extra_env,
|
||||
delegate.status_tx,
|
||||
delegate.new_version_available,
|
||||
&mut cx.to_async(),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::AgentServerDelegate;
|
||||
use crate::{AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
@@ -65,6 +65,7 @@ impl crate::AgentServer for CustomAgentServer {
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let default_mode = self.default_mode(cx);
|
||||
let store = delegate.store.downgrade();
|
||||
let extra_env = load_proxy_env(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (command, root_dir, login) = store
|
||||
@@ -76,7 +77,7 @@ impl crate::AgentServer for CustomAgentServer {
|
||||
})?;
|
||||
anyhow::Ok(agent.get_command(
|
||||
root_dir.as_deref(),
|
||||
Default::default(),
|
||||
extra_env,
|
||||
delegate.status_tx,
|
||||
delegate.new_version_available,
|
||||
&mut cx.to_async(),
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use std::rc::Rc;
|
||||
use std::{any::Any, path::Path};
|
||||
|
||||
use crate::{AgentServer, AgentServerDelegate};
|
||||
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::ProxySettings;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext, SharedString, Task};
|
||||
use gpui::{App, SharedString, Task};
|
||||
use language_models::provider::google::GoogleLanguageModelProvider;
|
||||
use project::agent_server_store::GEMINI_NAME;
|
||||
use settings::SettingsStore;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Gemini;
|
||||
@@ -37,17 +34,20 @@ impl AgentServer for Gemini {
|
||||
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string());
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let store = delegate.store.downgrade();
|
||||
let proxy_url = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<ProxySettings>(None).proxy.clone()
|
||||
});
|
||||
let mut extra_env = load_proxy_env(cx);
|
||||
let default_mode = self.default_mode(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let mut extra_env = HashMap::default();
|
||||
if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
|
||||
extra_env.insert("GEMINI_API_KEY".into(), api_key.key);
|
||||
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
|
||||
|
||||
if let Some(api_key) = cx
|
||||
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
|
||||
.await
|
||||
.ok()
|
||||
{
|
||||
extra_env.insert("GEMINI_API_KEY".into(), api_key);
|
||||
}
|
||||
let (mut command, root_dir, login) = store
|
||||
let (command, root_dir, login) = store
|
||||
.update(cx, |store, cx| {
|
||||
let agent = store
|
||||
.get_external_agent(&GEMINI_NAME.into())
|
||||
@@ -62,14 +62,6 @@ impl AgentServer for Gemini {
|
||||
})??
|
||||
.await?;
|
||||
|
||||
// Add proxy flag if proxy settings are configured in Zed and not in the args
|
||||
if let Some(proxy_url_value) = &proxy_url
|
||||
&& !command.args.iter().any(|arg| arg.contains("--proxy"))
|
||||
{
|
||||
command.args.push("--proxy".into());
|
||||
command.args.push(proxy_url_value.clone());
|
||||
}
|
||||
|
||||
let connection = crate::acp::connect(
|
||||
name,
|
||||
command,
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
use agent_client_protocol as acp;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::AgentServerCommand;
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, SharedString};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsKey, SettingsSources, SettingsUi};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
AllAgentServersSettings::register(cx);
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, SettingsUi, SettingsKey)]
|
||||
#[settings_key(key = "agent_servers")]
|
||||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<BuiltinAgentServerSettings>,
|
||||
pub claude: Option<BuiltinAgentServerSettings>,
|
||||
|
||||
/// Custom agent servers configured by the user
|
||||
#[serde(flatten)]
|
||||
pub custom: HashMap<SharedString, CustomAgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)]
|
||||
pub struct BuiltinAgentServerSettings {
|
||||
/// Absolute path to a binary to be used when launching this agent.
|
||||
///
|
||||
/// This can be used to run a specific binary without automatic downloads or searching `$PATH`.
|
||||
#[serde(rename = "command")]
|
||||
pub path: Option<PathBuf>,
|
||||
/// If a binary is specified in `command`, it will be passed these arguments.
|
||||
pub args: Option<Vec<String>>,
|
||||
/// If a binary is specified in `command`, it will be passed these environment variables.
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
/// Whether to skip searching `$PATH` for an agent server binary when
|
||||
/// launching this agent.
|
||||
///
|
||||
/// This has no effect if a `command` is specified. Otherwise, when this is
|
||||
/// `false`, Zed will search `$PATH` for an agent server binary and, if one
|
||||
/// is found, use it for threads with this agent. If no agent binary is
|
||||
/// found on `$PATH`, Zed will automatically install and use its own binary.
|
||||
/// When this is `true`, Zed will not search `$PATH`, and will always use
|
||||
/// its own binary.
|
||||
///
|
||||
/// Default: true
|
||||
pub ignore_system_version: Option<bool>,
|
||||
/// The default mode for new threads.
|
||||
///
|
||||
/// Note: Not all agents support modes.
|
||||
///
|
||||
/// Default: None
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default_mode: Option<acp::SessionModeId>,
|
||||
}
|
||||
|
||||
impl BuiltinAgentServerSettings {
|
||||
pub(crate) fn custom_command(self) -> Option<AgentServerCommand> {
|
||||
self.path.map(|path| AgentServerCommand {
|
||||
path,
|
||||
args: self.args.unwrap_or_default(),
|
||||
env: self.env,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AgentServerCommand> for BuiltinAgentServerSettings {
|
||||
fn from(value: AgentServerCommand) -> Self {
|
||||
BuiltinAgentServerSettings {
|
||||
path: Some(value.path),
|
||||
args: Some(value.args),
|
||||
env: value.env,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)]
|
||||
pub struct CustomAgentServerSettings {
|
||||
#[serde(flatten)]
|
||||
pub command: AgentServerCommand,
|
||||
/// The default mode for new threads.
|
||||
///
|
||||
/// Note: Not all agents support modes.
|
||||
///
|
||||
/// Default: None
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default_mode: Option<acp::SessionModeId>,
|
||||
}
|
||||
|
||||
impl settings::Settings for AllAgentServersSettings {
|
||||
type FileContent = Self;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for AllAgentServersSettings {
|
||||
gemini,
|
||||
claude,
|
||||
custom,
|
||||
} in sources.defaults_and_customizations()
|
||||
{
|
||||
if gemini.is_some() {
|
||||
settings.gemini = gemini.clone();
|
||||
}
|
||||
if claude.is_some() {
|
||||
settings.claude = claude.clone();
|
||||
}
|
||||
|
||||
// Merge custom agents
|
||||
for (name, config) in custom {
|
||||
// Skip built-in agent names to avoid conflicts
|
||||
if name != "gemini" && name != "claude" {
|
||||
settings.custom.insert(name.clone(), config.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
@@ -75,6 +75,7 @@ pub struct AgentSettings {
|
||||
pub expand_edit_card: bool,
|
||||
pub expand_terminal_card: bool,
|
||||
pub use_modifier_to_send: bool,
|
||||
pub message_editor_min_lines: usize,
|
||||
}
|
||||
|
||||
impl AgentSettings {
|
||||
@@ -107,6 +108,10 @@ impl AgentSettings {
|
||||
model,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn set_message_editor_max_lines(&self) -> usize {
|
||||
self.message_editor_min_lines * 2
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
@@ -320,6 +325,10 @@ pub struct AgentSettingsContent {
|
||||
///
|
||||
/// Default: false
|
||||
use_modifier_to_send: Option<bool>,
|
||||
/// Minimum number of lines of height the agent message editor should have.
|
||||
///
|
||||
/// Default: 4
|
||||
message_editor_min_lines: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
|
||||
@@ -355,21 +364,30 @@ impl JsonSchema for LanguageModelProviderSetting {
|
||||
}
|
||||
|
||||
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
|
||||
// list the builtin providers as a subset so that we still auto complete them in the settings
|
||||
json_schema!({
|
||||
"enum": [
|
||||
"amazon-bedrock",
|
||||
"anthropic",
|
||||
"copilot_chat",
|
||||
"deepseek",
|
||||
"google",
|
||||
"lmstudio",
|
||||
"mistral",
|
||||
"ollama",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"vercel",
|
||||
"x_ai",
|
||||
"zed.dev"
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"amazon-bedrock",
|
||||
"anthropic",
|
||||
"copilot_chat",
|
||||
"deepseek",
|
||||
"google",
|
||||
"lmstudio",
|
||||
"mistral",
|
||||
"ollama",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"vercel",
|
||||
"x_ai",
|
||||
"zed.dev"
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
@@ -472,6 +490,10 @@ impl Settings for AgentSettings {
|
||||
&mut settings.use_modifier_to_send,
|
||||
value.use_modifier_to_send,
|
||||
);
|
||||
merge(
|
||||
&mut settings.message_editor_min_lines,
|
||||
value.message_editor_min_lines,
|
||||
);
|
||||
|
||||
settings
|
||||
.model_parameters
|
||||
|
||||
@@ -80,7 +80,6 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
shlex.workspace = true
|
||||
smol.workspace = true
|
||||
streaming_diff.workspace = true
|
||||
task.workspace = true
|
||||
|
||||
@@ -1099,11 +1099,16 @@ impl MessageEditor {
|
||||
}
|
||||
|
||||
pub fn insert_selections(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let buffer = self.editor.read(cx).buffer().clone();
|
||||
let Some(buffer) = buffer.read(cx).as_singleton() else {
|
||||
let editor = self.editor.read(cx);
|
||||
let editor_buffer = editor.buffer().read(cx);
|
||||
let Some(buffer) = editor_buffer.as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
|
||||
let cursor_anchor = editor.selections.newest_anchor().head();
|
||||
let cursor_offset = cursor_anchor.to_offset(&editor_buffer.snapshot(cx));
|
||||
let anchor = buffer.update(cx, |buffer, _cx| {
|
||||
buffer.anchor_before(cursor_offset.min(buffer.len()))
|
||||
});
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
@@ -1117,13 +1122,7 @@ impl MessageEditor {
|
||||
return;
|
||||
};
|
||||
self.editor.update(cx, |message_editor, cx| {
|
||||
message_editor.edit(
|
||||
[(
|
||||
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
|
||||
completion.new_text,
|
||||
)],
|
||||
cx,
|
||||
);
|
||||
message_editor.edit([(cursor_anchor..cursor_anchor, completion.new_text)], cx);
|
||||
});
|
||||
if let Some(confirm) = completion.confirm {
|
||||
confirm(CompletionIntent::Complete, window, cx);
|
||||
|
||||
@@ -107,13 +107,15 @@ impl ModeSelector {
|
||||
.text_sm()
|
||||
.text_color(Color::Muted.color(cx))
|
||||
.child("Hold")
|
||||
.child(div().pt_0p5().children(ui::render_modifiers(
|
||||
&gpui::Modifiers::secondary_key(),
|
||||
PlatformStyle::platform(),
|
||||
None,
|
||||
Some(ui::TextSize::Default.rems(cx).into()),
|
||||
true,
|
||||
)))
|
||||
.child(h_flex().flex_shrink_0().children(
|
||||
ui::render_modifiers(
|
||||
&gpui::Modifiers::secondary_key(),
|
||||
PlatformStyle::platform(),
|
||||
None,
|
||||
Some(ui::TextSize::Default.rems(cx).into()),
|
||||
true,
|
||||
),
|
||||
))
|
||||
.child(div().map(|this| {
|
||||
if is_default {
|
||||
this.child("to also unset as default")
|
||||
|
||||
@@ -500,20 +500,24 @@ impl Render for AcpThreadHistory {
|
||||
),
|
||||
)
|
||||
} else {
|
||||
view.pr_5()
|
||||
.child(
|
||||
uniform_list(
|
||||
"thread-history",
|
||||
self.visible_items.len(),
|
||||
cx.processor(|this, range: Range<usize>, window, cx| {
|
||||
this.render_list_items(range, window, cx)
|
||||
}),
|
||||
)
|
||||
.p_1()
|
||||
.track_scroll(self.scroll_handle.clone())
|
||||
.flex_grow(),
|
||||
view.child(
|
||||
uniform_list(
|
||||
"thread-history",
|
||||
self.visible_items.len(),
|
||||
cx.processor(|this, range: Range<usize>, window, cx| {
|
||||
this.render_list_items(range, window, cx)
|
||||
}),
|
||||
)
|
||||
.vertical_scrollbar_for(self.scroll_handle.clone(), window, cx)
|
||||
.p_1()
|
||||
.pr_4()
|
||||
.track_scroll(self.scroll_handle.clone())
|
||||
.flex_grow(),
|
||||
)
|
||||
.vertical_scrollbar_for(
|
||||
self.scroll_handle.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use agent_client_protocol::{self as acp, PromptCapabilities};
|
||||
use agent_servers::{AgentServer, AgentServerDelegate};
|
||||
use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
|
||||
use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use arrayvec::ArrayVec;
|
||||
use audio::{Audio, Sound};
|
||||
use buffer_diff::BufferDiff;
|
||||
@@ -71,9 +71,6 @@ use crate::{
|
||||
RejectOnce, ToggleBurnMode, ToggleProfileSelector,
|
||||
};
|
||||
|
||||
pub const MIN_EDITOR_LINES: usize = 4;
|
||||
pub const MAX_EDITOR_LINES: usize = 8;
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
enum ThreadFeedback {
|
||||
Positive,
|
||||
@@ -357,8 +354,8 @@ impl AcpThreadView {
|
||||
agent.name(),
|
||||
&placeholder,
|
||||
editor::EditorMode::AutoHeight {
|
||||
min_lines: MIN_EDITOR_LINES,
|
||||
max_lines: Some(MAX_EDITOR_LINES),
|
||||
min_lines: AgentSettings::get_global(cx).message_editor_min_lines,
|
||||
max_lines: Some(AgentSettings::get_global(cx).set_message_editor_max_lines()),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
@@ -857,10 +854,11 @@ impl AcpThreadView {
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
let agent_settings = AgentSettings::get_global(cx);
|
||||
editor.set_mode(
|
||||
EditorMode::AutoHeight {
|
||||
min_lines: MIN_EDITOR_LINES,
|
||||
max_lines: Some(MAX_EDITOR_LINES),
|
||||
min_lines: agent_settings.message_editor_min_lines,
|
||||
max_lines: Some(agent_settings.set_message_editor_max_lines()),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
@@ -1584,19 +1582,6 @@ impl AcpThreadView {
|
||||
|
||||
window.spawn(cx, async move |cx| {
|
||||
let mut task = login.clone();
|
||||
task.command = task
|
||||
.command
|
||||
.map(|command| anyhow::Ok(shlex::try_quote(&command)?.to_string()))
|
||||
.transpose()?;
|
||||
task.args = task
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
Ok(shlex::try_quote(arg)
|
||||
.context("Failed to quote argument")?
|
||||
.to_string())
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
task.full_label = task.label.clone();
|
||||
task.id = task::TaskId(format!("external-agent-{}-login", task.label));
|
||||
task.command_label = task.label.clone();
|
||||
@@ -3197,10 +3182,14 @@ impl AcpThreadView {
|
||||
};
|
||||
|
||||
Button::new(SharedString::from(method_id.clone()), name)
|
||||
.when(ix == 0, |el| {
|
||||
el.style(ButtonStyle::Tinted(ui::TintColor::Warning))
|
||||
})
|
||||
.label_size(LabelSize::Small)
|
||||
.map(|this| {
|
||||
if ix == 0 {
|
||||
this.style(ButtonStyle::Tinted(TintColor::Warning))
|
||||
} else {
|
||||
this.style(ButtonStyle::Outlined)
|
||||
}
|
||||
})
|
||||
.on_click({
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
telemetry::event!(
|
||||
@@ -5680,6 +5669,23 @@ pub(crate) mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_spawn_external_agent_login_handles_spaces(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Verify paths with spaces aren't pre-quoted
|
||||
let path_with_spaces = "/Users/test/Library/Application Support/Zed/cli.js";
|
||||
let login_task = task::SpawnInTerminal {
|
||||
command: Some("node".to_string()),
|
||||
args: vec![path_with_spaces.to_string(), "/login".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Args should be passed as-is, not pre-quoted
|
||||
assert!(!login_task.args[0].starts_with('"'));
|
||||
assert!(!login_task.args[0].starts_with('\''));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
@@ -274,13 +274,28 @@ impl AgentConfiguration {
|
||||
*is_expanded = !*is_expanded;
|
||||
}
|
||||
})),
|
||||
)
|
||||
.when(provider.is_authenticated(cx), |parent| {
|
||||
),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.px_2()
|
||||
.gap_1()
|
||||
.when(is_expanded, |parent| match configuration_view {
|
||||
Some(configuration_view) => parent.child(configuration_view),
|
||||
None => parent.child(Label::new(format!(
|
||||
"No configuration view for {provider_name}",
|
||||
))),
|
||||
})
|
||||
.when(is_expanded && provider.is_authenticated(cx), |parent| {
|
||||
parent.child(
|
||||
Button::new(
|
||||
SharedString::from(format!("new-thread-{provider_id}")),
|
||||
"Start New Thread",
|
||||
)
|
||||
.full_width()
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Thread)
|
||||
.icon_size(IconSize::Small)
|
||||
@@ -297,17 +312,6 @@ impl AgentConfiguration {
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.w_full()
|
||||
.px_2()
|
||||
.when(is_expanded, |parent| match configuration_view {
|
||||
Some(configuration_view) => parent.child(configuration_view),
|
||||
None => parent.child(Label::new(format!(
|
||||
"No configuration view for {provider_name}",
|
||||
))),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_provider_configuration_section(
|
||||
@@ -561,11 +565,28 @@ impl AgentConfiguration {
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.children(
|
||||
context_server_ids.into_iter().map(|context_server_id| {
|
||||
self.render_context_server(context_server_id, window, cx)
|
||||
}),
|
||||
)
|
||||
.map(|parent| {
|
||||
if context_server_ids.is_empty() {
|
||||
parent.child(
|
||||
h_flex()
|
||||
.p_4()
|
||||
.justify_center()
|
||||
.border_1()
|
||||
.border_dashed()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.rounded_sm()
|
||||
.child(
|
||||
Label::new("No MCP servers added yet.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
parent.children(context_server_ids.into_iter().map(|context_server_id| {
|
||||
self.render_context_server(context_server_id, window, cx)
|
||||
}))
|
||||
}
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
@@ -818,6 +839,8 @@ impl AgentConfiguration {
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.flex_1()
|
||||
.min_w_0()
|
||||
.child(
|
||||
Disclosure::new(
|
||||
"tool-list-disclosure",
|
||||
@@ -841,17 +864,19 @@ impl AgentConfiguration {
|
||||
.id(SharedString::from(format!("tooltip-{}", item_id)))
|
||||
.h_full()
|
||||
.w_3()
|
||||
.mx_1()
|
||||
.ml_1()
|
||||
.mr_1p5()
|
||||
.justify_center()
|
||||
.tooltip(Tooltip::text(tooltip_text))
|
||||
.child(status_indicator),
|
||||
)
|
||||
.child(Label::new(item_id).ml_0p5())
|
||||
.child(Label::new(item_id).truncate())
|
||||
.child(
|
||||
div()
|
||||
.id("extension-source")
|
||||
.mt_0p5()
|
||||
.mx_1()
|
||||
.flex_none()
|
||||
.tooltip(Tooltip::text(source_tooltip))
|
||||
.child(
|
||||
Icon::new(source_icon)
|
||||
@@ -873,7 +898,8 @@ impl AgentConfiguration {
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.gap_0p5()
|
||||
.flex_none()
|
||||
.child(context_server_configuration_menu)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
@@ -1123,6 +1149,7 @@ impl AgentConfiguration {
|
||||
SharedString::from(format!("start_acp_thread-{name}")),
|
||||
"Start New Thread",
|
||||
)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Thread)
|
||||
.icon_position(IconPosition::Start)
|
||||
|
||||
@@ -63,7 +63,6 @@ ui.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
web_search.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
assistant_tool::init(cx);
|
||||
|
||||
let registry = ToolRegistry::global(cx);
|
||||
registry.register_tool(TerminalTool::new(cx));
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
|
||||
@@ -160,7 +160,7 @@ mod tests {
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
// This output is marlformed, so we're doing our best effort
|
||||
// This output is malformed, so we're doing our best effort
|
||||
"Hello world\n```\n\nThe end\n".to_string()
|
||||
);
|
||||
}
|
||||
@@ -182,7 +182,7 @@ mod tests {
|
||||
&mut parser,
|
||||
&mut rng
|
||||
),
|
||||
// This output is marlformed, so we're doing our best effort
|
||||
// This output is malformed, so we're doing our best effort
|
||||
"```\nHello world\n```\n".to_string()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -916,7 +916,7 @@ impl Loader {
|
||||
if !found_non_static {
|
||||
found_non_static = true;
|
||||
eprintln!(
|
||||
"Warning: Found non-static non-tree-sitter functions in the external scannner"
|
||||
"Warning: Found non-static non-tree-sitter functions in the external scanner"
|
||||
);
|
||||
}
|
||||
eprintln!(" `{function_name}`");
|
||||
|
||||
@@ -6,7 +6,7 @@ use action_log::ActionLog;
|
||||
use agent_settings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use futures::FutureExt as _;
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement,
|
||||
WeakEntity, Window,
|
||||
@@ -26,11 +26,12 @@ use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use task::{Shell, ShellBuilder};
|
||||
use terminal_view::TerminalView;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*};
|
||||
use util::{
|
||||
ResultExt, get_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
|
||||
ResultExt, get_default_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
|
||||
time::duration_alt_display,
|
||||
};
|
||||
use workspace::Workspace;
|
||||
@@ -45,29 +46,10 @@ pub struct TerminalToolInput {
|
||||
cd: String,
|
||||
}
|
||||
|
||||
pub struct TerminalTool {
|
||||
determine_shell: Shared<Task<String>>,
|
||||
}
|
||||
pub struct TerminalTool;
|
||||
|
||||
impl TerminalTool {
|
||||
pub const NAME: &str = "terminal";
|
||||
|
||||
pub(crate) fn new(cx: &mut App) -> Self {
|
||||
let determine_shell = cx.background_spawn(async move {
|
||||
if cfg!(windows) {
|
||||
return get_system_shell();
|
||||
}
|
||||
|
||||
if which::which("bash").is_ok() {
|
||||
"bash".into()
|
||||
} else {
|
||||
get_system_shell()
|
||||
}
|
||||
});
|
||||
Self {
|
||||
determine_shell: determine_shell.shared(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Tool for TerminalTool {
|
||||
@@ -135,19 +117,6 @@ impl Tool for TerminalTool {
|
||||
Ok(dir) => dir,
|
||||
Err(err) => return Task::ready(Err(err)).into(),
|
||||
};
|
||||
let program = self.determine_shell.clone();
|
||||
let command = if cfg!(windows) {
|
||||
format!("$null | & {{{}}}", input.command.replace("\"", "'"))
|
||||
} else if let Some(cwd) = working_dir
|
||||
.as_ref()
|
||||
.and_then(|cwd| cwd.as_os_str().to_str())
|
||||
{
|
||||
// Make sure once we're *inside* the shell, we cd into `cwd`
|
||||
format!("(cd {cwd}; {}) </dev/null", input.command)
|
||||
} else {
|
||||
format!("({}) </dev/null", input.command)
|
||||
};
|
||||
let args = vec!["-c".into(), command];
|
||||
|
||||
let cwd = working_dir.clone();
|
||||
let env = match &working_dir {
|
||||
@@ -156,6 +125,11 @@ impl Tool for TerminalTool {
|
||||
}),
|
||||
None => Task::ready(None).shared(),
|
||||
};
|
||||
let remote_shell = project.update(cx, |project, cx| {
|
||||
project
|
||||
.remote_client()
|
||||
.and_then(|r| r.read(cx).default_system_shell())
|
||||
});
|
||||
|
||||
let env = cx.spawn(async move |_| {
|
||||
let mut env = env.await.unwrap_or_default();
|
||||
@@ -171,8 +145,13 @@ impl Tool for TerminalTool {
|
||||
let task = cx.background_spawn(async move {
|
||||
let env = env.await;
|
||||
let pty_system = native_pty_system();
|
||||
let program = program.await;
|
||||
let mut cmd = CommandBuilder::new(program);
|
||||
let (command, args) = ShellBuilder::new(
|
||||
remote_shell.as_deref(),
|
||||
&Shell::Program(get_default_system_shell()),
|
||||
)
|
||||
.redirect_stdin_to_dev_null()
|
||||
.build(Some(input.command.clone()), &[]);
|
||||
let mut cmd = CommandBuilder::new(command);
|
||||
cmd.args(args);
|
||||
for (k, v) in env {
|
||||
cmd.env(k, v);
|
||||
@@ -208,16 +187,22 @@ impl Tool for TerminalTool {
|
||||
};
|
||||
};
|
||||
|
||||
let command = input.command.clone();
|
||||
let terminal = cx.spawn({
|
||||
let project = project.downgrade();
|
||||
async move |cx| {
|
||||
let program = program.await;
|
||||
let (command, args) = ShellBuilder::new(
|
||||
remote_shell.as_deref(),
|
||||
&Shell::Program(get_default_system_shell()),
|
||||
)
|
||||
.redirect_stdin_to_dev_null()
|
||||
.build(Some(input.command), &[]);
|
||||
let env = env.await;
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_terminal_task(
|
||||
task::SpawnInTerminal {
|
||||
command: Some(program),
|
||||
command: Some(command),
|
||||
args,
|
||||
cwd,
|
||||
env,
|
||||
@@ -230,14 +215,8 @@ impl Tool for TerminalTool {
|
||||
}
|
||||
});
|
||||
|
||||
let command_markdown = cx.new(|cx| {
|
||||
Markdown::new(
|
||||
format!("```bash\n{}\n```", input.command).into(),
|
||||
None,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let command_markdown =
|
||||
cx.new(|cx| Markdown::new(format!("```bash\n{}\n```", command).into(), None, None, cx));
|
||||
|
||||
let card = cx.new(|cx| {
|
||||
TerminalToolCard::new(
|
||||
@@ -288,7 +267,7 @@ impl Tool for TerminalTool {
|
||||
let previous_len = content.len();
|
||||
let (processed_content, finished_with_empty_output) = process_content(
|
||||
&content,
|
||||
&input.command,
|
||||
&command,
|
||||
exit_status.map(portable_pty::ExitStatus::from),
|
||||
);
|
||||
|
||||
@@ -740,7 +719,6 @@ mod tests {
|
||||
if cfg!(windows) {
|
||||
return;
|
||||
}
|
||||
|
||||
init_test(&executor, cx);
|
||||
|
||||
let fs = Arc::new(RealFs::new(None, executor));
|
||||
@@ -763,7 +741,7 @@ mod tests {
|
||||
};
|
||||
let result = cx.update(|cx| {
|
||||
TerminalTool::run(
|
||||
Arc::new(TerminalTool::new(cx)),
|
||||
Arc::new(TerminalTool),
|
||||
serde_json::to_value(input).unwrap(),
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
@@ -783,7 +761,6 @@ mod tests {
|
||||
if cfg!(windows) {
|
||||
return;
|
||||
}
|
||||
|
||||
init_test(&executor, cx);
|
||||
|
||||
let fs = Arc::new(RealFs::new(None, executor));
|
||||
@@ -798,7 +775,7 @@ mod tests {
|
||||
|
||||
let check = |input, expected, cx: &mut App| {
|
||||
let headless_result = TerminalTool::run(
|
||||
Arc::new(TerminalTool::new(cx)),
|
||||
Arc::new(TerminalTool),
|
||||
serde_json::to_value(input).unwrap(),
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
|
||||
@@ -211,7 +211,7 @@ impl Audio {
|
||||
agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed));
|
||||
})
|
||||
.replayable(REPLAY_DURATION)
|
||||
.expect("REPLAY_DURATION is longer then 100ms");
|
||||
.expect("REPLAY_DURATION is longer than 100ms");
|
||||
|
||||
cx.update_default_global(|this: &mut Self, _cx| {
|
||||
let output_mixer = this
|
||||
|
||||
@@ -57,7 +57,7 @@ impl<S: Source> RodioExt for S {
|
||||
/// replay is being read
|
||||
///
|
||||
/// # Errors
|
||||
/// If duration is smaller then 100ms
|
||||
/// If duration is smaller than 100ms
|
||||
fn replayable(
|
||||
self,
|
||||
duration: Duration,
|
||||
@@ -151,7 +151,7 @@ impl<S: Source> Source for TakeSamples<S> {
|
||||
struct ReplayQueue {
|
||||
inner: ArrayQueue<Vec<Sample>>,
|
||||
normal_chunk_len: usize,
|
||||
/// The last chunk in the queue may be smaller then
|
||||
/// The last chunk in the queue may be smaller than
|
||||
/// the normal chunk size. This is always equal to the
|
||||
/// size of the last element in the queue.
|
||||
/// (so normally chunk_size)
|
||||
@@ -535,7 +535,7 @@ mod tests {
|
||||
|
||||
let (mut replay, mut source) = input
|
||||
.replayable(Duration::from_secs(3))
|
||||
.expect("longer then 100ms");
|
||||
.expect("longer than 100ms");
|
||||
|
||||
source.by_ref().take(3).count();
|
||||
let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
|
||||
@@ -552,7 +552,7 @@ mod tests {
|
||||
|
||||
let (mut replay, mut source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer then 100ms");
|
||||
.expect("longer than 100ms");
|
||||
|
||||
source.by_ref().take(5).count(); // get all items but do not end the source
|
||||
let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
|
||||
@@ -567,7 +567,7 @@ mod tests {
|
||||
|
||||
let (replay, mut source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer then 100ms");
|
||||
.expect("longer than 100ms");
|
||||
|
||||
// exhaust but do not yet end source
|
||||
source.by_ref().take(40_000).count();
|
||||
@@ -586,7 +586,7 @@ mod tests {
|
||||
let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
|
||||
let (mut replay, source) = input
|
||||
.replayable(Duration::from_secs(2))
|
||||
.expect("longer then 100ms");
|
||||
.expect("longer than 100ms");
|
||||
assert_eq!(replay.by_ref().samples_ready(), 0);
|
||||
|
||||
source.take(8000).count(); // half a second
|
||||
|
||||
@@ -32,3 +32,6 @@ workspace-hack.workspace = true
|
||||
|
||||
[target.'cfg(not(target_os = "windows"))'.dependencies]
|
||||
which.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
|
||||
@@ -34,7 +34,7 @@ actions!(
|
||||
/// Checks for available updates.
|
||||
Check,
|
||||
/// Dismisses the update error message.
|
||||
DismissErrorMessage,
|
||||
DismissMessage,
|
||||
/// Opens the release notes for the current version in a browser.
|
||||
ViewReleaseNotes,
|
||||
]
|
||||
@@ -55,14 +55,14 @@ pub enum VersionCheckType {
|
||||
Semantic(SemanticVersion),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
#[derive(Clone)]
|
||||
pub enum AutoUpdateStatus {
|
||||
Idle,
|
||||
Checking,
|
||||
Downloading { version: VersionCheckType },
|
||||
Installing { version: VersionCheckType },
|
||||
Updated { version: VersionCheckType },
|
||||
Errored,
|
||||
Errored { error: Arc<anyhow::Error> },
|
||||
}
|
||||
|
||||
impl AutoUpdateStatus {
|
||||
@@ -383,7 +383,9 @@ impl AutoUpdater {
|
||||
}
|
||||
UpdateCheckType::Manual => {
|
||||
log::error!("auto-update failed: error:{:?}", error);
|
||||
AutoUpdateStatus::Errored
|
||||
AutoUpdateStatus::Errored {
|
||||
error: Arc::new(error),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -402,8 +404,8 @@ impl AutoUpdater {
|
||||
self.status.clone()
|
||||
}
|
||||
|
||||
pub fn dismiss_error(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
if self.status == AutoUpdateStatus::Idle {
|
||||
pub fn dismiss(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
if let AutoUpdateStatus::Idle = self.status {
|
||||
return false;
|
||||
}
|
||||
self.status = AutoUpdateStatus::Idle;
|
||||
@@ -992,8 +994,27 @@ pub fn finalize_auto_update_on_quit() {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::TestAppContext;
|
||||
use settings::default_settings;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_auto_update_defaults_to_true(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let mut store = SettingsStore::new(cx);
|
||||
store
|
||||
.set_default_settings(&default_settings(), cx)
|
||||
.expect("Unable to set default settings");
|
||||
store
|
||||
.set_user_settings("{}", cx)
|
||||
.expect("Unable to set user settings");
|
||||
cx.set_global(store);
|
||||
AutoUpdateSetting::register(cx);
|
||||
assert!(AutoUpdateSetting::get_global(cx).0);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stable_does_not_update_when_fetched_version_is_not_higher() {
|
||||
let release_channel = ReleaseChannel::Stable;
|
||||
|
||||
@@ -22,7 +22,7 @@ use futures::{
|
||||
channel::oneshot, future::BoxFuture,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
@@ -132,6 +132,20 @@ pub struct ProxySettings {
|
||||
pub proxy: Option<String>,
|
||||
}
|
||||
|
||||
impl ProxySettings {
|
||||
pub fn proxy_url(&self) -> Option<Url> {
|
||||
self.proxy
|
||||
.as_ref()
|
||||
.and_then(|input| {
|
||||
input
|
||||
.parse::<Url>()
|
||||
.inspect_err(|e| log::error!("Error parsing proxy settings: {}", e))
|
||||
.ok()
|
||||
})
|
||||
.or_else(read_proxy_from_env)
|
||||
}
|
||||
}
|
||||
|
||||
impl Settings for ProxySettings {
|
||||
type FileContent = ProxySettingsContent;
|
||||
|
||||
|
||||
@@ -754,6 +754,10 @@ impl UserStore {
|
||||
}
|
||||
|
||||
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
|
||||
if self.plan().is_some_and(|plan| plan.is_v2()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.model_request_usage
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
|
||||
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
|
||||
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
|
||||
|
||||
/// The name of the header used to indicate the the minimum required Zed version.
|
||||
/// The name of the header used to indicate the minimum required Zed version.
|
||||
///
|
||||
/// This can be used to force a Zed upgrade in order to continue communicating
|
||||
/// with the LLM service.
|
||||
@@ -321,8 +321,8 @@ pub struct LanguageModel {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub models: Vec<LanguageModel>,
|
||||
pub default_model: LanguageModelId,
|
||||
pub default_fast_model: LanguageModelId,
|
||||
pub default_model: Option<LanguageModelId>,
|
||||
pub default_fast_model: Option<LanguageModelId>,
|
||||
pub recommended_models: Vec<LanguageModelId>,
|
||||
}
|
||||
|
||||
|
||||
@@ -226,12 +226,6 @@ spec:
|
||||
secretKeyRef:
|
||||
name: supermaven
|
||||
key: api_key
|
||||
- name: USER_BACKFILLER_GITHUB_ACCESS_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: user-backfiller
|
||||
key: github_access_token
|
||||
optional: true
|
||||
- name: INVITE_LINK_PREFIX
|
||||
value: ${INVITE_LINK_PREFIX}
|
||||
- name: RUST_BACKTRACE
|
||||
|
||||
@@ -7,7 +7,6 @@ pub mod llm;
|
||||
pub mod migrations;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod user_backfiller;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@@ -157,7 +156,6 @@ pub struct Config {
|
||||
pub slack_panics_webhook: Option<String>,
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -211,7 +209,6 @@ impl Config {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
kinesis_access_key: None,
|
||||
kinesis_secret_key: None,
|
||||
|
||||
@@ -11,7 +11,6 @@ use collab::ServiceMode;
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
use collab::{
|
||||
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
|
||||
executor::Executor, rpc::ResultExt,
|
||||
@@ -114,7 +113,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
if mode.is_api() {
|
||||
fetch_extensions_from_blob_store_periodically(state.clone());
|
||||
spawn_user_backfiller(state.clone());
|
||||
|
||||
app = app
|
||||
.merge(collab::api::events::router())
|
||||
|
||||
@@ -604,7 +604,6 @@ impl TestServer {
|
||||
migrations_path: None,
|
||||
seed_path: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
kinesis_stream: None,
|
||||
kinesis_access_key: None,
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::db::Database;
|
||||
use crate::executor::Executor;
|
||||
use crate::{AppState, Config};
|
||||
|
||||
pub fn spawn_user_backfiller(app_state: Arc<AppState>) {
|
||||
let Some(user_backfiller_github_access_token) =
|
||||
app_state.config.user_backfiller_github_access_token.clone()
|
||||
else {
|
||||
log::info!("no USER_BACKFILLER_GITHUB_ACCESS_TOKEN set; not spawning user backfiller");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app_state.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
let user_backfiller = UserBackfiller::new(
|
||||
app_state.config.clone(),
|
||||
user_backfiller_github_access_token,
|
||||
app_state.db.clone(),
|
||||
executor,
|
||||
);
|
||||
|
||||
log::info!("backfilling users");
|
||||
|
||||
user_backfiller
|
||||
.backfill_github_user_created_at()
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const GITHUB_REQUESTS_PER_HOUR_LIMIT: usize = 5_000;
|
||||
const SLEEP_DURATION_BETWEEN_USERS: std::time::Duration = std::time::Duration::from_millis(
|
||||
(GITHUB_REQUESTS_PER_HOUR_LIMIT as f64 / 60. / 60. * 1000.) as u64,
|
||||
);
|
||||
|
||||
struct UserBackfiller {
|
||||
config: Config,
|
||||
github_access_token: Arc<str>,
|
||||
db: Arc<Database>,
|
||||
http_client: reqwest::Client,
|
||||
executor: Executor,
|
||||
}
|
||||
|
||||
impl UserBackfiller {
|
||||
fn new(
|
||||
config: Config,
|
||||
github_access_token: Arc<str>,
|
||||
db: Arc<Database>,
|
||||
executor: Executor,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
github_access_token,
|
||||
db,
|
||||
http_client: reqwest::Client::new(),
|
||||
executor,
|
||||
}
|
||||
}
|
||||
|
||||
async fn backfill_github_user_created_at(&self) -> Result<()> {
|
||||
let initial_channel_id = self.config.auto_join_channel_id;
|
||||
|
||||
let users_missing_github_user_created_at =
|
||||
self.db.get_users_missing_github_user_created_at().await?;
|
||||
|
||||
for user in users_missing_github_user_created_at {
|
||||
match self
|
||||
.fetch_github_user(&format!(
|
||||
"https://api.github.com/user/{}",
|
||||
user.github_user_id
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(github_user) => {
|
||||
self.db
|
||||
.update_or_create_user_by_github_account(
|
||||
&user.github_login,
|
||||
github_user.id,
|
||||
user.email_address.as_deref(),
|
||||
user.name.as_deref(),
|
||||
github_user.created_at,
|
||||
initial_channel_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
log::info!("backfilled user: {}", user.github_login);
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("failed to fetch GitHub user {}: {err}", user.github_login);
|
||||
}
|
||||
}
|
||||
|
||||
self.executor.sleep(SLEEP_DURATION_BETWEEN_USERS).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn fetch_github_user(&self, url: &str) -> Result<GithubUser> {
|
||||
let response = self
|
||||
.http_client
|
||||
.get(url)
|
||||
.header(
|
||||
"authorization",
|
||||
format!("Bearer {}", self.github_access_token),
|
||||
)
|
||||
.header("user-agent", "zed")
|
||||
.send()
|
||||
.await
|
||||
.with_context(|| format!("failed to fetch '{url}'"))?;
|
||||
|
||||
let rate_limit_remaining = response
|
||||
.headers()
|
||||
.get("x-ratelimit-remaining")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| value.parse::<i32>().ok());
|
||||
let rate_limit_reset = response
|
||||
.headers()
|
||||
.get("x-ratelimit-reset")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| value.parse::<i64>().ok())
|
||||
.and_then(|value| DateTime::from_timestamp(value, 0));
|
||||
|
||||
if rate_limit_remaining == Some(0)
|
||||
&& let Some(reset_at) = rate_limit_reset
|
||||
{
|
||||
let now = Utc::now();
|
||||
if reset_at > now {
|
||||
let sleep_duration = reset_at - now;
|
||||
log::info!(
|
||||
"rate limit reached. Sleeping for {} seconds",
|
||||
sleep_duration.num_seconds()
|
||||
);
|
||||
self.executor.sleep(sleep_duration.to_std().unwrap()).await;
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
.error_for_status()
|
||||
.context("fetching GitHub user")?
|
||||
.json()
|
||||
.await
|
||||
.with_context(|| format!("failed to deserialize GitHub user from '{url}'"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct GithubUser {
|
||||
id: i32,
|
||||
created_at: DateTime<Utc>,
|
||||
#[expect(
|
||||
unused,
|
||||
reason = "This field was found to be unused with serde library bump; it's left as is due to insufficient context on PO's side, but it *may* be fine to remove"
|
||||
)]
|
||||
name: Option<String>,
|
||||
}
|
||||
21
crates/denoise/Cargo.toml
Normal file
21
crates/denoise/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "denoise"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle-core = { version = "0.9.1", git ="https://github.com/zed-industries/candle", branch = "9.1-patched" }
|
||||
candle-onnx = { version = "0.9.1", git ="https://github.com/zed-industries/candle", branch = "9.1-patched" }
|
||||
log.workspace = true
|
||||
|
||||
rodio = { workspace = true, features = ["wav_output"] }
|
||||
|
||||
rustfft = { version = "6.2.0", features = ["avx"] }
|
||||
realfft = "3.4.0"
|
||||
thiserror.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
1
crates/denoise/LICENSE-GPL
Symbolic link
1
crates/denoise/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
LICENSE-GPL
|
||||
20
crates/denoise/README.md
Normal file
20
crates/denoise/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
Real time streaming audio denoising using a [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551).
|
||||
|
||||
Trivial to build as it uses the native rust Candle crate for inference. Easy to integrate into any Rodio pipeline.
|
||||
|
||||
```rust
|
||||
# use rodio::{nz, source::UniformSourceIterator, wav_to_file};
|
||||
let file = std::fs::File::open("clips_airconditioning.wav")?;
|
||||
let decoder = rodio::Decoder::try_from(file)?;
|
||||
let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000));
|
||||
|
||||
let mut denoised = denoise::Denoiser::try_new(resampled)?;
|
||||
wav_to_file(&mut denoised, "denoised.wav")?;
|
||||
Result::Ok<(), Box<dyn std::error::Error>>
|
||||
```
|
||||
|
||||
## Acknowledgements & License
|
||||
|
||||
The trained models in this repo are optimized versions of the models in the [breizhn/DTLN](https://github.com/breizhn/DTLN?tab=readme-ov-file#model-conversion-and-real-time-processing-with-onnx). These are licensed under MIT.
|
||||
|
||||
The FFT code was adapted from Datadog's [dtln-rs Repo](https://github.com/DataDog/dtln-rs/tree/main) also licensed under MIT.
|
||||
11
crates/denoise/examples/denoise.rs
Normal file
11
crates/denoise/examples/denoise.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
use rodio::{nz, source::UniformSourceIterator, wav_to_file};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let file = std::fs::File::open("airconditioning.wav")?;
|
||||
let decoder = rodio::Decoder::try_from(file)?;
|
||||
let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000));
|
||||
|
||||
let mut denoised = denoise::Denoiser::try_new(resampled)?;
|
||||
wav_to_file(&mut denoised, "denoised.wav")?;
|
||||
Ok(())
|
||||
}
|
||||
23
crates/denoise/examples/enable_disable.rs
Normal file
23
crates/denoise/examples/enable_disable.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use rodio::Source;
|
||||
use rodio::wav_to_file;
|
||||
use rodio::{nz, source::UniformSourceIterator};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let file = std::fs::File::open("clips_airconditioning.wav")?;
|
||||
let decoder = rodio::Decoder::try_from(file)?;
|
||||
let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000));
|
||||
|
||||
let mut enabled = true;
|
||||
let denoised = denoise::Denoiser::try_new(resampled)?.periodic_access(
|
||||
Duration::from_secs(2),
|
||||
|denoised| {
|
||||
enabled = !enabled;
|
||||
denoised.set_enabled(enabled);
|
||||
},
|
||||
);
|
||||
|
||||
wav_to_file(denoised, "processed.wav")?;
|
||||
Ok(())
|
||||
}
|
||||
BIN
crates/denoise/models/model_1_converted_simplified.onnx
Normal file
BIN
crates/denoise/models/model_1_converted_simplified.onnx
Normal file
Binary file not shown.
BIN
crates/denoise/models/model_2_converted_simplified.onnx
Normal file
BIN
crates/denoise/models/model_2_converted_simplified.onnx
Normal file
Binary file not shown.
204
crates/denoise/src/engine.rs
Normal file
204
crates/denoise/src/engine.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
/// use something like https://netron.app/ to inspect the models and understand
|
||||
/// the flow
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle_core::{Device, IndexOp, Tensor};
|
||||
use candle_onnx::onnx::ModelProto;
|
||||
use candle_onnx::prost::Message;
|
||||
use realfft::RealFftPlanner;
|
||||
use rustfft::num_complex::Complex;
|
||||
|
||||
pub struct Engine {
|
||||
spectral_model: ModelProto,
|
||||
signal_model: ModelProto,
|
||||
|
||||
fft_planner: RealFftPlanner<f32>,
|
||||
fft_scratch: Vec<Complex<f32>>,
|
||||
spectrum: [Complex<f32>; FFT_OUT_SIZE],
|
||||
signal: [f32; BLOCK_LEN],
|
||||
|
||||
in_magnitude: [f32; FFT_OUT_SIZE],
|
||||
in_phase: [f32; FFT_OUT_SIZE],
|
||||
|
||||
spectral_memory: Tensor,
|
||||
signal_memory: Tensor,
|
||||
|
||||
in_buffer: [f32; BLOCK_LEN],
|
||||
out_buffer: [f32; BLOCK_LEN],
|
||||
}
|
||||
|
||||
// 32 ms @ 16khz per DTLN docs: https://github.com/breizhn/DTLN
|
||||
pub const BLOCK_LEN: usize = 512;
|
||||
// 8 ms @ 16khz per DTLN docs.
|
||||
pub const BLOCK_SHIFT: usize = 128;
|
||||
pub const FFT_OUT_SIZE: usize = BLOCK_LEN / 2 + 1;
|
||||
|
||||
impl Engine {
|
||||
pub fn new() -> Self {
|
||||
let mut fft_planner = RealFftPlanner::new();
|
||||
let fft_planned = fft_planner.plan_fft_forward(BLOCK_LEN);
|
||||
let scratch_len = fft_planned.get_scratch_len();
|
||||
Self {
|
||||
// Models are 1.5MB and 2.5MB respectively. Its worth the binary
|
||||
// size increase not to have to distribute the models separately.
|
||||
spectral_model: ModelProto::decode(
|
||||
include_bytes!("../models/model_1_converted_simplified.onnx").as_slice(),
|
||||
)
|
||||
.expect("The model should decode"),
|
||||
signal_model: ModelProto::decode(
|
||||
include_bytes!("../models/model_2_converted_simplified.onnx").as_slice(),
|
||||
)
|
||||
.expect("The model should decode"),
|
||||
fft_planner,
|
||||
fft_scratch: vec![Complex::ZERO; scratch_len],
|
||||
spectrum: [Complex::ZERO; FFT_OUT_SIZE],
|
||||
signal: [0f32; BLOCK_LEN],
|
||||
|
||||
in_magnitude: [0f32; FFT_OUT_SIZE],
|
||||
in_phase: [0f32; FFT_OUT_SIZE],
|
||||
|
||||
spectral_memory: Tensor::from_slice::<_, f32>(
|
||||
&[0f32; 512],
|
||||
(1, 2, BLOCK_SHIFT, 2),
|
||||
&Device::Cpu,
|
||||
)
|
||||
.expect("Tensor has the correct dimensions"),
|
||||
signal_memory: Tensor::from_slice::<_, f32>(
|
||||
&[0f32; 512],
|
||||
(1, 2, BLOCK_SHIFT, 2),
|
||||
&Device::Cpu,
|
||||
)
|
||||
.expect("Tensor has the correct dimensions"),
|
||||
out_buffer: [0f32; BLOCK_LEN],
|
||||
in_buffer: [0f32; BLOCK_LEN],
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a clunk of samples and get the denoised chunk 4 feeds later
|
||||
pub fn feed(&mut self, samples: &[f32]) -> [f32; BLOCK_SHIFT] {
|
||||
/// The name of the output node of the onnx network
|
||||
/// [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551).
|
||||
const MEMORY_OUTPUT: &'static str = "Identity_1";
|
||||
|
||||
debug_assert_eq!(samples.len(), BLOCK_SHIFT);
|
||||
|
||||
// place new samples at the end of the `in_buffer`
|
||||
self.in_buffer.copy_within(BLOCK_SHIFT.., 0);
|
||||
self.in_buffer[(BLOCK_LEN - BLOCK_SHIFT)..].copy_from_slice(&samples);
|
||||
|
||||
// run inference
|
||||
let inputs = self.spectral_inputs();
|
||||
let mut spectral_outputs = candle_onnx::simple_eval(&self.spectral_model, inputs)
|
||||
.expect("The embedded file must be valid");
|
||||
self.spectral_memory = spectral_outputs
|
||||
.remove(MEMORY_OUTPUT)
|
||||
.expect("The model has an output named Identity_1");
|
||||
let inputs = self.signal_inputs(spectral_outputs);
|
||||
let mut signal_outputs = candle_onnx::simple_eval(&self.signal_model, inputs)
|
||||
.expect("The embedded file must be valid");
|
||||
self.signal_memory = signal_outputs
|
||||
.remove(MEMORY_OUTPUT)
|
||||
.expect("The model has an output named Identity_1");
|
||||
let model_output = model_outputs(signal_outputs);
|
||||
|
||||
// place processed samples at the start of the `out_buffer`
|
||||
// shift the rest left, fill the end with zeros. Zeros are needed as
|
||||
// the out buffer is part of the input of the network
|
||||
self.out_buffer.copy_within(BLOCK_SHIFT.., 0);
|
||||
self.out_buffer[BLOCK_LEN - BLOCK_SHIFT..].fill(0f32);
|
||||
for (a, b) in self.out_buffer.iter_mut().zip(model_output) {
|
||||
*a += b;
|
||||
}
|
||||
|
||||
// samples at the front of the `out_buffer` are now denoised
|
||||
self.out_buffer[..BLOCK_SHIFT]
|
||||
.try_into()
|
||||
.expect("len is correct")
|
||||
}
|
||||
|
||||
fn spectral_inputs(&mut self) -> HashMap<String, Tensor> {
|
||||
// Prepare FFT input
|
||||
let fft = self.fft_planner.plan_fft_forward(BLOCK_LEN);
|
||||
|
||||
// Perform real-to-complex FFT
|
||||
let mut fft_in = self.in_buffer;
|
||||
fft.process_with_scratch(&mut fft_in, &mut self.spectrum, &mut self.fft_scratch)
|
||||
.expect("The fft should run, there is enough scratch space");
|
||||
|
||||
// Generate magnitude and phase
|
||||
for ((magnitude, phase), complex) in self
|
||||
.in_magnitude
|
||||
.iter_mut()
|
||||
.zip(self.in_phase.iter_mut())
|
||||
.zip(self.spectrum)
|
||||
{
|
||||
*magnitude = complex.norm();
|
||||
*phase = complex.arg();
|
||||
}
|
||||
|
||||
const SPECTRUM_INPUT: &str = "input_2";
|
||||
const MEMORY_INPUT: &str = "input_3";
|
||||
let memory_input =
|
||||
Tensor::from_slice::<_, f32>(&self.in_magnitude, (1, 1, FFT_OUT_SIZE), &Device::Cpu)
|
||||
.expect("the in magnitude has enough elements to fill the Tensor");
|
||||
|
||||
let inputs = HashMap::from([
|
||||
(MEMORY_INPUT.to_string(), memory_input),
|
||||
(SPECTRUM_INPUT.to_string(), self.spectral_memory.clone()),
|
||||
]);
|
||||
inputs
|
||||
}
|
||||
|
||||
fn signal_inputs(&mut self, outputs: HashMap<String, Tensor>) -> HashMap<String, Tensor> {
|
||||
let magnitude_weight = model_outputs(outputs);
|
||||
|
||||
// Apply mask and reconstruct complex spectrum
|
||||
let mut spectrum = [Complex::I; FFT_OUT_SIZE];
|
||||
for i in 0..FFT_OUT_SIZE {
|
||||
let magnitude = self.in_magnitude[i] * magnitude_weight[i];
|
||||
let phase = self.in_phase[i];
|
||||
let real = magnitude * phase.cos();
|
||||
let imag = magnitude * phase.sin();
|
||||
spectrum[i] = Complex::new(real, imag);
|
||||
}
|
||||
|
||||
// Handle DC component (i = 0)
|
||||
let magnitude = self.in_magnitude[0] * magnitude_weight[0];
|
||||
spectrum[0] = Complex::new(magnitude, 0.0);
|
||||
|
||||
// Handle Nyquist component (i = N/2)
|
||||
let magnitude = self.in_magnitude[FFT_OUT_SIZE - 1] * magnitude_weight[FFT_OUT_SIZE - 1];
|
||||
spectrum[FFT_OUT_SIZE - 1] = Complex::new(magnitude, 0.0);
|
||||
|
||||
// Perform complex-to-real IFFT
|
||||
let ifft = self.fft_planner.plan_fft_inverse(BLOCK_LEN);
|
||||
ifft.process_with_scratch(&mut spectrum, &mut self.signal, &mut self.fft_scratch)
|
||||
.expect("The fft should run, there is enough scratch space");
|
||||
|
||||
// Normalize the IFFT output
|
||||
for real in &mut self.signal {
|
||||
*real /= BLOCK_LEN as f32;
|
||||
}
|
||||
|
||||
const SIGNAL_INPUT: &str = "input_4";
|
||||
const SIGNAL_MEMORY: &str = "input_5";
|
||||
let signal_input =
|
||||
Tensor::from_slice::<_, f32>(&self.signal, (1, 1, BLOCK_LEN), &Device::Cpu).unwrap();
|
||||
|
||||
HashMap::from([
|
||||
(SIGNAL_INPUT.to_string(), signal_input),
|
||||
(SIGNAL_MEMORY.to_string(), self.signal_memory.clone()),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
// Both models put their outputs in the same location
|
||||
fn model_outputs(mut outputs: HashMap<String, Tensor>) -> Vec<f32> {
|
||||
const NON_MEMORY_OUTPUT: &str = "Identity";
|
||||
outputs
|
||||
.remove(NON_MEMORY_OUTPUT)
|
||||
.expect("The model has this output")
|
||||
.i((0, 0))
|
||||
.and_then(|tensor| tensor.to_vec1())
|
||||
.expect("The tensor has the correct dimensions")
|
||||
}
|
||||
270
crates/denoise/src/lib.rs
Normal file
270
crates/denoise/src/lib.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
mod engine;
|
||||
|
||||
use core::fmt;
|
||||
use std::{collections::VecDeque, sync::mpsc, thread};
|
||||
|
||||
pub use engine::Engine;
|
||||
use rodio::{ChannelCount, Sample, SampleRate, Source, nz};
|
||||
|
||||
use crate::engine::BLOCK_SHIFT;
|
||||
|
||||
const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000);
|
||||
const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1);
|
||||
|
||||
pub struct Denoiser<S: Source> {
|
||||
inner: S,
|
||||
input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>,
|
||||
denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>,
|
||||
ready: [Sample; BLOCK_SHIFT],
|
||||
next: usize,
|
||||
state: IterState,
|
||||
// When disabled instead of reading denoised sub-blocks from the engine through
|
||||
// `denoised_rx` we read unprocessed from this queue. This maintains the same
|
||||
// latency so we can 'trivially' re-enable
|
||||
queued: Queue,
|
||||
}
|
||||
|
||||
impl<S: Source> fmt::Debug for Denoiser<S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Denoiser")
|
||||
.field("state", &self.state)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>);
|
||||
|
||||
impl Queue {
|
||||
fn new() -> Self {
|
||||
Self(VecDeque::new())
|
||||
}
|
||||
fn push(&mut self, block: [Sample; BLOCK_SHIFT]) {
|
||||
self.0.push_back(block);
|
||||
self.0.resize(4, [0f32; BLOCK_SHIFT]);
|
||||
}
|
||||
fn pop(&mut self) -> [Sample; BLOCK_SHIFT] {
|
||||
debug_assert!(self.0.len() == 4);
|
||||
self.0.pop_front().expect(
|
||||
"There is no State where the queue is popped while there are less then 4 entries",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum IterState {
|
||||
Enabled,
|
||||
StartingMidAudio { fed_to_denoiser: usize },
|
||||
Disabled,
|
||||
Startup { enabled: bool },
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum DenoiserError {
|
||||
#[error("This denoiser only works on sources with samplerate 16000")]
|
||||
UnsupportedSampleRate,
|
||||
#[error("This denoiser only works on mono sources (1 channel)")]
|
||||
UnsupportedChannelCount,
|
||||
}
|
||||
|
||||
// todo dvdsk needs constant source upstream in rodio
|
||||
impl<S: Source> Denoiser<S> {
|
||||
pub fn try_new(source: S) -> Result<Self, DenoiserError> {
|
||||
if source.sample_rate() != SUPPORTED_SAMPLE_RATE {
|
||||
return Err(DenoiserError::UnsupportedSampleRate);
|
||||
}
|
||||
if source.channels() != SUPPORTED_CHANNEL_COUNT {
|
||||
return Err(DenoiserError::UnsupportedChannelCount);
|
||||
}
|
||||
|
||||
let (input_tx, input_rx) = mpsc::channel();
|
||||
let (denoised_tx, denoised_rx) = mpsc::channel();
|
||||
|
||||
thread::spawn(move || {
|
||||
run_neural_denoiser(denoised_tx, input_rx);
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
inner: source,
|
||||
input_tx,
|
||||
denoised_rx,
|
||||
ready: [0.0; BLOCK_SHIFT],
|
||||
state: IterState::Startup { enabled: true },
|
||||
next: BLOCK_SHIFT,
|
||||
queued: Queue::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_enabled(&mut self, enabled: bool) {
|
||||
self.state = match (enabled, self.state) {
|
||||
(false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => {
|
||||
IterState::Disabled
|
||||
}
|
||||
(false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false },
|
||||
(true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 },
|
||||
(_, state) => state,
|
||||
};
|
||||
}
|
||||
|
||||
fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) {
|
||||
self.input_tx.send(sub_block).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn run_neural_denoiser(
|
||||
denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>,
|
||||
input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>,
|
||||
) {
|
||||
let mut engine = Engine::new();
|
||||
loop {
|
||||
let Ok(sub_block) = input_rx.recv() else {
|
||||
// tx must have dropped, stop thread
|
||||
break;
|
||||
};
|
||||
|
||||
let denoised_sub_block = engine.feed(&sub_block);
|
||||
if denoised_tx.send(denoised_sub_block).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Source for Denoiser<S> {
|
||||
fn current_span_len(&self) -> Option<usize> {
|
||||
self.inner.current_span_len()
|
||||
}
|
||||
|
||||
fn channels(&self) -> rodio::ChannelCount {
|
||||
self.inner.channels()
|
||||
}
|
||||
|
||||
fn sample_rate(&self) -> rodio::SampleRate {
|
||||
self.inner.sample_rate()
|
||||
}
|
||||
|
||||
fn total_duration(&self) -> Option<std::time::Duration> {
|
||||
self.inner.total_duration()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Source> Iterator for Denoiser<S> {
|
||||
type Item = Sample;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.next += 1;
|
||||
if self.next < self.ready.len() {
|
||||
let sample = self.ready[self.next];
|
||||
return Some(sample);
|
||||
}
|
||||
|
||||
// This is a separate function to prevent it from being inlined
|
||||
// as this code only runs once every 128 samples
|
||||
self.prepare_next_ready()
|
||||
.inspect_err(|_| {
|
||||
log::error!("Denoise engine crashed");
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Could not send or receive from denoise thread. It must have crashed")]
|
||||
struct DenoiseEngineCrashed;
|
||||
|
||||
impl<S: Source> Denoiser<S> {
|
||||
#[cold]
|
||||
fn prepare_next_ready(&mut self) -> Result<Option<f32>, DenoiseEngineCrashed> {
|
||||
self.state = match self.state {
|
||||
IterState::Startup { enabled } => {
|
||||
// guaranteed to be coming from silence
|
||||
for _ in 0..3 {
|
||||
let Some(sub_block) = read_sub_block(&mut self.inner) else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.queued.push(sub_block);
|
||||
self.input_tx
|
||||
.send(sub_block)
|
||||
.map_err(|_| DenoiseEngineCrashed)?;
|
||||
}
|
||||
let Some(sub_block) = read_sub_block(&mut self.inner) else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.queued.push(sub_block);
|
||||
self.input_tx
|
||||
.send(sub_block)
|
||||
.map_err(|_| DenoiseEngineCrashed)?;
|
||||
// throw out old blocks that are denoised silence
|
||||
let _ = self.denoised_rx.iter().take(3).count();
|
||||
self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
|
||||
|
||||
let Some(sub_block) = read_sub_block(&mut self.inner) else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.queued.push(sub_block);
|
||||
self.feed(sub_block);
|
||||
|
||||
if enabled {
|
||||
IterState::Enabled
|
||||
} else {
|
||||
IterState::Disabled
|
||||
}
|
||||
}
|
||||
IterState::Enabled => {
|
||||
self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
|
||||
let Some(sub_block) = read_sub_block(&mut self.inner) else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.queued.push(sub_block);
|
||||
self.input_tx
|
||||
.send(sub_block)
|
||||
.map_err(|_| DenoiseEngineCrashed)?;
|
||||
IterState::Enabled
|
||||
}
|
||||
IterState::Disabled => {
|
||||
// Need to maintain the same 512 samples delay such that
|
||||
// we can re-enable at any point.
|
||||
self.ready = self.queued.pop();
|
||||
let Some(sub_block) = read_sub_block(&mut self.inner) else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.queued.push(sub_block);
|
||||
IterState::Disabled
|
||||
}
|
||||
IterState::StartingMidAudio {
|
||||
fed_to_denoiser: mut sub_blocks_fed,
|
||||
} => {
|
||||
self.ready = self.queued.pop();
|
||||
let Some(sub_block) = read_sub_block(&mut self.inner) else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.queued.push(sub_block);
|
||||
self.input_tx
|
||||
.send(sub_block)
|
||||
.map_err(|_| DenoiseEngineCrashed)?;
|
||||
sub_blocks_fed += 1;
|
||||
if sub_blocks_fed > 4 {
|
||||
// throw out partially denoised blocks,
|
||||
// next will be correctly denoised
|
||||
let _ = self.denoised_rx.iter().take(3).count();
|
||||
IterState::Enabled
|
||||
} else {
|
||||
IterState::StartingMidAudio {
|
||||
fed_to_denoiser: sub_blocks_fed,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.next = 0;
|
||||
Ok(Some(self.ready[0]))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> {
|
||||
let mut res = [0f32; BLOCK_SHIFT];
|
||||
for sample in &mut res {
|
||||
*sample = s.next()?;
|
||||
}
|
||||
Some(res)
|
||||
}
|
||||
46
crates/edit_prediction_context/Cargo.toml
Normal file
46
crates/edit_prediction_context/Cargo.toml
Normal file
@@ -0,0 +1,46 @@
|
||||
[package]
|
||||
name = "edit_prediction_context"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/edit_prediction_context.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
arrayvec.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
ordered-float.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
slotmap.workspace = true
|
||||
strum.workspace = true
|
||||
text.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
indoc.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
pretty_assertions.workspace = true
|
||||
project = {workspace= true, features = ["test-support"]}
|
||||
serde_json.workspace = true
|
||||
settings = {workspace= true, features = ["test-support"]}
|
||||
text = { workspace = true, features = ["test-support"] }
|
||||
util = { workspace = true, features = ["test-support"] }
|
||||
zlog.workspace = true
|
||||
1
crates/edit_prediction_context/LICENSE-GPL
Symbolic link
1
crates/edit_prediction_context/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
193
crates/edit_prediction_context/src/declaration.rs
Normal file
193
crates/edit_prediction_context/src/declaration.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use language::LanguageId;
|
||||
use project::ProjectEntryId;
|
||||
use std::borrow::Cow;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
use text::{Bias, BufferId, Rope};
|
||||
|
||||
use crate::outline::OutlineDeclaration;
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
||||
pub struct Identifier {
|
||||
pub name: Arc<str>,
|
||||
pub language_id: LanguageId,
|
||||
}
|
||||
|
||||
slotmap::new_key_type! {
|
||||
pub struct DeclarationId;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Declaration {
|
||||
File {
|
||||
project_entry_id: ProjectEntryId,
|
||||
declaration: FileDeclaration,
|
||||
},
|
||||
Buffer {
|
||||
project_entry_id: ProjectEntryId,
|
||||
buffer_id: BufferId,
|
||||
rope: Rope,
|
||||
declaration: BufferDeclaration,
|
||||
},
|
||||
}
|
||||
|
||||
const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
|
||||
|
||||
impl Declaration {
|
||||
pub fn identifier(&self) -> &Identifier {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => &declaration.identifier,
|
||||
Declaration::Buffer { declaration, .. } => &declaration.identifier,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn project_entry_id(&self) -> Option<ProjectEntryId> {
|
||||
match self {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => Some(*project_entry_id),
|
||||
Declaration::Buffer {
|
||||
project_entry_id, ..
|
||||
} => Some(*project_entry_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text.as_ref().into(),
|
||||
declaration.text_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.item_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.item_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
|
||||
match self {
|
||||
Declaration::File { declaration, .. } => (
|
||||
declaration.text[declaration.signature_range_in_text.clone()].into(),
|
||||
declaration.signature_is_truncated,
|
||||
),
|
||||
Declaration::Buffer {
|
||||
rope, declaration, ..
|
||||
} => (
|
||||
rope.chunks_in_range(declaration.signature_range.clone())
|
||||
.collect::<Cow<str>>(),
|
||||
declaration.signature_range_is_truncated,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_range_to_line_boundaries_and_truncate(
|
||||
range: &Range<usize>,
|
||||
limit: usize,
|
||||
rope: &Rope,
|
||||
) -> (Range<usize>, bool) {
|
||||
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
|
||||
point_range.start.column = 0;
|
||||
point_range.end.row += 1;
|
||||
point_range.end.column = 0;
|
||||
|
||||
let mut item_range =
|
||||
rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
|
||||
let is_truncated = item_range.len() > limit;
|
||||
if is_truncated {
|
||||
item_range.end = item_range.start + limit;
|
||||
}
|
||||
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
|
||||
(item_range, is_truncated)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
/// offset range of the declaration in the file, expanded to line boundaries and truncated
|
||||
pub item_range_in_file: Range<usize>,
|
||||
/// text of `item_range_in_file`
|
||||
pub text: Arc<str>,
|
||||
/// whether `text` was truncated
|
||||
pub text_is_truncated: bool,
|
||||
/// offset range of the signature within `text`
|
||||
pub signature_range_in_text: Range<usize>,
|
||||
/// whether `signature` was truncated
|
||||
pub signature_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl FileDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
|
||||
let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
|
||||
// TODO: consider logging if unexpected
|
||||
let signature_start = declaration
|
||||
.signature_range
|
||||
.start
|
||||
.saturating_sub(item_range_in_file.start);
|
||||
let mut signature_end = declaration
|
||||
.signature_range
|
||||
.end
|
||||
.saturating_sub(item_range_in_file.start);
|
||||
let signature_is_truncated = signature_end > item_range_in_file.len();
|
||||
if signature_is_truncated {
|
||||
signature_end = item_range_in_file.len();
|
||||
}
|
||||
|
||||
FileDeclaration {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
signature_range_in_text: signature_start..signature_end,
|
||||
signature_is_truncated,
|
||||
text: rope
|
||||
.chunks_in_range(item_range_in_file.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
text_is_truncated,
|
||||
item_range_in_file,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BufferDeclaration {
|
||||
pub parent: Option<DeclarationId>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub item_range_is_truncated: bool,
|
||||
pub signature_range: Range<usize>,
|
||||
pub signature_range_is_truncated: bool,
|
||||
}
|
||||
|
||||
impl BufferDeclaration {
|
||||
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
|
||||
let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.item_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
let (signature_range, signature_range_is_truncated) =
|
||||
expand_range_to_line_boundaries_and_truncate(
|
||||
&declaration.signature_range,
|
||||
ITEM_TEXT_TRUNCATION_LENGTH,
|
||||
rope,
|
||||
);
|
||||
Self {
|
||||
parent: None,
|
||||
identifier: declaration.identifier,
|
||||
item_range,
|
||||
item_range_is_truncated,
|
||||
signature_range,
|
||||
signature_range_is_truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
324
crates/edit_prediction_context/src/declaration_scoring.rs
Normal file
324
crates/edit_prediction_context/src/declaration_scoring.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
use itertools::Itertools as _;
|
||||
use language::BufferSnapshot;
|
||||
use ordered_float::OrderedFloat;
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, ops::Range};
|
||||
use strum::EnumIter;
|
||||
use text::{OffsetRangeExt, Point, ToPoint};
|
||||
|
||||
use crate::{
|
||||
Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
|
||||
reference::{Reference, ReferenceRegion},
|
||||
syntax_index::SyntaxIndexState,
|
||||
text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
|
||||
};
|
||||
|
||||
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * Consider adding declaration_file_count
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ScoredSnippet {
|
||||
pub identifier: Identifier,
|
||||
pub declaration: Declaration,
|
||||
pub score_components: ScoreInputs,
|
||||
pub scores: Scores,
|
||||
}
|
||||
|
||||
// TODO: Consider having "Concise" style corresponding to `concise_text`
|
||||
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
pub enum SnippetStyle {
|
||||
Signature,
|
||||
Declaration,
|
||||
}
|
||||
|
||||
impl ScoredSnippet {
|
||||
/// Returns the score for this snippet with the specified style.
|
||||
pub fn score(&self, style: SnippetStyle) -> f32 {
|
||||
match style {
|
||||
SnippetStyle::Signature => self.scores.signature,
|
||||
SnippetStyle::Declaration => self.scores.declaration,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size(&self, style: SnippetStyle) -> usize {
|
||||
// TODO: how to handle truncation?
|
||||
match &self.declaration {
|
||||
Declaration::File { declaration, .. } => match style {
|
||||
SnippetStyle::Signature => declaration.signature_range_in_text.len(),
|
||||
SnippetStyle::Declaration => declaration.text.len(),
|
||||
},
|
||||
Declaration::Buffer { declaration, .. } => match style {
|
||||
SnippetStyle::Signature => declaration.signature_range.len(),
|
||||
SnippetStyle::Declaration => declaration.item_range.len(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn score_density(&self, style: SnippetStyle) -> f32 {
|
||||
self.score(style) / (self.size(style)) as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scored_snippets(
|
||||
index: &SyntaxIndexState,
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_text: &EditPredictionExcerptText,
|
||||
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
|
||||
cursor_offset: usize,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Vec<ScoredSnippet> {
|
||||
let containing_range_identifier_occurrences =
|
||||
IdentifierOccurrences::within_string(&excerpt_text.body);
|
||||
let cursor_point = cursor_offset.to_point(¤t_buffer);
|
||||
|
||||
let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
|
||||
let end_point = Point::new(cursor_point.row + 1, 0);
|
||||
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
|
||||
¤t_buffer
|
||||
.text_for_range(start_point..end_point)
|
||||
.collect::<String>(),
|
||||
);
|
||||
|
||||
let mut snippets = identifier_to_references
|
||||
.into_iter()
|
||||
.flat_map(|(identifier, references)| {
|
||||
let declarations =
|
||||
index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
|
||||
let declaration_count = declarations.len();
|
||||
|
||||
declarations
|
||||
.iter()
|
||||
.filter_map(|declaration| match declaration {
|
||||
Declaration::Buffer {
|
||||
buffer_id,
|
||||
declaration: buffer_declaration,
|
||||
..
|
||||
} => {
|
||||
let is_same_file = buffer_id == ¤t_buffer.remote_id();
|
||||
|
||||
if is_same_file {
|
||||
range_intersection(
|
||||
&buffer_declaration.item_range.to_offset(¤t_buffer),
|
||||
&excerpt.range,
|
||||
)
|
||||
.is_none()
|
||||
.then(|| {
|
||||
let declaration_line = buffer_declaration
|
||||
.item_range
|
||||
.start
|
||||
.to_point(current_buffer)
|
||||
.row;
|
||||
(
|
||||
true,
|
||||
(cursor_point.row as i32 - declaration_line as i32).abs()
|
||||
as u32,
|
||||
declaration,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
Some((false, 0, declaration))
|
||||
}
|
||||
}
|
||||
Declaration::File { .. } => {
|
||||
// We can assume that a file declaration is in a different file,
|
||||
// because the current one must be open
|
||||
Some((false, 0, declaration))
|
||||
}
|
||||
})
|
||||
.sorted_by_key(|&(_, distance, _)| distance)
|
||||
.enumerate()
|
||||
.map(
|
||||
|(
|
||||
declaration_line_distance_rank,
|
||||
(is_same_file, declaration_line_distance, declaration),
|
||||
)| {
|
||||
let same_file_declaration_count = index.file_declaration_count(declaration);
|
||||
|
||||
score_snippet(
|
||||
&identifier,
|
||||
&references,
|
||||
declaration.clone(),
|
||||
is_same_file,
|
||||
declaration_line_distance,
|
||||
declaration_line_distance_rank,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
&containing_range_identifier_occurrences,
|
||||
&adjacent_identifier_occurrences,
|
||||
cursor_point,
|
||||
current_buffer,
|
||||
)
|
||||
},
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
snippets.sort_unstable_by_key(|snippet| {
|
||||
OrderedFloat(
|
||||
snippet
|
||||
.score_density(SnippetStyle::Declaration)
|
||||
.max(snippet.score_density(SnippetStyle::Signature)),
|
||||
)
|
||||
});
|
||||
|
||||
snippets
|
||||
}
|
||||
|
||||
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
|
||||
let start = a.start.clone().max(b.start.clone());
|
||||
let end = a.end.clone().min(b.end.clone());
|
||||
if start < end {
|
||||
Some(Range { start, end })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn score_snippet(
|
||||
identifier: &Identifier,
|
||||
references: &[Reference],
|
||||
declaration: Declaration,
|
||||
is_same_file: bool,
|
||||
declaration_line_distance: u32,
|
||||
declaration_line_distance_rank: usize,
|
||||
same_file_declaration_count: usize,
|
||||
declaration_count: usize,
|
||||
containing_range_identifier_occurrences: &IdentifierOccurrences,
|
||||
adjacent_identifier_occurrences: &IdentifierOccurrences,
|
||||
cursor: Point,
|
||||
current_buffer: &BufferSnapshot,
|
||||
) -> Option<ScoredSnippet> {
|
||||
let is_referenced_nearby = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Nearby);
|
||||
let is_referenced_in_breadcrumb = references
|
||||
.iter()
|
||||
.any(|r| r.region == ReferenceRegion::Breadcrumb);
|
||||
let reference_count = references.len();
|
||||
let reference_line_distance = references
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let reference_line = r.range.start.to_point(current_buffer).row as i32;
|
||||
(cursor.row as i32 - reference_line).abs() as u32
|
||||
})
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
|
||||
let item_signature_occurrences =
|
||||
IdentifierOccurrences::within_string(&declaration.signature_text().0);
|
||||
let containing_range_vs_item_jaccard = jaccard_similarity(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_source_occurrences,
|
||||
);
|
||||
let containing_range_vs_signature_jaccard = jaccard_similarity(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_signature_occurrences,
|
||||
);
|
||||
let adjacent_vs_item_jaccard =
|
||||
jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_jaccard =
|
||||
jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
|
||||
|
||||
let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_source_occurrences,
|
||||
);
|
||||
let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
|
||||
containing_range_identifier_occurrences,
|
||||
&item_signature_occurrences,
|
||||
);
|
||||
let adjacent_vs_item_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
|
||||
let adjacent_vs_signature_weighted_overlap =
|
||||
weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
|
||||
|
||||
let score_components = ScoreInputs {
|
||||
is_same_file,
|
||||
is_referenced_nearby,
|
||||
is_referenced_in_breadcrumb,
|
||||
reference_line_distance,
|
||||
declaration_line_distance,
|
||||
declaration_line_distance_rank,
|
||||
reference_count,
|
||||
same_file_declaration_count,
|
||||
declaration_count,
|
||||
containing_range_vs_item_jaccard,
|
||||
containing_range_vs_signature_jaccard,
|
||||
adjacent_vs_item_jaccard,
|
||||
adjacent_vs_signature_jaccard,
|
||||
containing_range_vs_item_weighted_overlap,
|
||||
containing_range_vs_signature_weighted_overlap,
|
||||
adjacent_vs_item_weighted_overlap,
|
||||
adjacent_vs_signature_weighted_overlap,
|
||||
};
|
||||
|
||||
Some(ScoredSnippet {
|
||||
identifier: identifier.clone(),
|
||||
declaration: declaration,
|
||||
scores: score_components.score(),
|
||||
score_components,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct ScoreInputs {
|
||||
pub is_same_file: bool,
|
||||
pub is_referenced_nearby: bool,
|
||||
pub is_referenced_in_breadcrumb: bool,
|
||||
pub reference_count: usize,
|
||||
pub same_file_declaration_count: usize,
|
||||
pub declaration_count: usize,
|
||||
pub reference_line_distance: u32,
|
||||
pub declaration_line_distance: u32,
|
||||
pub declaration_line_distance_rank: usize,
|
||||
pub containing_range_vs_item_jaccard: f32,
|
||||
pub containing_range_vs_signature_jaccard: f32,
|
||||
pub adjacent_vs_item_jaccard: f32,
|
||||
pub adjacent_vs_signature_jaccard: f32,
|
||||
pub containing_range_vs_item_weighted_overlap: f32,
|
||||
pub containing_range_vs_signature_weighted_overlap: f32,
|
||||
pub adjacent_vs_item_weighted_overlap: f32,
|
||||
pub adjacent_vs_signature_weighted_overlap: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct Scores {
|
||||
pub signature: f32,
|
||||
pub declaration: f32,
|
||||
}
|
||||
|
||||
impl ScoreInputs {
|
||||
fn score(&self) -> Scores {
|
||||
// Score related to how likely this is the correct declaration, range 0 to 1
|
||||
let accuracy_score = if self.is_same_file {
|
||||
// TODO: use declaration_line_distance_rank
|
||||
1.0 / self.same_file_declaration_count as f32
|
||||
} else {
|
||||
1.0 / self.declaration_count as f32
|
||||
};
|
||||
|
||||
// Score related to the distance between the reference and cursor, range 0 to 1
|
||||
let distance_score = if self.is_referenced_nearby {
|
||||
1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
|
||||
} else {
|
||||
// same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
|
||||
0.5
|
||||
};
|
||||
|
||||
// For now instead of linear combination, the scores are just multiplied together.
|
||||
let combined_score = 10.0 * accuracy_score * distance_score;
|
||||
|
||||
Scores {
|
||||
signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
|
||||
// declaration score gets boosted both by being multipled by 2 and by there being more
|
||||
// weighted overlap.
|
||||
declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
|
||||
}
|
||||
}
|
||||
}
|
||||
220
crates/edit_prediction_context/src/edit_prediction_context.rs
Normal file
220
crates/edit_prediction_context/src/edit_prediction_context.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
mod declaration;
|
||||
mod declaration_scoring;
|
||||
mod excerpt;
|
||||
mod outline;
|
||||
mod reference;
|
||||
mod syntax_index;
|
||||
mod text_similarity;
|
||||
|
||||
pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use language::BufferSnapshot;
|
||||
pub use reference::references_in_excerpt;
|
||||
pub use syntax_index::SyntaxIndex;
|
||||
use text::{Point, ToOffset as _};
|
||||
|
||||
use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
|
||||
|
||||
pub struct EditPredictionContext {
|
||||
pub excerpt: EditPredictionExcerpt,
|
||||
pub excerpt_text: EditPredictionExcerptText,
|
||||
pub snippets: Vec<ScoredSnippet>,
|
||||
}
|
||||
|
||||
impl EditPredictionContext {
|
||||
pub fn gather(
|
||||
cursor_point: Point,
|
||||
buffer: BufferSnapshot,
|
||||
excerpt_options: EditPredictionExcerptOptions,
|
||||
syntax_index: Entity<SyntaxIndex>,
|
||||
cx: &mut App,
|
||||
) -> Task<Self> {
|
||||
let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
|
||||
cx.background_spawn(async move {
|
||||
let index_state = index_state.lock().await;
|
||||
|
||||
let excerpt =
|
||||
EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
|
||||
.unwrap();
|
||||
let excerpt_text = excerpt.text(&buffer);
|
||||
let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
|
||||
let cursor_offset = cursor_point.to_offset(&buffer);
|
||||
|
||||
let snippets = scored_snippets(
|
||||
&index_state,
|
||||
&excerpt,
|
||||
&excerpt_text,
|
||||
references,
|
||||
cursor_offset,
|
||||
&buffer,
|
||||
);
|
||||
|
||||
Self {
|
||||
excerpt,
|
||||
excerpt_text,
|
||||
snippets,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
use crate::{EditPredictionExcerptOptions, SyntaxIndex};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_call_site(cx: &mut TestAppContext) {
|
||||
let (project, index, _rust_lang_id) = init_test(cx).await;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path("c.rs", cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// first process_data call site
|
||||
let cursor_point = language::Point::new(8, 21);
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let context = cx
|
||||
.update(|cx| {
|
||||
EditPredictionContext::gather(
|
||||
cursor_point,
|
||||
buffer_snapshot,
|
||||
EditPredictionExcerptOptions {
|
||||
max_bytes: 40,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
include_parent_signatures: false,
|
||||
},
|
||||
index,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(context.snippets.len(), 1);
|
||||
assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
|
||||
drop(buffer);
|
||||
}
|
||||
|
||||
async fn init_test(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"a.rs": indoc! {r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
let z = add(x, y);
|
||||
println!("Result: {}", z);
|
||||
}
|
||||
|
||||
fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
"#},
|
||||
"b.rs": indoc! {"
|
||||
pub struct Config {
|
||||
pub name: String,
|
||||
pub value: i32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(name: String, value: i32) -> Self {
|
||||
Config { name, value }
|
||||
}
|
||||
}
|
||||
"},
|
||||
"c.rs": indoc! {r#"
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let data: Vec<i32> = args[1..]
|
||||
.iter()
|
||||
.filter_map(|s| s.parse().ok())
|
||||
.collect();
|
||||
let result = process_data(data);
|
||||
println!("{:?}", result);
|
||||
}
|
||||
|
||||
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for value in data {
|
||||
*counts.entry(value).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_process_data() {
|
||||
let data = vec![1, 2, 2, 3];
|
||||
let result = process_data(data);
|
||||
assert_eq!(result.get(&2), Some(&2));
|
||||
}
|
||||
}
|
||||
"#}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let lang = rust_lang();
|
||||
let lang_id = lang.id();
|
||||
language_registry.add(Arc::new(lang));
|
||||
|
||||
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
(project, index, lang_id)
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
|
||||
.unwrap()
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
616
crates/edit_prediction_context/src/excerpt.rs
Normal file
616
crates/edit_prediction_context/src/excerpt.rs
Normal file
@@ -0,0 +1,616 @@
|
||||
use language::BufferSnapshot;
|
||||
use std::ops::Range;
|
||||
use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
|
||||
use tree_sitter::{Node, TreeCursor};
|
||||
use util::RangeExt;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// - Test parent signatures
|
||||
//
|
||||
// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
|
||||
// planning.
|
||||
//
|
||||
// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
|
||||
// paragraph).
|
||||
//
|
||||
// - Truncation of long lines.
|
||||
//
|
||||
// - Filter outer syntax layers that don't support edit prediction.
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerptOptions {
|
||||
/// Limit for the number of bytes in the window around the cursor.
|
||||
pub max_bytes: usize,
|
||||
/// Minimum number of bytes in the window around the cursor. When syntax tree selection results
|
||||
/// in an excerpt smaller than this, it will fall back on line-based selection.
|
||||
pub min_bytes: usize,
|
||||
/// Target ratio of bytes before the cursor divided by total bytes in the window.
|
||||
pub target_before_cursor_over_total_bytes: f32,
|
||||
/// Whether to include parent signatures
|
||||
pub include_parent_signatures: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditPredictionExcerpt {
|
||||
pub range: Range<usize>,
|
||||
pub parent_signature_ranges: Vec<Range<usize>>,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EditPredictionExcerptText {
|
||||
pub body: String,
|
||||
pub parent_signatures: Vec<String>,
|
||||
}
|
||||
|
||||
impl EditPredictionExcerpt {
|
||||
pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
|
||||
let body = buffer
|
||||
.text_for_range(self.range.clone())
|
||||
.collect::<String>();
|
||||
let parent_signatures = self
|
||||
.parent_signature_ranges
|
||||
.iter()
|
||||
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
|
||||
.collect();
|
||||
EditPredictionExcerptText {
|
||||
body,
|
||||
parent_signatures,
|
||||
}
|
||||
}
|
||||
|
||||
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
|
||||
/// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
|
||||
/// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
|
||||
/// of parent outline items.
|
||||
///
|
||||
/// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
|
||||
/// expansion.
|
||||
///
|
||||
/// Returns `None` if the line around the cursor doesn't fit.
|
||||
pub fn select_from_buffer(
|
||||
query_point: Point,
|
||||
buffer: &BufferSnapshot,
|
||||
options: &EditPredictionExcerptOptions,
|
||||
) -> Option<Self> {
|
||||
if buffer.len() <= options.max_bytes {
|
||||
log::debug!(
|
||||
"using entire file for excerpt since source length ({}) <= window max bytes ({})",
|
||||
buffer.len(),
|
||||
options.max_bytes
|
||||
);
|
||||
return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
|
||||
}
|
||||
|
||||
let query_offset = query_point.to_offset(buffer);
|
||||
let query_range = Point::new(query_point.row, 0).to_offset(buffer)
|
||||
..Point::new(query_point.row + 1, 0).to_offset(buffer);
|
||||
if query_range.len() >= options.max_bytes {
|
||||
return None;
|
||||
}
|
||||
|
||||
// TODO: Don't compute text / annotation_range / skip converting to and from anchors.
|
||||
let outline_items = if options.include_parent_signatures {
|
||||
buffer
|
||||
.outline_items_containing(query_range.clone(), false, None)
|
||||
.into_iter()
|
||||
.flat_map(|item| {
|
||||
Some(ExcerptOutlineItem {
|
||||
item_range: item.range.to_offset(&buffer),
|
||||
signature_range: item.signature_range?.to_offset(&buffer),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let excerpt_selector = ExcerptSelector {
|
||||
query_offset,
|
||||
query_range,
|
||||
outline_items: &outline_items,
|
||||
buffer,
|
||||
options,
|
||||
};
|
||||
|
||||
if let Some(excerpt_ranges) = excerpt_selector.select_tree_sitter_nodes() {
|
||||
if excerpt_ranges.size >= options.min_bytes {
|
||||
return Some(excerpt_ranges);
|
||||
}
|
||||
log::debug!(
|
||||
"tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
|
||||
excerpt_ranges.size,
|
||||
options.min_bytes
|
||||
);
|
||||
} else {
|
||||
log::debug!(
|
||||
"couldn't find excerpt via tree-sitter, falling back on line-based selection"
|
||||
);
|
||||
}
|
||||
|
||||
excerpt_selector.select_lines()
|
||||
}
|
||||
|
||||
fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
|
||||
let size = range.len()
|
||||
+ parent_signature_ranges
|
||||
.iter()
|
||||
.map(|r| r.len())
|
||||
.sum::<usize>();
|
||||
Self {
|
||||
range,
|
||||
parent_signature_ranges,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
|
||||
if !new_range.contains_inclusive(&self.range) {
|
||||
// this is an issue because parent_signature_ranges may be incorrect
|
||||
log::error!("bug: with_expanded_range called with disjoint range");
|
||||
}
|
||||
let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
|
||||
let mut size = new_range.len();
|
||||
for range in &self.parent_signature_ranges {
|
||||
if range.contains_inclusive(&new_range) {
|
||||
break;
|
||||
}
|
||||
parent_signature_ranges.push(range.clone());
|
||||
size += range.len();
|
||||
}
|
||||
Self {
|
||||
range: new_range,
|
||||
parent_signature_ranges,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
fn parent_signatures_size(&self) -> usize {
|
||||
self.size - self.range.len()
|
||||
}
|
||||
}
|
||||
|
||||
struct ExcerptSelector<'a> {
|
||||
query_offset: usize,
|
||||
query_range: Range<usize>,
|
||||
outline_items: &'a [ExcerptOutlineItem],
|
||||
buffer: &'a BufferSnapshot,
|
||||
options: &'a EditPredictionExcerptOptions,
|
||||
}
|
||||
|
||||
struct ExcerptOutlineItem {
|
||||
item_range: Range<usize>,
|
||||
signature_range: Range<usize>,
|
||||
}
|
||||
|
||||
impl<'a> ExcerptSelector<'a> {
|
||||
/// Finds the largest node that is smaller than the window size and contains `query_range`.
|
||||
fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
|
||||
let selected_layer_root = self.select_syntax_layer()?;
|
||||
let mut cursor = selected_layer_root.walk();
|
||||
|
||||
loop {
|
||||
let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
|
||||
..node_line_end(cursor.node()).to_offset(&self.buffer);
|
||||
if excerpt_range.contains_inclusive(&self.query_range) {
|
||||
let excerpt = self.make_excerpt(excerpt_range);
|
||||
if excerpt.size <= self.options.max_bytes {
|
||||
return Some(self.expand_to_siblings(&mut cursor, excerpt));
|
||||
}
|
||||
} else {
|
||||
// TODO: Should still be able to handle this case via AST nodes. For example, this
|
||||
// can happen if the cursor is between two methods in a large class file.
|
||||
return None;
|
||||
}
|
||||
|
||||
if cursor
|
||||
.goto_first_child_for_byte(self.query_range.start)
|
||||
.is_none()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
|
||||
fn select_syntax_layer(&self) -> Option<Node<'_>> {
|
||||
let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
|
||||
let mut largest: Option<Node<'_>> = None;
|
||||
for layer in self
|
||||
.buffer
|
||||
.syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
|
||||
{
|
||||
let layer_range = layer.node().byte_range();
|
||||
if !layer_range.contains_inclusive(&self.query_range) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if layer_range.len() > self.options.max_bytes {
|
||||
match &smallest_exceeding_max_len {
|
||||
None => smallest_exceeding_max_len = Some(layer.node()),
|
||||
Some(existing) => {
|
||||
if layer_range.len() < existing.byte_range().len() {
|
||||
smallest_exceeding_max_len = Some(layer.node());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match &largest {
|
||||
None => largest = Some(layer.node()),
|
||||
Some(existing) if layer_range.len() > existing.byte_range().len() => {
|
||||
largest = Some(layer.node())
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
smallest_exceeding_max_len.or(largest)
|
||||
}
|
||||
|
||||
// motivation for this and `goto_previous_named_sibling` is to avoid including things like
|
||||
// trailing unnamed "}" in body nodes
|
||||
fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
|
||||
while cursor.goto_next_sibling() {
|
||||
if cursor.node().is_named() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
|
||||
while cursor.goto_previous_sibling() {
|
||||
if cursor.node().is_named() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn expand_to_siblings(
|
||||
&self,
|
||||
cursor: &mut TreeCursor,
|
||||
mut excerpt: EditPredictionExcerpt,
|
||||
) -> EditPredictionExcerpt {
|
||||
let mut forward_cursor = cursor.clone();
|
||||
let backward_cursor = cursor;
|
||||
let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
|
||||
let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
|
||||
loop {
|
||||
if backward_done && forward_done {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut forward = None;
|
||||
while !forward_done {
|
||||
let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
|
||||
if new_end > excerpt.range.end {
|
||||
let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
|
||||
if new_excerpt.size <= self.options.max_bytes {
|
||||
forward = Some(new_excerpt);
|
||||
break;
|
||||
} else {
|
||||
log::debug!("halting forward expansion, as it doesn't fit");
|
||||
forward_done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
|
||||
}
|
||||
|
||||
let mut backward = None;
|
||||
while !backward_done {
|
||||
let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
|
||||
if new_start < excerpt.range.start {
|
||||
let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
|
||||
if new_excerpt.size <= self.options.max_bytes {
|
||||
backward = Some(new_excerpt);
|
||||
break;
|
||||
} else {
|
||||
log::debug!("halting backward expansion, as it doesn't fit");
|
||||
backward_done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
backward_done = !Self::goto_previous_named_sibling(backward_cursor);
|
||||
}
|
||||
|
||||
let go_forward = match (forward, backward) {
|
||||
(Some(forward), Some(backward)) => {
|
||||
let go_forward = self.is_better_excerpt(&forward, &backward);
|
||||
if go_forward {
|
||||
excerpt = forward;
|
||||
} else {
|
||||
excerpt = backward;
|
||||
}
|
||||
go_forward
|
||||
}
|
||||
(Some(forward), None) => {
|
||||
log::debug!("expanding forward, since backward expansion has halted");
|
||||
excerpt = forward;
|
||||
true
|
||||
}
|
||||
(None, Some(backward)) => {
|
||||
log::debug!("expanding backward, since forward expansion has halted");
|
||||
excerpt = backward;
|
||||
false
|
||||
}
|
||||
(None, None) => break,
|
||||
};
|
||||
|
||||
if go_forward {
|
||||
forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
|
||||
} else {
|
||||
backward_done = !Self::goto_previous_named_sibling(backward_cursor);
|
||||
}
|
||||
}
|
||||
|
||||
excerpt
|
||||
}
|
||||
|
||||
fn select_lines(&self) -> Option<EditPredictionExcerpt> {
|
||||
// early return if line containing query_offset is already too large
|
||||
let excerpt = self.make_excerpt(self.query_range.clone());
|
||||
if excerpt.size > self.options.max_bytes {
|
||||
log::debug!(
|
||||
"excerpt for cursor line is {} bytes, which exceeds the window",
|
||||
excerpt.size
|
||||
);
|
||||
return None;
|
||||
}
|
||||
let signatures_size = excerpt.parent_signatures_size();
|
||||
let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
|
||||
|
||||
let before_bytes =
|
||||
(self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
|
||||
|
||||
let start_point = {
|
||||
let offset = self.query_offset.saturating_sub(before_bytes);
|
||||
let point = offset.to_point(self.buffer);
|
||||
Point::new(point.row + 1, 0)
|
||||
};
|
||||
let start_offset = start_point.to_offset(&self.buffer);
|
||||
let end_point = {
|
||||
let offset = start_offset + bytes_remaining;
|
||||
let point = offset.to_point(self.buffer);
|
||||
Point::new(point.row, 0)
|
||||
};
|
||||
let end_offset = end_point.to_offset(&self.buffer);
|
||||
|
||||
// this could be expanded further since recalculated `signature_size` may be smaller, but
|
||||
// skipping that for now for simplicity
|
||||
//
|
||||
// TODO: could also consider checking if lines immediately before / after fit.
|
||||
let excerpt = self.make_excerpt(start_offset..end_offset);
|
||||
if excerpt.size > self.options.max_bytes {
|
||||
log::error!(
|
||||
"bug: line-based excerpt selection has size {}, \
|
||||
which is {} bytes larger than the max size",
|
||||
excerpt.size,
|
||||
excerpt.size - self.options.max_bytes
|
||||
);
|
||||
}
|
||||
return Some(excerpt);
|
||||
}
|
||||
|
||||
fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
|
||||
let parent_signature_ranges = self
|
||||
.outline_items
|
||||
.iter()
|
||||
.filter(|item| item.item_range.contains_inclusive(&range))
|
||||
.map(|item| item.signature_range.clone())
|
||||
.collect();
|
||||
EditPredictionExcerpt::new(range, parent_signature_ranges)
|
||||
}
|
||||
|
||||
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
|
||||
fn is_better_excerpt(
|
||||
&self,
|
||||
forward: &EditPredictionExcerpt,
|
||||
backward: &EditPredictionExcerpt,
|
||||
) -> bool {
|
||||
let forward_ratio = self.excerpt_range_ratio(forward);
|
||||
let backward_ratio = self.excerpt_range_ratio(backward);
|
||||
let forward_delta =
|
||||
(forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
|
||||
let backward_delta =
|
||||
(backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
|
||||
let forward_is_better = forward_delta <= backward_delta;
|
||||
if forward_is_better {
|
||||
log::debug!(
|
||||
"expanding forward since {} is closer than {} to {}",
|
||||
forward_ratio,
|
||||
backward_ratio,
|
||||
self.options.target_before_cursor_over_total_bytes
|
||||
);
|
||||
} else {
|
||||
log::debug!(
|
||||
"expanding backward since {} is closer than {} to {}",
|
||||
backward_ratio,
|
||||
forward_ratio,
|
||||
self.options.target_before_cursor_over_total_bytes
|
||||
);
|
||||
}
|
||||
forward_is_better
|
||||
}
|
||||
|
||||
/// Returns the ratio of bytes before the cursor over bytes within the range.
|
||||
fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
|
||||
let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
|
||||
log::error!("bug: edit prediction cursor offset is not outside the excerpt");
|
||||
return 0.0;
|
||||
};
|
||||
bytes_before_cursor as f32 / excerpt.range.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
fn node_line_start(node: Node) -> Point {
|
||||
Point::new(node.start_position().row as u32, 0)
|
||||
}
|
||||
|
||||
fn node_line_end(node: Node) -> Point {
|
||||
Point::new(node.end_position().row as u32 + 1, 0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
|
||||
use util::test::{generate_marked_text, marked_text_offsets_by};
|
||||
|
||||
fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot())
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
|
||||
let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
|
||||
(text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
|
||||
}
|
||||
|
||||
fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
|
||||
let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
|
||||
|
||||
let buffer = create_buffer(&text, cx);
|
||||
let cursor_point = cursor.to_point(&buffer);
|
||||
|
||||
let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
|
||||
.expect("Should select an excerpt");
|
||||
pretty_assertions::assert_eq!(
|
||||
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
|
||||
generate_marked_text(&text, &[expected_excerpt], false)
|
||||
);
|
||||
assert!(excerpt.size <= options.max_bytes);
|
||||
assert!(excerpt.range.contains(&cursor));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
|
||||
zlog::init_test();
|
||||
let text = r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
« let ˇy = 2;
|
||||
» let z = 3;
|
||||
}"#;
|
||||
|
||||
let options = EditPredictionExcerptOptions {
|
||||
max_bytes: 20,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
include_parent_signatures: false,
|
||||
};
|
||||
|
||||
check_example(options, text, cx);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
|
||||
zlog::init_test();
|
||||
let text = r#"
|
||||
fn foo() {}
|
||||
|
||||
«fn main() {
|
||||
let x = 1;
|
||||
let ˇy = 2;
|
||||
let z = 3;
|
||||
}
|
||||
»
|
||||
fn bar() {}"#;
|
||||
|
||||
let options = EditPredictionExcerptOptions {
|
||||
max_bytes: 65,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
include_parent_signatures: false,
|
||||
};
|
||||
|
||||
check_example(options, text, cx);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
|
||||
zlog::init_test();
|
||||
let text = r#"
|
||||
fn main() {
|
||||
« let x = 1;
|
||||
let ˇy = 2;
|
||||
let z = 3;
|
||||
»}"#;
|
||||
|
||||
let options = EditPredictionExcerptOptions {
|
||||
max_bytes: 50,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
include_parent_signatures: false,
|
||||
};
|
||||
|
||||
check_example(options, text, cx);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_line_based_selection(cx: &mut TestAppContext) {
|
||||
zlog::init_test();
|
||||
let text = r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
« if true {
|
||||
let ˇy = 2;
|
||||
}
|
||||
let z = 3;
|
||||
»}"#;
|
||||
|
||||
let options = EditPredictionExcerptOptions {
|
||||
max_bytes: 60,
|
||||
min_bytes: 45,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
include_parent_signatures: false,
|
||||
};
|
||||
|
||||
check_example(options, text, cx);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
|
||||
zlog::init_test();
|
||||
let text = r#"
|
||||
fn main() {
|
||||
« let a = 1;
|
||||
let b = 2;
|
||||
let c = 3;
|
||||
let ˇd = 4;
|
||||
let e = 5;
|
||||
let f = 6;
|
||||
»
|
||||
let g = 7;
|
||||
}"#;
|
||||
|
||||
let options = EditPredictionExcerptOptions {
|
||||
max_bytes: 120,
|
||||
min_bytes: 10,
|
||||
target_before_cursor_over_total_bytes: 0.6,
|
||||
include_parent_signatures: false,
|
||||
};
|
||||
|
||||
check_example(options, text, cx);
|
||||
}
|
||||
}
|
||||
126
crates/edit_prediction_context/src/outline.rs
Normal file
126
crates/edit_prediction_context/src/outline.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use language::{BufferSnapshot, SyntaxMapMatches};
|
||||
use std::{cmp::Reverse, ops::Range};
|
||||
|
||||
use crate::declaration::Identifier;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * how to handle multiple name captures? for now last one wins
|
||||
//
|
||||
// * annotation ranges
|
||||
//
|
||||
// * new "signature" capture for outline queries
|
||||
//
|
||||
// * Check parent behavior of "int x, y = 0" declarations in a test
|
||||
|
||||
pub struct OutlineDeclaration {
|
||||
pub parent_index: Option<usize>,
|
||||
pub identifier: Identifier,
|
||||
pub item_range: Range<usize>,
|
||||
pub signature_range: Range<usize>,
|
||||
}
|
||||
|
||||
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
|
||||
declarations_overlapping_range(0..buffer.len(), buffer)
|
||||
}
|
||||
|
||||
pub fn declarations_overlapping_range(
|
||||
range: Range<usize>,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Vec<OutlineDeclaration> {
|
||||
let mut declarations = OutlineIterator::new(range, buffer).collect::<Vec<_>>();
|
||||
declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end)));
|
||||
|
||||
let mut parent_stack: Vec<(usize, Range<usize>)> = Vec::new();
|
||||
for (index, declaration) in declarations.iter_mut().enumerate() {
|
||||
while let Some((top_parent_index, top_parent_range)) = parent_stack.last() {
|
||||
if declaration.item_range.start >= top_parent_range.end {
|
||||
parent_stack.pop();
|
||||
} else {
|
||||
declaration.parent_index = Some(*top_parent_index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
parent_stack.push((index, declaration.item_range.clone()));
|
||||
}
|
||||
declarations
|
||||
}
|
||||
|
||||
/// Iterates outline items without being ordered w.r.t. nested items and without populating
|
||||
/// `parent`.
|
||||
pub struct OutlineIterator<'a> {
|
||||
buffer: &'a BufferSnapshot,
|
||||
matches: SyntaxMapMatches<'a>,
|
||||
}
|
||||
|
||||
impl<'a> OutlineIterator<'a> {
|
||||
pub fn new(range: Range<usize>, buffer: &'a BufferSnapshot) -> Self {
|
||||
let matches = buffer.syntax.matches(range, &buffer.text, |grammar| {
|
||||
grammar.outline_config.as_ref().map(|c| &c.query)
|
||||
});
|
||||
|
||||
Self { buffer, matches }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for OutlineIterator<'a> {
|
||||
type Item = OutlineDeclaration;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(mat) = self.matches.peek() {
|
||||
let config = self.matches.grammars()[mat.grammar_index]
|
||||
.outline_config
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let mut name_range = None;
|
||||
let mut item_range = None;
|
||||
let mut signature_start = None;
|
||||
let mut signature_end = None;
|
||||
|
||||
let mut add_to_signature = |range: Range<usize>| {
|
||||
if signature_start.is_none() {
|
||||
signature_start = Some(range.start);
|
||||
}
|
||||
signature_end = Some(range.end);
|
||||
};
|
||||
|
||||
for capture in mat.captures {
|
||||
let range = capture.node.byte_range();
|
||||
if capture.index == config.name_capture_ix {
|
||||
name_range = Some(range.clone());
|
||||
add_to_signature(range);
|
||||
} else if Some(capture.index) == config.context_capture_ix
|
||||
|| Some(capture.index) == config.extra_context_capture_ix
|
||||
{
|
||||
add_to_signature(range);
|
||||
} else if capture.index == config.item_capture_ix {
|
||||
item_range = Some(range.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let language_id = mat.language.id();
|
||||
self.matches.advance();
|
||||
|
||||
if let Some(name_range) = name_range
|
||||
&& let Some(item_range) = item_range
|
||||
&& let Some(signature_start) = signature_start
|
||||
&& let Some(signature_end) = signature_end
|
||||
{
|
||||
let name = self
|
||||
.buffer
|
||||
.text_for_range(name_range)
|
||||
.collect::<String>()
|
||||
.into();
|
||||
|
||||
return Some(OutlineDeclaration {
|
||||
identifier: Identifier { name, language_id },
|
||||
item_range: item_range,
|
||||
signature_range: signature_start..signature_end,
|
||||
parent_index: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
109
crates/edit_prediction_context/src/reference.rs
Normal file
109
crates/edit_prediction_context/src/reference.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use language::BufferSnapshot;
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::{
|
||||
declaration::Identifier,
|
||||
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Reference {
|
||||
pub identifier: Identifier,
|
||||
pub range: Range<usize>,
|
||||
pub region: ReferenceRegion,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ReferenceRegion {
|
||||
Breadcrumb,
|
||||
Nearby,
|
||||
}
|
||||
|
||||
pub fn references_in_excerpt(
|
||||
excerpt: &EditPredictionExcerpt,
|
||||
excerpt_text: &EditPredictionExcerptText,
|
||||
snapshot: &BufferSnapshot,
|
||||
) -> HashMap<Identifier, Vec<Reference>> {
|
||||
let mut references = identifiers_in_range(
|
||||
excerpt.range.clone(),
|
||||
excerpt_text.body.as_str(),
|
||||
ReferenceRegion::Nearby,
|
||||
snapshot,
|
||||
);
|
||||
|
||||
for (range, text) in excerpt
|
||||
.parent_signature_ranges
|
||||
.iter()
|
||||
.zip(excerpt_text.parent_signatures.iter())
|
||||
{
|
||||
references.extend(identifiers_in_range(
|
||||
range.clone(),
|
||||
text.as_str(),
|
||||
ReferenceRegion::Breadcrumb,
|
||||
snapshot,
|
||||
));
|
||||
}
|
||||
|
||||
let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
|
||||
for reference in references {
|
||||
identifier_to_references
|
||||
.entry(reference.identifier.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(reference);
|
||||
}
|
||||
identifier_to_references
|
||||
}
|
||||
|
||||
/// Finds all nodes which have a "variable" match from the highlights query within the offset range.
|
||||
pub fn identifiers_in_range(
|
||||
range: Range<usize>,
|
||||
range_text: &str,
|
||||
reference_region: ReferenceRegion,
|
||||
buffer: &BufferSnapshot,
|
||||
) -> Vec<Reference> {
|
||||
let mut matches = buffer
|
||||
.syntax
|
||||
.matches(range.clone(), &buffer.text, |grammar| {
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
|
||||
let mut references = Vec::new();
|
||||
let mut last_added_range = None;
|
||||
while let Some(mat) = matches.peek() {
|
||||
let config = matches.grammars()[mat.grammar_index]
|
||||
.highlights_config
|
||||
.as_ref();
|
||||
|
||||
for capture in mat.captures {
|
||||
if let Some(config) = config {
|
||||
if config.identifier_capture_indices.contains(&capture.index) {
|
||||
let node_range = capture.node.byte_range();
|
||||
|
||||
// sometimes multiple highlight queries match - this deduplicates them
|
||||
if Some(node_range.clone()) == last_added_range {
|
||||
continue;
|
||||
}
|
||||
|
||||
let identifier_text =
|
||||
&range_text[node_range.start - range.start..node_range.end - range.start];
|
||||
references.push(Reference {
|
||||
identifier: Identifier {
|
||||
name: identifier_text.into(),
|
||||
language_id: mat.language.id(),
|
||||
},
|
||||
range: node_range.clone(),
|
||||
region: reference_region,
|
||||
});
|
||||
last_added_range = Some(node_range);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matches.advance();
|
||||
}
|
||||
references
|
||||
}
|
||||
853
crates/edit_prediction_context/src/syntax_index.rs
Normal file
853
crates/edit_prediction_context/src/syntax_index.rs
Normal file
@@ -0,0 +1,853 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::lock::Mutex;
|
||||
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
|
||||
use language::{Buffer, BufferEvent};
|
||||
use project::buffer_store::{BufferStore, BufferStoreEvent};
|
||||
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
|
||||
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
|
||||
use slotmap::SlotMap;
|
||||
use text::BufferId;
|
||||
use util::{ResultExt as _, debug_panic, some_or_debug_panic};
|
||||
|
||||
use crate::declaration::{
|
||||
BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
|
||||
};
|
||||
use crate::outline::declarations_in_buffer;
|
||||
|
||||
// TODO:
|
||||
//
|
||||
// * Skip for remote projects
|
||||
//
|
||||
// * Consider making SyntaxIndex not an Entity.
|
||||
|
||||
// Potential future improvements:
|
||||
//
|
||||
// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
|
||||
// references are present and their scores.
|
||||
|
||||
// Potential future optimizations:
|
||||
//
|
||||
// * Cache of buffers for files
|
||||
//
|
||||
// * Parse files directly instead of loading into a Rope. Make SyntaxMap generic to handle embedded
|
||||
// languages? Will also need to find line boundaries, but that can be done by scanning characters in
|
||||
// the flat representation.
|
||||
//
|
||||
// * Use something similar to slotmap without key versions.
|
||||
//
|
||||
// * Concurrent slotmap
|
||||
//
|
||||
// * Use queue for parsing
|
||||
//
|
||||
|
||||
pub struct SyntaxIndex {
|
||||
state: Arc<Mutex<SyntaxIndexState>>,
|
||||
project: WeakEntity<Project>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SyntaxIndexState {
|
||||
declarations: SlotMap<DeclarationId, Declaration>,
|
||||
identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
|
||||
files: HashMap<ProjectEntryId, FileState>,
|
||||
buffers: HashMap<BufferId, BufferState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct FileState {
|
||||
declarations: Vec<DeclarationId>,
|
||||
task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BufferState {
|
||||
declarations: Vec<DeclarationId>,
|
||||
task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl SyntaxIndex {
|
||||
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
|
||||
let mut this = Self {
|
||||
project: project.downgrade(),
|
||||
state: Arc::new(Mutex::new(SyntaxIndexState::default())),
|
||||
};
|
||||
|
||||
let worktree_store = project.read(cx).worktree_store();
|
||||
cx.subscribe(&worktree_store, Self::handle_worktree_store_event)
|
||||
.detach();
|
||||
|
||||
for worktree in worktree_store
|
||||
.read(cx)
|
||||
.worktrees()
|
||||
.map(|w| w.read(cx).snapshot())
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
for entry in worktree.files(false, 0) {
|
||||
this.update_file(
|
||||
entry.id,
|
||||
ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path: entry.path.clone(),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let buffer_store = project.read(cx).buffer_store().clone();
|
||||
for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
|
||||
this.register_buffer(&buffer, cx);
|
||||
}
|
||||
cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
|
||||
.detach();
|
||||
|
||||
this
|
||||
}
|
||||
|
||||
fn handle_worktree_store_event(
|
||||
&mut self,
|
||||
_worktree_store: Entity<WorktreeStore>,
|
||||
event: &WorktreeStoreEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
use WorktreeStoreEvent::*;
|
||||
match event {
|
||||
WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
|
||||
let state = Arc::downgrade(&self.state);
|
||||
let worktree_id = *worktree_id;
|
||||
let updated_entries_set = updated_entries_set.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let Some(state) = state.upgrade() else { return };
|
||||
for (path, entry_id, path_change) in updated_entries_set.iter() {
|
||||
if let PathChange::Removed = path_change {
|
||||
state.lock().await.files.remove(entry_id);
|
||||
} else {
|
||||
let project_path = ProjectPath {
|
||||
worktree_id,
|
||||
path: path.clone(),
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_file(*entry_id, project_path, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
|
||||
let project_entry_id = *project_entry_id;
|
||||
self.with_state(cx, move |state| {
|
||||
state.files.remove(&project_entry_id);
|
||||
})
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_buffer_store_event(
|
||||
&mut self,
|
||||
_buffer_store: Entity<BufferStore>,
|
||||
event: &BufferStoreEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
use BufferStoreEvent::*;
|
||||
match event {
|
||||
BufferAdded(buffer) => self.register_buffer(buffer, cx),
|
||||
BufferOpened { .. }
|
||||
| BufferChangedFilePath { .. }
|
||||
| BufferDropped { .. }
|
||||
| SharedBufferClosed { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
|
||||
&self.state
|
||||
}
|
||||
|
||||
fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
|
||||
if let Some(mut state) = self.state.try_lock() {
|
||||
f(&mut state);
|
||||
return;
|
||||
}
|
||||
let state = Arc::downgrade(&self.state);
|
||||
cx.background_spawn(async move {
|
||||
let Some(state) = state.upgrade() else {
|
||||
return None;
|
||||
};
|
||||
let mut state = state.lock().await;
|
||||
Some(f(&mut state))
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
let buffer_id = buffer.read(cx).remote_id();
|
||||
cx.observe_release(buffer, move |this, _buffer, cx| {
|
||||
this.with_state(cx, move |state| {
|
||||
if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
|
||||
SyntaxIndexState::remove_buffer_declarations(
|
||||
&buffer_state.declarations,
|
||||
&mut state.declarations,
|
||||
&mut state.identifiers,
|
||||
);
|
||||
}
|
||||
})
|
||||
})
|
||||
.detach();
|
||||
cx.subscribe(buffer, Self::handle_buffer_event).detach();
|
||||
|
||||
self.update_buffer(buffer.clone(), cx);
|
||||
}
|
||||
|
||||
fn handle_buffer_event(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
event: &BufferEvent,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
BufferEvent::Edited => self.update_buffer(buffer, cx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
let buffer = buffer_entity.read(cx);
|
||||
|
||||
let Some(project_entry_id) =
|
||||
project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let buffer_id = buffer.remote_id();
|
||||
|
||||
let mut parse_status = buffer.parse_status();
|
||||
let snapshot_task = cx.spawn({
|
||||
let weak_buffer = buffer_entity.downgrade();
|
||||
async move |_, cx| {
|
||||
while *parse_status.borrow() != language::ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
|
||||
}
|
||||
});
|
||||
|
||||
let parse_task = cx.background_spawn(async move {
|
||||
let snapshot = snapshot_task.await?;
|
||||
let rope = snapshot.text.as_rope().clone();
|
||||
|
||||
anyhow::Ok((
|
||||
declarations_in_buffer(&snapshot)
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
(
|
||||
item.parent_index,
|
||||
BufferDeclaration::from_outline(item, &rope),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
rope,
|
||||
))
|
||||
});
|
||||
|
||||
let task = cx.spawn({
|
||||
async move |this, cx| {
|
||||
let Ok((declarations, rope)) = parse_task.await else {
|
||||
return;
|
||||
};
|
||||
|
||||
this.update(cx, move |this, cx| {
|
||||
this.with_state(cx, move |state| {
|
||||
let buffer_state = state
|
||||
.buffers
|
||||
.entry(buffer_id)
|
||||
.or_insert_with(Default::default);
|
||||
|
||||
SyntaxIndexState::remove_buffer_declarations(
|
||||
&buffer_state.declarations,
|
||||
&mut state.declarations,
|
||||
&mut state.identifiers,
|
||||
);
|
||||
|
||||
let mut new_ids = Vec::with_capacity(declarations.len());
|
||||
state.declarations.reserve(declarations.len());
|
||||
for (parent_index, mut declaration) in declarations {
|
||||
declaration.parent = parent_index
|
||||
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
|
||||
|
||||
let identifier = declaration.identifier.clone();
|
||||
let declaration_id = state.declarations.insert(Declaration::Buffer {
|
||||
rope: rope.clone(),
|
||||
buffer_id,
|
||||
declaration,
|
||||
project_entry_id,
|
||||
});
|
||||
new_ids.push(declaration_id);
|
||||
|
||||
state
|
||||
.identifiers
|
||||
.entry(identifier)
|
||||
.or_default()
|
||||
.insert(declaration_id);
|
||||
}
|
||||
|
||||
buffer_state.declarations = new_ids;
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
self.with_state(cx, move |state| {
|
||||
state
|
||||
.buffers
|
||||
.entry(buffer_id)
|
||||
.or_insert_with(Default::default)
|
||||
.task = Some(task)
|
||||
});
|
||||
}
|
||||
|
||||
fn update_file(
|
||||
&mut self,
|
||||
entry_id: ProjectEntryId,
|
||||
project_path: ProjectPath,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some(project) = self.project.upgrade() else {
|
||||
return;
|
||||
};
|
||||
let project = project.read(cx);
|
||||
let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else {
|
||||
return;
|
||||
};
|
||||
let language_registry = project.languages().clone();
|
||||
|
||||
let snapshot_task = worktree.update(cx, |worktree, cx| {
|
||||
let load_task = worktree.load_file(&project_path.path, cx);
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let loaded_file = load_task.await?;
|
||||
let language = language_registry
|
||||
.language_for_file_path(&project_path.path)
|
||||
.await
|
||||
.log_err();
|
||||
|
||||
let buffer = cx.new(|cx| {
|
||||
let mut buffer = Buffer::local(loaded_file.text, cx);
|
||||
buffer.set_language(language, cx);
|
||||
buffer
|
||||
})?;
|
||||
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
|
||||
while *parse_status.borrow() != language::ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
}
|
||||
|
||||
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
|
||||
})
|
||||
});
|
||||
|
||||
let parse_task = cx.background_spawn(async move {
|
||||
let snapshot = snapshot_task.await?;
|
||||
let rope = snapshot.as_rope();
|
||||
let declarations = declarations_in_buffer(&snapshot)
|
||||
.into_iter()
|
||||
.map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
|
||||
.collect::<Vec<_>>();
|
||||
anyhow::Ok(declarations)
|
||||
});
|
||||
|
||||
let task = cx.spawn({
|
||||
async move |this, cx| {
|
||||
// TODO: how to handle errors?
|
||||
let Ok(declarations) = parse_task.await else {
|
||||
return;
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.with_state(cx, move |state| {
|
||||
let file_state =
|
||||
state.files.entry(entry_id).or_insert_with(Default::default);
|
||||
|
||||
for old_declaration_id in &file_state.declarations {
|
||||
let Some(declaration) = state.declarations.remove(*old_declaration_id)
|
||||
else {
|
||||
debug_panic!("declaration not found");
|
||||
continue;
|
||||
};
|
||||
if let Some(identifier_declarations) =
|
||||
state.identifiers.get_mut(declaration.identifier())
|
||||
{
|
||||
identifier_declarations.remove(old_declaration_id);
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_ids = Vec::with_capacity(declarations.len());
|
||||
state.declarations.reserve(declarations.len());
|
||||
|
||||
for (parent_index, mut declaration) in declarations {
|
||||
declaration.parent = parent_index
|
||||
.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
|
||||
|
||||
let identifier = declaration.identifier.clone();
|
||||
let declaration_id = state.declarations.insert(Declaration::File {
|
||||
project_entry_id: entry_id,
|
||||
declaration,
|
||||
});
|
||||
new_ids.push(declaration_id);
|
||||
|
||||
state
|
||||
.identifiers
|
||||
.entry(identifier)
|
||||
.or_default()
|
||||
.insert(declaration_id);
|
||||
}
|
||||
|
||||
file_state.declarations = new_ids;
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
|
||||
self.with_state(cx, move |state| {
|
||||
state
|
||||
.files
|
||||
.entry(entry_id)
|
||||
.or_insert_with(Default::default)
|
||||
.task = Some(task);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl SyntaxIndexState {
|
||||
pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
|
||||
self.declarations.get(id)
|
||||
}
|
||||
|
||||
/// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
|
||||
///
|
||||
/// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
|
||||
pub fn declarations_for_identifier<const N: usize>(
|
||||
&self,
|
||||
identifier: &Identifier,
|
||||
) -> Vec<Declaration> {
|
||||
// make sure to not have a large stack allocation
|
||||
assert!(N < 32);
|
||||
|
||||
let Some(declaration_ids) = self.identifiers.get(&identifier) else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
let mut result = Vec::with_capacity(N);
|
||||
let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
|
||||
let mut file_declarations = Vec::new();
|
||||
|
||||
for declaration_id in declaration_ids {
|
||||
let declaration = self.declarations.get(*declaration_id);
|
||||
let Some(declaration) = some_or_debug_panic(declaration) else {
|
||||
continue;
|
||||
};
|
||||
match declaration {
|
||||
Declaration::Buffer {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
included_buffer_entry_ids.push(*project_entry_id);
|
||||
result.push(declaration.clone());
|
||||
if result.len() == N {
|
||||
return Vec::new();
|
||||
}
|
||||
}
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
if !included_buffer_entry_ids.contains(&project_entry_id) {
|
||||
file_declarations.push(declaration.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for declaration in file_declarations {
|
||||
match declaration {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => {
|
||||
if !included_buffer_entry_ids.contains(&project_entry_id) {
|
||||
result.push(declaration);
|
||||
|
||||
if result.len() == N {
|
||||
return Vec::new();
|
||||
}
|
||||
}
|
||||
}
|
||||
Declaration::Buffer { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
|
||||
match declaration {
|
||||
Declaration::File {
|
||||
project_entry_id, ..
|
||||
} => self
|
||||
.files
|
||||
.get(project_entry_id)
|
||||
.map(|file_state| file_state.declarations.len())
|
||||
.unwrap_or_default(),
|
||||
Declaration::Buffer { buffer_id, .. } => self
|
||||
.buffers
|
||||
.get(buffer_id)
|
||||
.map(|buffer_state| buffer_state.declarations.len())
|
||||
.unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_buffer_declarations(
|
||||
old_declaration_ids: &[DeclarationId],
|
||||
declarations: &mut SlotMap<DeclarationId, Declaration>,
|
||||
identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
|
||||
) {
|
||||
for old_declaration_id in old_declaration_ids {
|
||||
let Some(declaration) = declarations.remove(*old_declaration_id) else {
|
||||
debug_panic!("declaration not found");
|
||||
continue;
|
||||
};
|
||||
if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
|
||||
identifier_declarations.remove(old_declaration_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use text::OffsetRangeExt as _;
|
||||
use util::path;
|
||||
|
||||
use crate::syntax_index::SyntaxIndex;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
|
||||
let (project, index, rust_lang_id) = init_test(cx).await;
|
||||
let main = Identifier {
|
||||
name: "main".into(),
|
||||
language_id: rust_lang_id,
|
||||
};
|
||||
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&main);
|
||||
assert_eq!(decls.len(), 2);
|
||||
|
||||
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, main.clone());
|
||||
assert_eq!(decl.item_range_in_file, 32..280);
|
||||
|
||||
let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
assert_eq!(decl.identifier, main);
|
||||
assert_eq!(decl.item_range_in_file, 0..98);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parents_in_file(cx: &mut TestAppContext) {
|
||||
let (project, index, rust_lang_id) = init_test(cx).await;
|
||||
let test_process_data = Identifier {
|
||||
name: "test_process_data".into(),
|
||||
language_id: rust_lang_id,
|
||||
};
|
||||
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
|
||||
assert_eq!(decls.len(), 1);
|
||||
|
||||
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, test_process_data);
|
||||
|
||||
let parent_id = decl.parent.unwrap();
|
||||
let parent = index_state.declaration(parent_id).unwrap();
|
||||
let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
|
||||
assert_eq!(
|
||||
parent_decl.identifier,
|
||||
Identifier {
|
||||
name: "tests".into(),
|
||||
language_id: rust_lang_id
|
||||
}
|
||||
);
|
||||
assert_eq!(parent_decl.parent, None);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parents_in_buffer(cx: &mut TestAppContext) {
|
||||
let (project, index, rust_lang_id) = init_test(cx).await;
|
||||
let test_process_data = Identifier {
|
||||
name: "test_process_data".into(),
|
||||
language_id: rust_lang_id,
|
||||
};
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path("c.rs", cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
|
||||
assert_eq!(decls.len(), 1);
|
||||
|
||||
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, test_process_data);
|
||||
|
||||
let parent_id = decl.parent.unwrap();
|
||||
let parent = index_state.declaration(parent_id).unwrap();
|
||||
let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
|
||||
assert_eq!(
|
||||
parent_decl.identifier,
|
||||
Identifier {
|
||||
name: "tests".into(),
|
||||
language_id: rust_lang_id
|
||||
}
|
||||
);
|
||||
assert_eq!(parent_decl.parent, None);
|
||||
});
|
||||
|
||||
drop(buffer);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_declarations_limt(cx: &mut TestAppContext) {
|
||||
let (_, index, rust_lang_id) = init_test(cx).await;
|
||||
|
||||
let index_state = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
let index_state = index_state.lock().await;
|
||||
let decls = index_state.declarations_for_identifier::<1>(&Identifier {
|
||||
name: "main".into(),
|
||||
language_id: rust_lang_id,
|
||||
});
|
||||
assert_eq!(decls.len(), 0);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_buffer_shadow(cx: &mut TestAppContext) {
|
||||
let (project, index, rust_lang_id) = init_test(cx).await;
|
||||
|
||||
let main = Identifier {
|
||||
name: "main".into(),
|
||||
language_id: rust_lang_id,
|
||||
};
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let project_path = project.find_project_path("c.rs", cx).unwrap();
|
||||
project.open_buffer(project_path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
|
||||
{
|
||||
let index_state = index_state_arc.lock().await;
|
||||
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&main);
|
||||
assert_eq!(decls.len(), 2);
|
||||
let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
|
||||
assert_eq!(decl.identifier, main);
|
||||
assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
|
||||
|
||||
expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
// Drop the buffer and wait for release
|
||||
cx.update(|_| {
|
||||
drop(buffer);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let index_state = index_state_arc.lock().await;
|
||||
|
||||
cx.update(|cx| {
|
||||
let decls = index_state.declarations_for_identifier::<8>(&main);
|
||||
assert_eq!(decls.len(), 2);
|
||||
expect_file_decl("c.rs", &decls[0], &project, cx);
|
||||
expect_file_decl("a.rs", &decls[1], &project, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn expect_buffer_decl<'a>(
|
||||
path: &str,
|
||||
declaration: &'a Declaration,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> &'a BufferDeclaration {
|
||||
if let Declaration::Buffer {
|
||||
declaration,
|
||||
project_entry_id,
|
||||
..
|
||||
} = declaration
|
||||
{
|
||||
let project_path = project
|
||||
.read(cx)
|
||||
.path_for_entry(*project_entry_id, cx)
|
||||
.unwrap();
|
||||
assert_eq!(project_path.path.as_ref(), Path::new(path),);
|
||||
declaration
|
||||
} else {
|
||||
panic!("Expected a buffer declaration, found {:?}", declaration);
|
||||
}
|
||||
}
|
||||
|
||||
fn expect_file_decl<'a>(
|
||||
path: &str,
|
||||
declaration: &'a Declaration,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> &'a FileDeclaration {
|
||||
if let Declaration::File {
|
||||
declaration,
|
||||
project_entry_id: file,
|
||||
} = declaration
|
||||
{
|
||||
assert_eq!(
|
||||
project
|
||||
.read(cx)
|
||||
.path_for_entry(*file, cx)
|
||||
.unwrap()
|
||||
.path
|
||||
.as_ref(),
|
||||
Path::new(path),
|
||||
);
|
||||
declaration
|
||||
} else {
|
||||
panic!("Expected a file declaration, found {:?}", declaration);
|
||||
}
|
||||
}
|
||||
|
||||
async fn init_test(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"a.rs": indoc! {r#"
|
||||
fn main() {
|
||||
let x = 1;
|
||||
let y = 2;
|
||||
let z = add(x, y);
|
||||
println!("Result: {}", z);
|
||||
}
|
||||
|
||||
fn add(a: i32, b: i32) -> i32 {
|
||||
a + b
|
||||
}
|
||||
"#},
|
||||
"b.rs": indoc! {"
|
||||
pub struct Config {
|
||||
pub name: String,
|
||||
pub value: i32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(name: String, value: i32) -> Self {
|
||||
Config { name, value }
|
||||
}
|
||||
}
|
||||
"},
|
||||
"c.rs": indoc! {r#"
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let data: Vec<i32> = args[1..]
|
||||
.iter()
|
||||
.filter_map(|s| s.parse().ok())
|
||||
.collect();
|
||||
let result = process_data(data);
|
||||
println!("{:?}", result);
|
||||
}
|
||||
|
||||
fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
|
||||
let mut counts = HashMap::new();
|
||||
for value in data {
|
||||
*counts.entry(value).or_insert(0) += 1;
|
||||
}
|
||||
counts
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_process_data() {
|
||||
let data = vec![1, 2, 2, 3];
|
||||
let result = process_data(data);
|
||||
assert_eq!(result.get(&2), Some(&2));
|
||||
}
|
||||
}
|
||||
"#}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
|
||||
let lang = rust_lang();
|
||||
let lang_id = lang.id();
|
||||
language_registry.add(Arc::new(lang));
|
||||
|
||||
let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
|
||||
cx.run_until_parked();
|
||||
|
||||
(project, index, lang_id)
|
||||
}
|
||||
|
||||
fn rust_lang() -> Language {
|
||||
Language::new(
|
||||
LanguageConfig {
|
||||
name: "Rust".into(),
|
||||
matcher: LanguageMatcher {
|
||||
path_suffixes: vec!["rs".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
)
|
||||
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
241
crates/edit_prediction_context/src/text_similarity.rs
Normal file
241
crates/edit_prediction_context/src/text_similarity.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
use regex::Regex;
|
||||
use std::{collections::HashMap, sync::LazyLock};
|
||||
|
||||
use crate::reference::Reference;
|
||||
|
||||
// TODO: Consider implementing sliding window similarity matching like
|
||||
// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
|
||||
//
|
||||
// That implementation could actually be more efficient - no need to track words in the window that
|
||||
// are not in the query.
|
||||
|
||||
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IdentifierOccurrences {
|
||||
identifier_to_count: HashMap<String, usize>,
|
||||
total_count: usize,
|
||||
}
|
||||
|
||||
impl IdentifierOccurrences {
|
||||
pub fn within_string(code: &str) -> Self {
|
||||
Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn within_references(references: &[Reference]) -> Self {
|
||||
Self::from_iterator(
|
||||
references
|
||||
.iter()
|
||||
.map(|reference| reference.identifier.name.as_ref()),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut identifier_to_count = HashMap::new();
|
||||
let mut total_count = 0;
|
||||
for identifier in identifier_iterator {
|
||||
// TODO: Score matches that match case higher?
|
||||
//
|
||||
// TODO: Also include unsplit identifier?
|
||||
for identifier_part in split_identifier(identifier) {
|
||||
identifier_to_count
|
||||
.entry(identifier_part.to_lowercase())
|
||||
.and_modify(|count| *count += 1)
|
||||
.or_insert(1);
|
||||
total_count += 1;
|
||||
}
|
||||
}
|
||||
IdentifierOccurrences {
|
||||
identifier_to_count,
|
||||
total_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Splits camelcase / snakecase / kebabcase / pascalcase
|
||||
//
|
||||
// TODO: Make this more efficient / elegant.
|
||||
fn split_identifier<'a>(identifier: &'a str) -> Vec<&'a str> {
|
||||
let mut parts = Vec::new();
|
||||
let mut start = 0;
|
||||
let chars: Vec<char> = identifier.chars().collect();
|
||||
|
||||
if chars.is_empty() {
|
||||
return parts;
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
while i < chars.len() {
|
||||
let ch = chars[i];
|
||||
|
||||
// Handle explicit delimiters (underscore and hyphen)
|
||||
if ch == '_' || ch == '-' {
|
||||
if i > start {
|
||||
parts.push(&identifier[start..i]);
|
||||
}
|
||||
start = i + 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle camelCase and PascalCase transitions
|
||||
if i > 0 && i < chars.len() {
|
||||
let prev_char = chars[i - 1];
|
||||
|
||||
// Transition from lowercase/digit to uppercase
|
||||
if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
|
||||
parts.push(&identifier[start..i]);
|
||||
start = i;
|
||||
}
|
||||
// Handle sequences like "XMLParser" -> ["XML", "Parser"]
|
||||
else if i + 1 < chars.len()
|
||||
&& ch.is_uppercase()
|
||||
&& chars[i + 1].is_lowercase()
|
||||
&& prev_char.is_uppercase()
|
||||
{
|
||||
parts.push(&identifier[start..i]);
|
||||
start = i;
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Add the last part if there's any remaining
|
||||
if start < identifier.len() {
|
||||
parts.push(&identifier[start..]);
|
||||
}
|
||||
|
||||
// Filter out empty strings
|
||||
parts.into_iter().filter(|s| !s.is_empty()).collect()
|
||||
}
|
||||
|
||||
pub fn jaccard_similarity<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.identifier_to_count
|
||||
.keys()
|
||||
.filter(|key| set_b.identifier_to_count.contains_key(*key))
|
||||
.count();
|
||||
let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
|
||||
intersection as f32 / union as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn overlap_coefficient<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
let intersection = set_a
|
||||
.identifier_to_count
|
||||
.keys()
|
||||
.filter(|key| set_b.identifier_to_count.contains_key(*key))
|
||||
.count();
|
||||
intersection as f32 / set_a.identifier_to_count.len() as f32
|
||||
}
|
||||
|
||||
// TODO
|
||||
#[allow(dead_code)]
|
||||
pub fn weighted_jaccard_similarity<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
let mut denominator_a = 0;
|
||||
let mut used_count_b = 0;
|
||||
for (symbol, count_a) in set_a.identifier_to_count.iter() {
|
||||
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
|
||||
numerator += count_a.min(count_b);
|
||||
denominator_a += count_a.max(count_b);
|
||||
used_count_b += count_b;
|
||||
}
|
||||
|
||||
let denominator = denominator_a + (set_b.total_count - used_count_b);
|
||||
if denominator == 0 {
|
||||
0.0
|
||||
} else {
|
||||
numerator as f32 / denominator as f32
|
||||
}
|
||||
}
|
||||
|
||||
pub fn weighted_overlap_coefficient<'a>(
|
||||
mut set_a: &'a IdentifierOccurrences,
|
||||
mut set_b: &'a IdentifierOccurrences,
|
||||
) -> f32 {
|
||||
if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
|
||||
std::mem::swap(&mut set_a, &mut set_b);
|
||||
}
|
||||
|
||||
let mut numerator = 0;
|
||||
for (symbol, count_a) in set_a.identifier_to_count.iter() {
|
||||
let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
|
||||
numerator += count_a.min(count_b);
|
||||
}
|
||||
|
||||
let denominator = set_a.total_count.min(set_b.total_count);
|
||||
if denominator == 0 {
|
||||
0.0
|
||||
} else {
|
||||
numerator as f32 / denominator as f32
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_split_identifier() {
|
||||
assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
|
||||
assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
|
||||
assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
|
||||
assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
|
||||
assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_functions() {
|
||||
// 10 identifier parts, 8 unique
|
||||
// Repeats: 2 "outline", 2 "items"
|
||||
let set_a = IdentifierOccurrences::within_string(
|
||||
"let mut outline_items = query_outline_items(&language, &tree, &source);",
|
||||
);
|
||||
// 14 identifier parts, 11 unique
|
||||
// Repeats: 2 "outline", 2 "language", 2 "tree"
|
||||
let set_b = IdentifierOccurrences::within_string(
|
||||
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
|
||||
);
|
||||
|
||||
// 6 overlaps: "outline", "items", "query", "language", "tree", "source"
|
||||
// 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
|
||||
assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
|
||||
|
||||
// Numerator is one more than before due to both having 2 "outline".
|
||||
// Denominator is the same except for 3 more due to the non-overlapping duplicates
|
||||
assert_eq!(
|
||||
weighted_jaccard_similarity(&set_a, &set_b),
|
||||
7.0 / (7.0 + 7.0 + 3.0)
|
||||
);
|
||||
|
||||
// Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
|
||||
assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
|
||||
|
||||
// Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
|
||||
// the smaller set, 10.
|
||||
assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
|
||||
}
|
||||
}
|
||||
35
crates/edit_prediction_context/src/wip_requests.rs
Normal file
35
crates/edit_prediction_context/src/wip_requests.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
|
||||
// `zeta_context.rs` in cloud.
|
||||
//
|
||||
// * Run excerpt selection at several different sizes, send the largest size with offsets within for
|
||||
// the smaller sizes.
|
||||
//
|
||||
// * Longer event history.
|
||||
//
|
||||
// * Many more snippets than could fit in model context - allows ranking experimentation.
|
||||
|
||||
pub struct Zeta2Request {
|
||||
pub event_history: Vec<Event>,
|
||||
pub excerpt: String,
|
||||
pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
|
||||
/// Within `excerpt`
|
||||
pub cursor_position: usize,
|
||||
pub signatures: Vec<String>,
|
||||
pub retrieved_declarations: Vec<ReferencedDeclaration>,
|
||||
}
|
||||
|
||||
pub struct Zeta2ExcerptSubset {
|
||||
/// Within `excerpt` text.
|
||||
pub excerpt_range: Range<usize>,
|
||||
/// Within `signatures`.
|
||||
pub parent_signatures: Vec<usize>,
|
||||
}
|
||||
|
||||
pub struct ReferencedDeclaration {
|
||||
pub text: Arc<str>,
|
||||
/// Range within `text`
|
||||
pub signature_range: Range<usize>,
|
||||
/// Indices within `signatures`.
|
||||
pub parent_signatures: Vec<usize>,
|
||||
// A bunch of score metrics
|
||||
}
|
||||
@@ -20549,7 +20549,9 @@ impl Editor {
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
self.update_lsp_data(false, Some(buffer_id), window, cx);
|
||||
if self.active_diagnostics != ActiveDiagnostic::All {
|
||||
self.update_lsp_data(false, Some(buffer_id), window, cx);
|
||||
}
|
||||
cx.emit(EditorEvent::ExcerptsAdded {
|
||||
buffer: buffer.clone(),
|
||||
predecessor: *predecessor,
|
||||
|
||||
@@ -19265,7 +19265,7 @@ async fn test_expand_diff_hunk_at_excerpt_boundary(cx: &mut TestAppContext) {
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
// When the start of a hunk coincides with the start of its excerpt,
|
||||
// the hunk is expanded. When the start of a a hunk is earlier than
|
||||
// the hunk is expanded. When the start of a hunk is earlier than
|
||||
// the start of its excerpt, the hunk is not expanded.
|
||||
cx.assert_state_with_diff(
|
||||
"
|
||||
|
||||
@@ -9694,7 +9694,7 @@ impl EditorScrollbars {
|
||||
editor_bounds.bottom_left(),
|
||||
size(
|
||||
// The horizontal viewport size differs from the space available for the
|
||||
// horizontal scrollbar, so we have to manually stich it together here.
|
||||
// horizontal scrollbar, so we have to manually stitch it together here.
|
||||
editor_bounds.size.width - right_margin,
|
||||
scrollbar_width,
|
||||
),
|
||||
|
||||
@@ -521,6 +521,14 @@ impl PickerDelegate for BranchListDelegate {
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.tooltip({
|
||||
let branch_name = entry.branch.name().to_string();
|
||||
if entry.is_new {
|
||||
Tooltip::text(format!("Create branch \"{}\"", branch_name))
|
||||
} else {
|
||||
Tooltip::text(branch_name)
|
||||
}
|
||||
})
|
||||
.child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
|
||||
@@ -3748,7 +3748,10 @@ impl GitPanel {
|
||||
.custom_scrollbars(
|
||||
Scrollbars::for_settings::<GitPanelSettings>()
|
||||
.tracked_scroll_handle(self.scroll_handle.clone())
|
||||
.with_track_along(ScrollAxes::Horizontal),
|
||||
.with_track_along(
|
||||
ScrollAxes::Horizontal,
|
||||
cx.theme().colors().panel_background,
|
||||
),
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
|
||||
@@ -115,7 +115,7 @@ seahash = "4.1"
|
||||
semantic_version.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
slotmap = "1.0.6"
|
||||
slotmap.workspace = true
|
||||
smallvec.workspace = true
|
||||
smol.workspace = true
|
||||
stacksafe.workspace = true
|
||||
|
||||
@@ -151,9 +151,9 @@ impl From<Hsla> for Rgba {
|
||||
};
|
||||
|
||||
Rgba {
|
||||
r,
|
||||
g,
|
||||
b,
|
||||
r: r.clamp(0., 1.),
|
||||
g: g.clamp(0., 1.),
|
||||
b: b.clamp(0., 1.),
|
||||
a: color.a,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,10 @@ unsafe fn build_classes() {
|
||||
APP_DELEGATE_CLASS = unsafe {
|
||||
let mut decl = ClassDecl::new("GPUIApplicationDelegate", class!(NSResponder)).unwrap();
|
||||
decl.add_ivar::<*mut c_void>(MAC_PLATFORM_IVAR);
|
||||
decl.add_method(
|
||||
sel!(applicationWillFinishLaunching:),
|
||||
will_finish_launching as extern "C" fn(&mut Object, Sel, id),
|
||||
);
|
||||
decl.add_method(
|
||||
sel!(applicationDidFinishLaunching:),
|
||||
did_finish_launching as extern "C" fn(&mut Object, Sel, id),
|
||||
@@ -1356,6 +1360,23 @@ unsafe fn get_mac_platform(object: &mut Object) -> &MacPlatform {
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn will_finish_launching(_this: &mut Object, _: Sel, _: id) {
|
||||
unsafe {
|
||||
let user_defaults: id = msg_send![class!(NSUserDefaults), standardUserDefaults];
|
||||
|
||||
// The autofill heuristic controller causes slowdown and high CPU usage.
|
||||
// We don't know exactly why. This disables the full heuristic controller.
|
||||
//
|
||||
// Adapted from: https://github.com/ghostty-org/ghostty/pull/8625
|
||||
let name = ns_string("NSAutoFillHeuristicControllerEnabled");
|
||||
let existing_value: id = msg_send![user_defaults, objectForKey: name];
|
||||
if existing_value == nil {
|
||||
let false_value: id = msg_send![class!(NSNumber), numberWithBool:false];
|
||||
let _: () = msg_send![user_defaults, setObject: false_value forKey: name];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn did_finish_launching(this: &mut Object, _: Sel, _: id) {
|
||||
unsafe {
|
||||
let app: id = msg_send![APP_CLASS, sharedApplication];
|
||||
|
||||
@@ -1016,7 +1016,7 @@ fn handle_gpu_device_lost(
|
||||
all_windows: &std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
|
||||
text_system: &std::sync::Weak<DirectWriteTextSystem>,
|
||||
) {
|
||||
// Here we wait a bit to ensure the the system has time to recover from the device lost state.
|
||||
// Here we wait a bit to ensure the system has time to recover from the device lost state.
|
||||
// If we don't wait, the final drawing result will be blank.
|
||||
std::thread::sleep(std::time::Duration::from_millis(350));
|
||||
|
||||
|
||||
@@ -684,8 +684,16 @@ impl PlatformWindow for WindowsWindow {
|
||||
.executor
|
||||
.spawn(async move {
|
||||
this.set_window_placement().log_err();
|
||||
unsafe { SetActiveWindow(hwnd).log_err() };
|
||||
unsafe { SetFocus(Some(hwnd)).log_err() };
|
||||
|
||||
unsafe {
|
||||
// If the window is minimized, restore it.
|
||||
if IsIconic(hwnd).as_bool() {
|
||||
ShowWindowAsync(hwnd, SW_RESTORE).ok().log_err();
|
||||
}
|
||||
|
||||
SetActiveWindow(hwnd).log_err();
|
||||
SetFocus(Some(hwnd)).log_err();
|
||||
}
|
||||
|
||||
// premium ragebait by windows, this is needed because the window
|
||||
// must have received an input event to be able to set itself to foreground
|
||||
|
||||
@@ -318,6 +318,12 @@ pub fn read_proxy_from_env() -> Option<Url> {
|
||||
.and_then(|env| env.parse().ok())
|
||||
}
|
||||
|
||||
pub fn read_no_proxy_from_env() -> Option<String> {
|
||||
const ENV_VARS: &[&str] = &["NO_PROXY", "no_proxy"];
|
||||
|
||||
ENV_VARS.iter().find_map(|var| std::env::var(var).ok())
|
||||
}
|
||||
|
||||
pub struct BlockedHttpClient;
|
||||
|
||||
impl BlockedHttpClient {
|
||||
|
||||
@@ -68,7 +68,7 @@ With both approaches, would need to record the buffer version and use that when
|
||||
|
||||
* Mode to navigate to source code on every element change while picking.
|
||||
|
||||
* Tracking of more source locations - currently the source location is often in a ui compoenent. Ideally this would have a way for the components to indicate that they are probably not the source location the user is looking for.
|
||||
* Tracking of more source locations - currently the source location is often in a ui component. Ideally this would have a way for the components to indicate that they are probably not the source location the user is looking for.
|
||||
|
||||
- Could have `InspectorElementId` be `Vec<(ElementId, Option<Location>)>`, but if there are multiple code paths that construct the same element this would cause them to be considered different.
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ struct BufferBranchState {
|
||||
/// state of a buffer.
|
||||
pub struct BufferSnapshot {
|
||||
pub text: text::BufferSnapshot,
|
||||
pub(crate) syntax: SyntaxSnapshot,
|
||||
pub syntax: SyntaxSnapshot,
|
||||
file: Option<Arc<dyn File>>,
|
||||
diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>,
|
||||
remote_selections: TreeMap<ReplicaId, SelectionSet>,
|
||||
@@ -660,7 +660,10 @@ impl HighlightedTextBuilder {
|
||||
syntax_snapshot: &'a SyntaxSnapshot,
|
||||
) -> BufferChunks<'a> {
|
||||
let captures = syntax_snapshot.captures(range.clone(), snapshot, |grammar| {
|
||||
grammar.highlights_query.as_ref()
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
|
||||
let highlight_maps = captures
|
||||
@@ -3246,7 +3249,10 @@ impl BufferSnapshot {
|
||||
|
||||
fn get_highlights(&self, range: Range<usize>) -> (SyntaxMapCaptures<'_>, Vec<HighlightMap>) {
|
||||
let captures = self.syntax.captures(range, &self.text, |grammar| {
|
||||
grammar.highlights_query.as_ref()
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
let highlight_maps = captures
|
||||
.grammars()
|
||||
@@ -3310,18 +3316,25 @@ impl BufferSnapshot {
|
||||
|
||||
/// Iterates over every [`SyntaxLayer`] in the buffer.
|
||||
pub fn syntax_layers(&self) -> impl Iterator<Item = SyntaxLayer<'_>> + '_ {
|
||||
self.syntax
|
||||
.layers_for_range(0..self.len(), &self.text, true)
|
||||
self.syntax_layers_for_range(0..self.len(), true)
|
||||
}
|
||||
|
||||
pub fn syntax_layer_at<D: ToOffset>(&self, position: D) -> Option<SyntaxLayer<'_>> {
|
||||
let offset = position.to_offset(self);
|
||||
self.syntax
|
||||
.layers_for_range(offset..offset, &self.text, false)
|
||||
self.syntax_layers_for_range(offset..offset, false)
|
||||
.filter(|l| l.node().end_byte() > offset)
|
||||
.last()
|
||||
}
|
||||
|
||||
pub fn syntax_layers_for_range<D: ToOffset>(
|
||||
&self,
|
||||
range: Range<D>,
|
||||
include_hidden: bool,
|
||||
) -> impl Iterator<Item = SyntaxLayer<'_>> + '_ {
|
||||
self.syntax
|
||||
.layers_for_range(range, &self.text, include_hidden)
|
||||
}
|
||||
|
||||
pub fn smallest_syntax_layer_containing<D: ToOffset>(
|
||||
&self,
|
||||
range: Range<D>,
|
||||
@@ -3859,9 +3872,12 @@ impl BufferSnapshot {
|
||||
text: item.text,
|
||||
highlight_ranges: item.highlight_ranges,
|
||||
name_ranges: item.name_ranges,
|
||||
body_range: item.body_range.map(|body_range| {
|
||||
self.anchor_after(body_range.start)..self.anchor_before(body_range.end)
|
||||
}),
|
||||
signature_range: item
|
||||
.signature_range
|
||||
.map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)),
|
||||
body_range: item
|
||||
.body_range
|
||||
.map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)),
|
||||
annotation_range: annotation_row_range.map(|annotation_range| {
|
||||
self.anchor_after(Point::new(annotation_range.start, 0))
|
||||
..self.anchor_before(Point::new(
|
||||
@@ -3901,38 +3917,51 @@ impl BufferSnapshot {
|
||||
|
||||
let mut open_point = None;
|
||||
let mut close_point = None;
|
||||
let mut buffer_ranges = Vec::new();
|
||||
for capture in mat.captures {
|
||||
let node_is_name;
|
||||
if capture.index == config.name_capture_ix {
|
||||
node_is_name = true;
|
||||
} else if Some(capture.index) == config.context_capture_ix
|
||||
|| (Some(capture.index) == config.extra_context_capture_ix && include_extra_context)
|
||||
{
|
||||
node_is_name = false;
|
||||
} else {
|
||||
if Some(capture.index) == config.open_capture_ix {
|
||||
open_point = Some(Point::from_ts_point(capture.node.end_position()));
|
||||
} else if Some(capture.index) == config.close_capture_ix {
|
||||
close_point = Some(Point::from_ts_point(capture.node.start_position()));
|
||||
}
|
||||
|
||||
continue;
|
||||
let mut signature_start = None;
|
||||
let mut signature_end = None;
|
||||
let mut extend_signature_range = |node: tree_sitter::Node| {
|
||||
if signature_start.is_none() {
|
||||
signature_start = Some(Point::from_ts_point(node.start_position()));
|
||||
}
|
||||
signature_end = Some(Point::from_ts_point(node.end_position()));
|
||||
};
|
||||
|
||||
let mut range = capture.node.start_byte()..capture.node.end_byte();
|
||||
let start = capture.node.start_position();
|
||||
if capture.node.end_position().row > start.row {
|
||||
let mut buffer_ranges = Vec::new();
|
||||
let mut add_to_buffer_ranges = |node: tree_sitter::Node, node_is_name| {
|
||||
let mut range = node.start_byte()..node.end_byte();
|
||||
let start = node.start_position();
|
||||
if node.end_position().row > start.row {
|
||||
range.end = range.start + self.line_len(start.row as u32) as usize - start.column;
|
||||
}
|
||||
|
||||
if !range.is_empty() {
|
||||
buffer_ranges.push((range, node_is_name));
|
||||
}
|
||||
};
|
||||
|
||||
for capture in mat.captures {
|
||||
if capture.index == config.name_capture_ix {
|
||||
add_to_buffer_ranges(capture.node, true);
|
||||
extend_signature_range(capture.node);
|
||||
} else if Some(capture.index) == config.context_capture_ix
|
||||
|| (Some(capture.index) == config.extra_context_capture_ix && include_extra_context)
|
||||
{
|
||||
add_to_buffer_ranges(capture.node, false);
|
||||
extend_signature_range(capture.node);
|
||||
} else {
|
||||
if Some(capture.index) == config.open_capture_ix {
|
||||
open_point = Some(Point::from_ts_point(capture.node.end_position()));
|
||||
} else if Some(capture.index) == config.close_capture_ix {
|
||||
close_point = Some(Point::from_ts_point(capture.node.start_position()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if buffer_ranges.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut text = String::new();
|
||||
let mut highlight_ranges = Vec::new();
|
||||
let mut name_ranges = Vec::new();
|
||||
@@ -3941,7 +3970,6 @@ impl BufferSnapshot {
|
||||
true,
|
||||
);
|
||||
let mut last_buffer_range_end = 0;
|
||||
|
||||
for (buffer_range, is_name) in buffer_ranges {
|
||||
let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end;
|
||||
if space_added {
|
||||
@@ -3983,12 +4011,17 @@ impl BufferSnapshot {
|
||||
last_buffer_range_end = buffer_range.end;
|
||||
}
|
||||
|
||||
let signature_range = signature_start
|
||||
.zip(signature_end)
|
||||
.map(|(start, end)| start..end);
|
||||
|
||||
Some(OutlineItem {
|
||||
depth: 0, // We'll calculate the depth later
|
||||
range: item_point_range,
|
||||
text,
|
||||
highlight_ranges,
|
||||
name_ranges,
|
||||
signature_range,
|
||||
body_range: open_point.zip(close_point).map(|(start, end)| start..end),
|
||||
annotation_range: None,
|
||||
})
|
||||
|
||||
@@ -81,7 +81,9 @@ pub use language_registry::{
|
||||
};
|
||||
pub use lsp::{LanguageServerId, LanguageServerName};
|
||||
pub use outline::*;
|
||||
pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer, ToTreeSitterPoint, TreeSitterOptions};
|
||||
pub use syntax_map::{
|
||||
OwnedSyntaxLayer, SyntaxLayer, SyntaxMapMatches, ToTreeSitterPoint, TreeSitterOptions,
|
||||
};
|
||||
pub use text::{AnchorRangeExt, LineEnding};
|
||||
pub use tree_sitter::{Node, Parser, Tree, TreeCursor};
|
||||
|
||||
@@ -1154,7 +1156,7 @@ pub struct Grammar {
|
||||
id: GrammarId,
|
||||
pub ts_language: tree_sitter::Language,
|
||||
pub(crate) error_query: Option<Query>,
|
||||
pub(crate) highlights_query: Option<Query>,
|
||||
pub highlights_config: Option<HighlightsConfig>,
|
||||
pub(crate) brackets_config: Option<BracketsConfig>,
|
||||
pub(crate) redactions_config: Option<RedactionConfig>,
|
||||
pub(crate) runnable_config: Option<RunnableConfig>,
|
||||
@@ -1168,6 +1170,11 @@ pub struct Grammar {
|
||||
pub(crate) highlight_map: Mutex<HighlightMap>,
|
||||
}
|
||||
|
||||
pub struct HighlightsConfig {
|
||||
pub query: Query,
|
||||
pub identifier_capture_indices: Vec<u32>,
|
||||
}
|
||||
|
||||
struct IndentConfig {
|
||||
query: Query,
|
||||
indent_capture_ix: u32,
|
||||
@@ -1332,7 +1339,7 @@ impl Language {
|
||||
grammar: ts_language.map(|ts_language| {
|
||||
Arc::new(Grammar {
|
||||
id: GrammarId::new(),
|
||||
highlights_query: None,
|
||||
highlights_config: None,
|
||||
brackets_config: None,
|
||||
outline_config: None,
|
||||
text_object_config: None,
|
||||
@@ -1430,7 +1437,29 @@ impl Language {
|
||||
|
||||
pub fn with_highlights_query(mut self, source: &str) -> Result<Self> {
|
||||
let grammar = self.grammar_mut()?;
|
||||
grammar.highlights_query = Some(Query::new(&grammar.ts_language, source)?);
|
||||
let query = Query::new(&grammar.ts_language, source)?;
|
||||
|
||||
let mut identifier_capture_indices = Vec::new();
|
||||
for name in [
|
||||
"variable",
|
||||
"constant",
|
||||
"constructor",
|
||||
"function",
|
||||
"function.method",
|
||||
"function.method.call",
|
||||
"function.special",
|
||||
"property",
|
||||
"type",
|
||||
"type.interface",
|
||||
] {
|
||||
identifier_capture_indices.extend(query.capture_index_for_name(name));
|
||||
}
|
||||
|
||||
grammar.highlights_config = Some(HighlightsConfig {
|
||||
query,
|
||||
identifier_capture_indices,
|
||||
});
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -1856,7 +1885,10 @@ impl Language {
|
||||
let tree = grammar.parse_text(text, None);
|
||||
let captures =
|
||||
SyntaxSnapshot::single_tree_captures(range.clone(), text, &tree, self, |grammar| {
|
||||
grammar.highlights_query.as_ref()
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
let highlight_maps = vec![grammar.highlight_map()];
|
||||
let mut offset = 0;
|
||||
@@ -1885,10 +1917,10 @@ impl Language {
|
||||
|
||||
pub fn set_theme(&self, theme: &SyntaxTheme) {
|
||||
if let Some(grammar) = self.grammar.as_ref()
|
||||
&& let Some(highlights_query) = &grammar.highlights_query
|
||||
&& let Some(highlights_config) = &grammar.highlights_config
|
||||
{
|
||||
*grammar.highlight_map.lock() =
|
||||
HighlightMap::new(highlights_query.capture_names(), theme);
|
||||
HighlightMap::new(highlights_config.query.capture_names(), theme);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2103,8 +2135,9 @@ impl Grammar {
|
||||
|
||||
pub fn highlight_id_for_name(&self, name: &str) -> Option<HighlightId> {
|
||||
let capture_id = self
|
||||
.highlights_query
|
||||
.highlights_config
|
||||
.as_ref()?
|
||||
.query
|
||||
.capture_index_for_name(name)?;
|
||||
Some(self.highlight_map.lock().get(capture_id))
|
||||
}
|
||||
|
||||
@@ -552,6 +552,7 @@ pub struct LanguageSettingsContent {
|
||||
///
|
||||
/// Default: ["..."]
|
||||
#[serde(default)]
|
||||
#[settings_ui(skip)]
|
||||
pub language_servers: Option<Vec<String>>,
|
||||
/// Controls where the `editor::Rewrap` action is allowed for this language.
|
||||
///
|
||||
|
||||
@@ -19,6 +19,7 @@ pub struct OutlineItem<T> {
|
||||
pub text: String,
|
||||
pub highlight_ranges: Vec<(Range<usize>, HighlightStyle)>,
|
||||
pub name_ranges: Vec<Range<usize>>,
|
||||
pub signature_range: Option<Range<T>>,
|
||||
pub body_range: Option<Range<T>>,
|
||||
pub annotation_range: Option<Range<T>>,
|
||||
}
|
||||
@@ -35,6 +36,10 @@ impl<T: ToPoint> OutlineItem<T> {
|
||||
text: self.text.clone(),
|
||||
highlight_ranges: self.highlight_ranges.clone(),
|
||||
name_ranges: self.name_ranges.clone(),
|
||||
signature_range: self
|
||||
.signature_range
|
||||
.as_ref()
|
||||
.map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)),
|
||||
body_range: self
|
||||
.body_range
|
||||
.as_ref()
|
||||
@@ -208,6 +213,7 @@ mod tests {
|
||||
text: "class Foo".to_string(),
|
||||
highlight_ranges: vec![],
|
||||
name_ranges: vec![6..9],
|
||||
signature_range: None,
|
||||
body_range: None,
|
||||
annotation_range: None,
|
||||
},
|
||||
@@ -217,6 +223,7 @@ mod tests {
|
||||
text: "private".to_string(),
|
||||
highlight_ranges: vec![],
|
||||
name_ranges: vec![],
|
||||
signature_range: None,
|
||||
body_range: None,
|
||||
annotation_range: None,
|
||||
},
|
||||
@@ -241,6 +248,7 @@ mod tests {
|
||||
text: "fn process".to_string(),
|
||||
highlight_ranges: vec![],
|
||||
name_ranges: vec![3..10],
|
||||
signature_range: None,
|
||||
body_range: None,
|
||||
annotation_range: None,
|
||||
},
|
||||
@@ -250,6 +258,7 @@ mod tests {
|
||||
text: "struct DataProcessor".to_string(),
|
||||
highlight_ranges: vec![],
|
||||
name_ranges: vec![7..20],
|
||||
signature_range: None,
|
||||
body_range: None,
|
||||
annotation_range: None,
|
||||
},
|
||||
|
||||
@@ -1409,12 +1409,15 @@ fn assert_capture_ranges(
|
||||
) {
|
||||
let mut actual_ranges = Vec::<Range<usize>>::new();
|
||||
let captures = syntax_map.captures(0..buffer.len(), buffer, |grammar| {
|
||||
grammar.highlights_query.as_ref()
|
||||
grammar
|
||||
.highlights_config
|
||||
.as_ref()
|
||||
.map(|config| &config.query)
|
||||
});
|
||||
let queries = captures
|
||||
.grammars()
|
||||
.iter()
|
||||
.map(|grammar| grammar.highlights_query.as_ref().unwrap())
|
||||
.map(|grammar| &grammar.highlights_config.as_ref().unwrap().query)
|
||||
.collect::<Vec<_>>();
|
||||
for capture in captures {
|
||||
let name = &queries[capture.grammar_index].capture_names()[capture.index as usize];
|
||||
|
||||
@@ -29,6 +29,7 @@ copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
editor.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
@@ -61,6 +62,7 @@ util.workspace = true
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
workspace-hack.workspace = true
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
zed_env_vars.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
||||
295
crates/language_models/src/api_key.rs
Normal file
295
crates/language_models/src/api_key.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, future};
|
||||
use gpui::{AsyncApp, Context, SharedString, Task};
|
||||
use language_model::AuthenticateError;
|
||||
use std::{
|
||||
fmt::{Display, Formatter},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use zed_env_vars::EnvVar;
|
||||
|
||||
/// Manages a single API key for a language model provider. API keys either come from environment
|
||||
/// variables or the system keychain.
|
||||
///
|
||||
/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
|
||||
/// only used with that URL.
|
||||
pub struct ApiKeyState {
|
||||
url: SharedString,
|
||||
load_status: LoadStatus,
|
||||
load_task: Option<future::Shared<Task<()>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LoadStatus {
|
||||
NotPresent,
|
||||
Error(String),
|
||||
Loaded(ApiKey),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApiKey {
|
||||
source: ApiKeySource,
|
||||
key: Arc<str>,
|
||||
}
|
||||
|
||||
impl ApiKeyState {
|
||||
pub fn new(url: SharedString) -> Self {
|
||||
Self {
|
||||
url,
|
||||
load_status: LoadStatus::NotPresent,
|
||||
load_task: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_key(&self) -> bool {
|
||||
matches!(self.load_status, LoadStatus::Loaded { .. })
|
||||
}
|
||||
|
||||
pub fn is_from_env_var(&self) -> bool {
|
||||
match &self.load_status {
|
||||
LoadStatus::Loaded(ApiKey {
|
||||
source: ApiKeySource::EnvVar { .. },
|
||||
..
|
||||
}) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
|
||||
/// there is no key or for URL mismatches, and the mismatch case is logged.
|
||||
///
|
||||
/// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
|
||||
/// called with this URL.
|
||||
pub fn key(&self, url: &str) -> Option<Arc<str>> {
|
||||
let api_key = match &self.load_status {
|
||||
LoadStatus::Loaded(api_key) => api_key,
|
||||
_ => return None,
|
||||
};
|
||||
if url == self.url.as_str() {
|
||||
Some(api_key.key.clone())
|
||||
} else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
|
||||
log::warn!(
|
||||
"{} is now being used with URL {}, when initially it was used with URL {}",
|
||||
var_name,
|
||||
url,
|
||||
self.url
|
||||
);
|
||||
Some(api_key.key.clone())
|
||||
} else {
|
||||
// bug case because load_if_needed should be called whenever the url may have changed
|
||||
log::error!(
|
||||
"bug: Attempted to use API key associated with URL {} instead with URL {}",
|
||||
self.url,
|
||||
url
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Set or delete the API key in the system keychain.
|
||||
pub fn store<Ent: 'static>(
|
||||
&mut self,
|
||||
url: SharedString,
|
||||
key: Option<String>,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
|
||||
cx: &Context<Ent>,
|
||||
) -> Task<Result<()>> {
|
||||
if self.is_from_env_var() {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"bug: attempted to store API key in system keychain when API key is from env var",
|
||||
)));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn(async move |ent, cx| {
|
||||
if let Some(key) = &key {
|
||||
credentials_provider
|
||||
.write_credentials(&url, "Bearer", key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
} else {
|
||||
credentials_provider
|
||||
.delete_credentials(&url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
ent.update(cx, |ent, cx| {
|
||||
let this = get_this(ent);
|
||||
this.url = url;
|
||||
this.load_status = match &key {
|
||||
Some(key) => LoadStatus::Loaded(ApiKey {
|
||||
source: ApiKeySource::SystemKeychain,
|
||||
key: key.as_str().into(),
|
||||
}),
|
||||
None => LoadStatus::NotPresent,
|
||||
};
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Reloads the API key if the current API key is associated with a different URL.
|
||||
///
|
||||
/// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
|
||||
/// interchangeably - URL change should correspond to some user initiated change.
|
||||
pub fn handle_url_change<Ent: 'static>(
|
||||
&mut self,
|
||||
url: SharedString,
|
||||
env_var: &EnvVar,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
|
||||
cx: &mut Context<Ent>,
|
||||
) {
|
||||
if url != self.url {
|
||||
if !self.is_from_env_var() {
|
||||
// loading will continue even though this result task is dropped
|
||||
let _task = self.load_if_needed(url, env_var, get_this, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// If needed, loads the API key associated with the given URL from the system keychain. When a
|
||||
/// non-empty environment variable is provided, it will be used instead. If called when an API
|
||||
/// key was already loaded for a different URL, that key will be cleared before loading.
|
||||
///
|
||||
/// Dropping the returned Task does not cancel key loading.
|
||||
pub fn load_if_needed<Ent: 'static>(
|
||||
&mut self,
|
||||
url: SharedString,
|
||||
env_var: &EnvVar,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
|
||||
cx: &mut Context<Ent>,
|
||||
) -> Task<Result<(), AuthenticateError>> {
|
||||
if let LoadStatus::Loaded { .. } = &self.load_status
|
||||
&& self.url == url
|
||||
{
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
if let Some(key) = &env_var.value
|
||||
&& !key.is_empty()
|
||||
{
|
||||
let api_key = ApiKey::from_env(env_var.name.clone(), key);
|
||||
self.url = url;
|
||||
self.load_status = LoadStatus::Loaded(api_key);
|
||||
self.load_task = None;
|
||||
cx.notify();
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let task = if let Some(load_task) = &self.load_task {
|
||||
load_task.clone()
|
||||
} else {
|
||||
let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
|
||||
self.url = url;
|
||||
self.load_status = LoadStatus::NotPresent;
|
||||
self.load_task = Some(load_task.clone());
|
||||
cx.notify();
|
||||
load_task
|
||||
};
|
||||
|
||||
cx.spawn(async move |ent, cx| {
|
||||
task.await;
|
||||
ent.update(cx, |ent, _cx| {
|
||||
get_this(ent).load_status.clone().into_authenticate_result()
|
||||
})
|
||||
.ok();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn load<Ent: 'static>(
|
||||
url: SharedString,
|
||||
get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
|
||||
cx: &Context<Ent>,
|
||||
) -> Task<()> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
cx.spawn({
|
||||
async move |ent, cx| {
|
||||
let load_status =
|
||||
ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
|
||||
.await;
|
||||
ent.update(cx, |ent, cx| {
|
||||
let this = get_this(ent);
|
||||
this.url = url;
|
||||
this.load_status = load_status;
|
||||
this.load_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiKey {
|
||||
pub fn key(&self) -> &str {
|
||||
&self.key
|
||||
}
|
||||
|
||||
pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
|
||||
Self {
|
||||
source: ApiKeySource::EnvVar(env_var_name),
|
||||
key: key.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_from_system_keychain(
|
||||
url: &str,
|
||||
credentials_provider: &dyn CredentialsProvider,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self, AuthenticateError> {
|
||||
Self::load_from_system_keychain_impl(url, credentials_provider, cx)
|
||||
.await
|
||||
.into_authenticate_result()
|
||||
}
|
||||
|
||||
async fn load_from_system_keychain_impl(
|
||||
url: &str,
|
||||
credentials_provider: &dyn CredentialsProvider,
|
||||
cx: &AsyncApp,
|
||||
) -> LoadStatus {
|
||||
if url.is_empty() {
|
||||
return LoadStatus::NotPresent;
|
||||
}
|
||||
let read_result = credentials_provider.read_credentials(&url, cx).await;
|
||||
let api_key = match read_result {
|
||||
Ok(Some((_, api_key))) => api_key,
|
||||
Ok(None) => return LoadStatus::NotPresent,
|
||||
Err(err) => return LoadStatus::Error(err.to_string()),
|
||||
};
|
||||
let key = match str::from_utf8(&api_key) {
|
||||
Ok(key) => key,
|
||||
Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
|
||||
};
|
||||
LoadStatus::Loaded(Self {
|
||||
source: ApiKeySource::SystemKeychain,
|
||||
key: key.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LoadStatus {
|
||||
fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
|
||||
match self {
|
||||
LoadStatus::Loaded(api_key) => Ok(api_key),
|
||||
LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
|
||||
LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum ApiKeySource {
|
||||
EnvVar(SharedString),
|
||||
SystemKeychain,
|
||||
}
|
||||
|
||||
impl Display for ApiKeySource {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
|
||||
ApiKeySource::SystemKeychain => write!(f, "system keychain"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ use gpui::{App, Context, Entity};
|
||||
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
|
||||
use provider::deepseek::DeepSeekLanguageModelProvider;
|
||||
|
||||
mod api_key;
|
||||
pub mod provider;
|
||||
mod settings;
|
||||
pub mod ui;
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::api_key::ApiKeyState;
|
||||
use crate::ui::InstructionListItem;
|
||||
use anthropic::{
|
||||
AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
|
||||
ToolResultPart, Usage,
|
||||
ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent,
|
||||
ToolResultContent, ToolResultPart, Usage,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
};
|
||||
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, FontStyle, Task, TextStyle, WhiteSpace};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
|
||||
@@ -27,11 +23,12 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
|
||||
@@ -97,91 +94,52 @@ pub struct AnthropicLanguageModelProvider {
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
|
||||
const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.ok();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let key = AnthropicLanguageModelProvider::api_key(cx);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let key = key.await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(key.key);
|
||||
this.api_key_from_env = key.from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ApiKey {
|
||||
pub key: String,
|
||||
pub from_env: bool,
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -197,30 +155,16 @@ impl AnthropicLanguageModelProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey, AuthenticateError>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.api_url
|
||||
.clone();
|
||||
fn settings(cx: &App) -> &AnthropicSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).anthropic
|
||||
}
|
||||
|
||||
if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
|
||||
Task::ready(Ok(ApiKey {
|
||||
key,
|
||||
from_env: true,
|
||||
}))
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
ANTHROPIC_API_URL.into()
|
||||
} else {
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
|
||||
Ok(ApiKey {
|
||||
key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
from_env: false,
|
||||
})
|
||||
})
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -275,11 +219,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.available_models
|
||||
.iter()
|
||||
{
|
||||
for model in &AnthropicLanguageModelProvider::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
anthropic::Model::Custom {
|
||||
@@ -327,7 +267,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -417,11 +358,11 @@ impl AnthropicModel {
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped").into())).boxed();
|
||||
};
|
||||
|
||||
let beta_headers = self.model.beta_headers();
|
||||
@@ -483,7 +424,10 @@ impl LanguageModel for AnthropicModel {
|
||||
}
|
||||
|
||||
fn api_key(&self, cx: &App) -> Option<String> {
|
||||
self.state.read(cx).api_key.clone()
|
||||
self.state.read_with(cx, |state, cx| {
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
state.api_key_state.key(&api_url).map(|key| key.to_string())
|
||||
})
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
@@ -984,15 +928,17 @@ impl ConfigurationView {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -1001,11 +947,11 @@ impl ConfigurationView {
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
@@ -1040,7 +986,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials...")).into_any()
|
||||
@@ -1079,7 +1025,7 @@ impl Render for ConfigurationView {
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
|
||||
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
@@ -1099,9 +1045,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = AnthropicLanguageModelProvider::api_url(cx);
|
||||
if api_url == ANTHROPIC_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -1112,7 +1063,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.disabled(env_var_set)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -215,11 +215,21 @@ impl State {
|
||||
|
||||
self.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.find(|model| {
|
||||
response
|
||||
.default_model
|
||||
.as_ref()
|
||||
.is_some_and(|default_model_id| &model.id == default_model_id)
|
||||
})
|
||||
.cloned();
|
||||
self.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.find(|model| {
|
||||
response
|
||||
.default_fast_model
|
||||
.as_ref()
|
||||
.is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
|
||||
})
|
||||
.cloned();
|
||||
self.recommended_models = response
|
||||
.recommended_models
|
||||
@@ -541,29 +551,36 @@ where
|
||||
|
||||
impl From<ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: ApiError) -> Self {
|
||||
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body)
|
||||
&& cloud_error.code.starts_with("upstream_http_")
|
||||
{
|
||||
let status = if let Some(status) = cloud_error.upstream_status {
|
||||
status
|
||||
} else if cloud_error.code.ends_with("_error") {
|
||||
error.status
|
||||
} else {
|
||||
// If there's a status code in the code string (e.g. "upstream_http_429")
|
||||
// then use that; otherwise, see if the JSON contains a status code.
|
||||
cloud_error
|
||||
.code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code_str| code_str.parse::<u16>().ok())
|
||||
.and_then(|code| StatusCode::from_u16(code).ok())
|
||||
.unwrap_or(error.status)
|
||||
};
|
||||
if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
|
||||
if cloud_error.code.starts_with("upstream_http_") {
|
||||
let status = if let Some(status) = cloud_error.upstream_status {
|
||||
status
|
||||
} else if cloud_error.code.ends_with("_error") {
|
||||
error.status
|
||||
} else {
|
||||
// If there's a status code in the code string (e.g. "upstream_http_429")
|
||||
// then use that; otherwise, see if the JSON contains a status code.
|
||||
cloud_error
|
||||
.code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code_str| code_str.parse::<u16>().ok())
|
||||
.and_then(|code| StatusCode::from_u16(code).ok())
|
||||
.unwrap_or(error.status)
|
||||
};
|
||||
|
||||
return LanguageModelCompletionError::UpstreamProviderError {
|
||||
message: cloud_error.message,
|
||||
status,
|
||||
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
|
||||
};
|
||||
return LanguageModelCompletionError::UpstreamProviderError {
|
||||
message: cloud_error.message,
|
||||
status,
|
||||
retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
|
||||
};
|
||||
}
|
||||
|
||||
return LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
cloud_error.message,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
let retry_after = None;
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use deepseek::DEEPSEEK_API_URL;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{
|
||||
AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
|
||||
WhiteSpace,
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
|
||||
Window,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
@@ -21,16 +21,19 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, prelude::*};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
|
||||
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
@@ -59,95 +62,48 @@ pub struct DeepSeekLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.deepseek
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.deepseek
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.deepseek
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl DeepSeekLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -160,7 +116,20 @@ impl DeepSeekLanguageModelProvider {
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &DeepSeekSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).deepseek
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
DEEPSEEK_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,11 +168,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
|
||||
models.insert("deepseek-chat", deepseek::Model::Chat);
|
||||
models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
|
||||
|
||||
for available_model in AllLanguageModelSettings::get_global(cx)
|
||||
.deepseek
|
||||
.available_models
|
||||
.iter()
|
||||
{
|
||||
for available_model in &Self::settings(cx).available_models {
|
||||
models.insert(
|
||||
&available_model.name,
|
||||
deepseek::Model::Custom {
|
||||
@@ -240,7 +205,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,15 +225,20 @@ impl DeepSeekLanguageModel {
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing DeepSeek API Key")?;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
};
|
||||
let request =
|
||||
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
@@ -610,7 +581,7 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx);
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
@@ -618,12 +589,10 @@ impl ConfigurationView {
|
||||
let state = self.state.clone();
|
||||
cx.spawn(async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -631,10 +600,12 @@ impl ConfigurationView {
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_api_key(cx))?.await)
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
cx.spawn(async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
@@ -672,7 +643,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials...")).into_any()
|
||||
@@ -706,8 +677,7 @@ impl Render for ConfigurationView {
|
||||
)
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"Or set the {} environment variable.",
|
||||
DEEPSEEK_API_KEY_VAR
|
||||
"Or set the {API_KEY_ENV_VAR_NAME} environment variable."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
@@ -727,9 +697,17 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {}", DEEPSEEK_API_KEY_VAR)
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured".to_string()
|
||||
let api_url = DeepSeekLanguageModelProvider::api_url(cx);
|
||||
if api_url == DEEPSEEK_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"API key configured for {}",
|
||||
truncate_and_trailoff(&api_url, 32)
|
||||
)
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
|
||||
@@ -2,13 +2,14 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
|
||||
use google_ai::{
|
||||
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
|
||||
ThinkingConfig, UsageMetadata,
|
||||
};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
|
||||
Window,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
@@ -26,19 +27,19 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
Arc, LazyLock,
|
||||
atomic::{self, AtomicU64},
|
||||
};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::EnvVar;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::api_key::ApiKey;
|
||||
use crate::api_key::ApiKeyState;
|
||||
use crate::ui::InstructionListItem;
|
||||
|
||||
use super::anthropic::ApiKey;
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
|
||||
|
||||
@@ -91,101 +92,56 @@ pub struct GoogleLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY";
|
||||
const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
|
||||
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
|
||||
const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
|
||||
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
|
||||
// Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
|
||||
EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
|
||||
});
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl GoogleLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -201,30 +157,32 @@ impl GoogleLanguageModelProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
|
||||
pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
|
||||
if let Some(key) = API_KEY_ENV_VAR.value.clone() {
|
||||
return Task::ready(Ok(key));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
|
||||
Task::ready(Ok(ApiKey {
|
||||
key,
|
||||
from_env: true,
|
||||
}))
|
||||
} else {
|
||||
cx.spawn(async move |cx| {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
let api_url = Self::api_url(cx).to_string();
|
||||
cx.spawn(async move |cx| {
|
||||
Ok(
|
||||
ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
.key()
|
||||
.to_string(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
Ok(ApiKey {
|
||||
key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
from_env: false,
|
||||
})
|
||||
})
|
||||
fn settings(cx: &App) -> &GoogleSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).google
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
google_ai::API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -269,10 +227,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.google
|
||||
.available_models
|
||||
{
|
||||
for model in &GoogleLanguageModelProvider::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
google_ai::Model::Custom {
|
||||
@@ -317,7 +272,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,11 +296,11 @@ impl GoogleLanguageModel {
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
@@ -418,13 +374,16 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
let model_id = self.model.request_id().to_string();
|
||||
let request = into_google(request, model_id, self.model.mode());
|
||||
let http_client = self.http_client.clone();
|
||||
let api_key = self.state.read(cx).api_key.clone();
|
||||
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
let api_url = settings.api_url.clone();
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
let api_key = self.state.read(cx).api_key_state.key(&api_url);
|
||||
|
||||
async move {
|
||||
let api_key = api_key.context("Missing Google API key")?;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
}
|
||||
.into());
|
||||
};
|
||||
let response = google_ai::count_tokens(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
@@ -852,20 +811,22 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx);
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -874,11 +835,11 @@ impl ConfigurationView {
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
@@ -913,7 +874,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials...")).into_any()
|
||||
@@ -950,7 +911,7 @@ impl Render for ConfigurationView {
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."),
|
||||
format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
@@ -969,9 +930,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {GEMINI_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {} environment variable", API_KEY_ENV_VAR.name)
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = GoogleLanguageModelProvider::api_url(cx);
|
||||
if api_url == google_ai::API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -982,7 +948,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.disabled(env_var_set)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
|
||||
Window,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
@@ -14,24 +14,28 @@ use language_model::{
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use mistral::StreamResponse;
|
||||
use mistral::{MISTRAL_API_URL, StreamResponse};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct MistralSettings {
|
||||
pub api_url: String,
|
||||
@@ -56,96 +60,48 @@ pub struct MistralLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.mistral
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = MistralLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.mistral
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.mistral
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = MistralLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MistralLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -160,6 +116,19 @@ impl MistralLanguageModelProvider {
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &MistralSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).mistral
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
mistral::MISTRAL_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for MistralLanguageModelProvider {
|
||||
@@ -202,10 +171,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.mistral
|
||||
.available_models
|
||||
{
|
||||
for model in &Self::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
mistral::Model::Custom {
|
||||
@@ -254,7 +220,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,15 +243,20 @@ impl MistralLanguageModel {
|
||||
Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).mistral;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = MistralLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing Mistral API Key")?;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
};
|
||||
let request =
|
||||
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
@@ -780,20 +752,22 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx);
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -802,11 +776,11 @@ impl ConfigurationView {
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
@@ -841,7 +815,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials...")).into_any()
|
||||
@@ -878,7 +852,7 @@ impl Render for ConfigurationView {
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
|
||||
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
@@ -897,9 +871,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = MistralLanguageModelProvider::api_url(cx);
|
||||
if api_url == MISTRAL_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -910,7 +889,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.disabled(env_var_set)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use fs::Fs;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{Stream, TryFutureExt, stream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
@@ -10,20 +11,25 @@ use language_model::{
|
||||
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use menu;
|
||||
use ollama::{
|
||||
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionCall,
|
||||
OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
|
||||
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OLLAMA_API_URL,
|
||||
OllamaFunctionCall, OllamaFunctionTool, OllamaToolCall, get_models, show_model,
|
||||
stream_chat_completion,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use std::pin::Pin;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use ui::{ButtonLike, Indicator, List, prelude::*};
|
||||
use util::ResultExt;
|
||||
use ui::{ButtonLike, ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::api_key::ApiKeyState;
|
||||
use crate::ui::InstructionListItem;
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
@@ -33,6 +39,9 @@ const OLLAMA_SITE: &str = "https://ollama.com/";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "OLLAMA_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct OllamaSettings {
|
||||
pub api_url: String,
|
||||
@@ -63,25 +72,61 @@ pub struct OllamaLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key_state: ApiKeyState,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<ollama::Model>,
|
||||
fetched_models: Vec<ollama::Model>,
|
||||
fetch_model_task: Option<Task<Result<()>>>,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
!self.available_models.is_empty()
|
||||
!self.fetched_models.is_empty()
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
let task = self
|
||||
.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx);
|
||||
|
||||
self.fetched_models.clear();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
|
||||
.ok();
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
let task = self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
|
||||
// Always try to fetch models - if no API key is needed (local Ollama), it will work
|
||||
// If API key is needed and provided, it will work
|
||||
// If API key is needed and not provided, it will fail gracefully
|
||||
cx.spawn(async move |this, cx| {
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
|
||||
.ok();
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = Arc::clone(&self.http_client);
|
||||
let api_url = settings.api_url.clone();
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
let api_key = self.api_key_state.key(&api_url);
|
||||
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(async move |this, cx| {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
let models =
|
||||
get_models(http_client.as_ref(), &api_url, api_key.as_deref(), None).await?;
|
||||
|
||||
let tasks = models
|
||||
.into_iter()
|
||||
@@ -92,9 +137,12 @@ impl State {
|
||||
.map(|model| {
|
||||
let http_client = Arc::clone(&http_client);
|
||||
let api_url = api_url.clone();
|
||||
let api_key = api_key.clone();
|
||||
async move {
|
||||
let name = model.name.as_str();
|
||||
let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
|
||||
let capabilities =
|
||||
show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name)
|
||||
.await?;
|
||||
let ollama_model = ollama::Model::new(
|
||||
name,
|
||||
None,
|
||||
@@ -119,7 +167,7 @@ impl State {
|
||||
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.available_models = ollama_models;
|
||||
this.fetched_models = ollama_models;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
@@ -129,15 +177,6 @@ impl State {
|
||||
let task = self.fetch_models(cx);
|
||||
self.fetch_model_task.replace(task);
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let fetch_models_task = self.fetch_models(cx);
|
||||
cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaLanguageModelProvider {
|
||||
@@ -145,30 +184,47 @@ impl OllamaLanguageModelProvider {
|
||||
let this = Self {
|
||||
http_client: http_client.clone(),
|
||||
state: cx.new(|cx| {
|
||||
let subscription = cx.observe_global::<SettingsStore>({
|
||||
let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
|
||||
cx.observe_global::<SettingsStore>({
|
||||
let mut last_settings = OllamaLanguageModelProvider::settings(cx).clone();
|
||||
move |this: &mut State, cx| {
|
||||
let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
if &settings != new_settings {
|
||||
settings = new_settings.clone();
|
||||
this.restart_fetch_models_task(cx);
|
||||
let current_settings = OllamaLanguageModelProvider::settings(cx);
|
||||
let settings_changed = current_settings != &last_settings;
|
||||
if settings_changed {
|
||||
let url_changed = last_settings.api_url != current_settings.api_url;
|
||||
last_settings = current_settings.clone();
|
||||
if url_changed {
|
||||
this.fetched_models.clear();
|
||||
this.authenticate(cx).detach();
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
|
||||
State {
|
||||
http_client,
|
||||
available_models: Default::default(),
|
||||
fetched_models: Default::default(),
|
||||
fetch_model_task: None,
|
||||
_subscription: subscription,
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
}),
|
||||
};
|
||||
this.state
|
||||
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
|
||||
this
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &OllamaSettings {
|
||||
&AllLanguageModelSettings::get_global(cx).ollama
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
OLLAMA_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||
@@ -208,16 +264,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||
let mut models: HashMap<String, ollama::Model> = HashMap::new();
|
||||
|
||||
// Add models from the Ollama API
|
||||
for model in self.state.read(cx).available_models.iter() {
|
||||
for model in self.state.read(cx).fetched_models.iter() {
|
||||
models.insert(model.name.clone(), model.clone());
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in AllLanguageModelSettings::get_global(cx)
|
||||
.ollama
|
||||
.available_models
|
||||
.iter()
|
||||
{
|
||||
for model in &OllamaLanguageModelProvider::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
ollama::Model {
|
||||
@@ -240,6 +292,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||
model,
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
state: self.state.clone(),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
@@ -267,7 +320,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,6 +330,7 @@ pub struct OllamaLanguageModel {
|
||||
model: ollama::Model,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
impl OllamaLanguageModel {
|
||||
@@ -454,15 +509,17 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok(api_url) = cx.update(|cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
settings.api_url.clone()
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
|
||||
let stream =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request)
|
||||
.await?;
|
||||
let stream = map_to_language_model_completion_events(stream);
|
||||
Ok(stream)
|
||||
});
|
||||
@@ -574,39 +631,221 @@ fn map_to_language_model_completion_events(
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: gpui::Entity<SingleLineInput>,
|
||||
api_url_editor: gpui::Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
loading_models_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let loading_models_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
task.await.log_err();
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.loading_models_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
let api_key_editor =
|
||||
cx.new(|cx| SingleLineInput::new(window, cx, "63e02e...").label("API key"));
|
||||
|
||||
let api_url_editor = cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, OLLAMA_API_URL).label("API URL");
|
||||
input.set_text(OllamaLanguageModelProvider::api_url(cx), window, cx);
|
||||
input
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
api_url_editor,
|
||||
state,
|
||||
loading_models_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_connection(&self, cx: &mut App) {
|
||||
self.state
|
||||
.update(cx, |state, cx| state.fetch_models(cx))
|
||||
.detach_and_log_err(cx);
|
||||
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn save_api_url(&mut self, cx: &mut Context<Self>) {
|
||||
let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string();
|
||||
let current_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
if !api_url.is_empty() && &api_url != ¤t_url {
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
|
||||
if let Some(settings) = settings.ollama.as_mut() {
|
||||
settings.api_url = Some(api_url);
|
||||
} else {
|
||||
settings.ollama = Some(crate::settings::OllamaSettingsContent {
|
||||
api_url: Some(api_url),
|
||||
available_models: None,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_url_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
if let Some(settings) = settings.ollama.as_mut() {
|
||||
settings.api_url = Some(OLLAMA_API_URL.into());
|
||||
}
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_instructions() -> Div {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new(
|
||||
"Run LLMs locally on your machine with Ollama, or connect to an Ollama server. \
|
||||
Can provide access to Llama, Mistral, Gemma, and hundreds of other models.",
|
||||
))
|
||||
.child(Label::new("To use local Ollama:"))
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::new(
|
||||
"Download and install Ollama from",
|
||||
Some("ollama.com"),
|
||||
Some("https://ollama.com/download"),
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Start Ollama and download a model: `ollama run gpt-oss:20b`",
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Click 'Connect' below to start using Ollama in Zed",
|
||||
)),
|
||||
)
|
||||
.child(Label::new(
|
||||
"Alternatively, you can connect to an Ollama server by specifying its \
|
||||
URL and API key (may not be required):",
|
||||
))
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &Context<Self>) -> Div {
|
||||
let state = self.state.read(cx);
|
||||
let env_var_set = state.api_key_state.is_from_env_var();
|
||||
|
||||
if !state.api_key_state.has_key() {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.")
|
||||
)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
} else {
|
||||
h_flex()
|
||||
.p_3()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().elevated_surface_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(
|
||||
Label::new(
|
||||
if env_var_set {
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
|
||||
} else {
|
||||
"API key configured".to_string()
|
||||
}
|
||||
)
|
||||
)
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn render_api_url_editor(&self, cx: &Context<Self>) -> Div {
|
||||
let api_url = OllamaLanguageModelProvider::api_url(cx);
|
||||
let custom_api_url_set = api_url != OLLAMA_API_URL;
|
||||
|
||||
if custom_api_url_set {
|
||||
h_flex()
|
||||
.p_3()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().elevated_surface_background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(v_flex().gap_1().child(Label::new(api_url))),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-url", "Reset API URL")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.on_click(
|
||||
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
v_flex()
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
|
||||
this.save_api_url(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.gap_2()
|
||||
.child(self.api_url_editor.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,98 +853,83 @@ impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let is_authenticated = self.state.read(cx).is_authenticated();
|
||||
|
||||
let ollama_intro =
|
||||
"Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
|
||||
|
||||
if self.loading_models_task.is_some() {
|
||||
div().child(Label::new("Loading models...")).into_any()
|
||||
} else {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
v_flex().gap_1().child(Label::new(ollama_intro)).child(
|
||||
List::new()
|
||||
.child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Once installed, try `ollama run llama3.2`",
|
||||
)),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.map(|this| {
|
||||
if is_authenticated {
|
||||
this.child(
|
||||
Button::new("ollama-site", "Ollama")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
|
||||
.into_any_element(),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
Button::new(
|
||||
"download_ollama_button",
|
||||
"Download Ollama",
|
||||
)
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Self::render_instructions())
|
||||
.child(self.render_api_url_editor(cx))
|
||||
.child(self.render_api_key_editor(cx))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.map(|this| {
|
||||
if is_authenticated {
|
||||
this.child(
|
||||
Button::new("ollama-site", "Ollama")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
|
||||
.into_any_element(),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
Button::new("download_ollama_button", "Download Ollama")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.open_url(OLLAMA_DOWNLOAD_URL)
|
||||
})
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
})
|
||||
.child(
|
||||
Button::new("view-models", "View All Models")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
|
||||
),
|
||||
)
|
||||
.map(|this| {
|
||||
if is_authenticated {
|
||||
this.child(
|
||||
ButtonLike::new("connected")
|
||||
.disabled(true)
|
||||
.cursor_style(gpui::CursorStyle::Arrow)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Indicator::dot().color(Color::Success))
|
||||
.child(Label::new("Connected"))
|
||||
.into_any_element(),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
Button::new("retry_ollama_models", "Connect")
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon(IconName::PlayFilled)
|
||||
.on_click(cx.listener(move |this, _, _, cx| {
|
||||
)
|
||||
}
|
||||
})
|
||||
.child(
|
||||
Button::new("view-models", "View All Models")
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
|
||||
),
|
||||
)
|
||||
.map(|this| {
|
||||
if is_authenticated {
|
||||
this.child(
|
||||
ButtonLike::new("connected")
|
||||
.disabled(true)
|
||||
.cursor_style(gpui::CursorStyle::Arrow)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new("Connected"))
|
||||
.into_any_element(),
|
||||
),
|
||||
)
|
||||
} else {
|
||||
this.child(
|
||||
Button::new("retry_ollama_models", "Connect")
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon(IconName::PlayOutlined)
|
||||
.on_click(
|
||||
cx.listener(move |this, _, _, cx| {
|
||||
this.retry_connection(cx)
|
||||
})),
|
||||
)
|
||||
}
|
||||
})
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
@@ -14,24 +12,29 @@ use language_model::{
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
|
||||
use open_ai::{
|
||||
ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct OpenAiSettings {
|
||||
pub api_url: String,
|
||||
@@ -54,132 +57,48 @@ pub struct OpenAiLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
last_api_url: String,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = OpenAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
self.get_api_key(cx)
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = OpenAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let initial_api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
last_api_url: initial_api_url.clone(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let current_api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
if this.last_api_url != current_api_url {
|
||||
this.last_api_url = current_api_url;
|
||||
if !this.api_key_from_env {
|
||||
this.api_key = None;
|
||||
let spawn_task = cx.spawn(async move |handle, cx| {
|
||||
if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
|
||||
if let Err(_) = task.await {
|
||||
handle
|
||||
.update(cx, |this, _| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
spawn_task.detach();
|
||||
}
|
||||
}
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -194,6 +113,19 @@ impl OpenAiLanguageModelProvider {
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &OpenAiSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).openai
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
open_ai::OPEN_AI_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
|
||||
@@ -236,10 +168,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.available_models
|
||||
{
|
||||
for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
open_ai::Model::Custom {
|
||||
@@ -278,7 +207,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,11 +228,12 @@ impl OpenAiLanguageModel {
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).openai;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = OpenAiLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
@@ -802,45 +733,35 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
@@ -850,7 +771,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
@@ -872,10 +793,11 @@ impl Render for ConfigurationView {
|
||||
)
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
Label::new(format!(
|
||||
"You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
@@ -898,9 +820,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = OpenAiLanguageModelProvider::api_url(cx);
|
||||
if api_url == OPEN_AI_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -911,7 +838,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use convert_case::{Case, Casing};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
@@ -17,12 +15,12 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
|
||||
use ui::{ElevationIndex, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::EnvVar;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::api_key::ApiKeyState;
|
||||
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
@@ -70,124 +68,67 @@ pub struct OpenAiCompatibleLanguageModelProvider {
|
||||
|
||||
pub struct State {
|
||||
id: Arc<str>,
|
||||
env_var_name: Arc<str>,
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
api_key_env_var: EnvVar,
|
||||
api_key_state: ApiKeyState,
|
||||
settings: OpenAiCompatibleSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = SharedString::new(self.settings.api_url.as_str());
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let env_var_name = self.env_var_name.clone();
|
||||
let api_url = self.settings.api_url.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
self.get_api_key(cx)
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = SharedString::new(self.settings.api_url.clone());
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&self.api_key_env_var,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleLanguageModelProvider {
|
||||
pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
|
||||
AllLanguageModelSettings::get_global(cx)
|
||||
crate::AllLanguageModelSettings::get_global(cx)
|
||||
.openai_compatible
|
||||
.get(id)
|
||||
}
|
||||
|
||||
let state = cx.new(|cx| State {
|
||||
id: id.clone(),
|
||||
env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
|
||||
settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into();
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
|
||||
return;
|
||||
};
|
||||
if &this.settings != &settings {
|
||||
if settings.api_url != this.settings.api_url && !this.api_key_from_env {
|
||||
let spawn_task = cx.spawn(async move |handle, cx| {
|
||||
if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
|
||||
if let Err(_) = task.await {
|
||||
handle
|
||||
.update(cx, |this, _| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
spawn_task.detach();
|
||||
}
|
||||
|
||||
let api_url = SharedString::new(settings.api_url.as_str());
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&this.api_key_env_var,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
this.settings = settings;
|
||||
cx.notify();
|
||||
}
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
|
||||
State {
|
||||
id: id.clone(),
|
||||
api_key_env_var: EnvVar::new(api_key_env_var_name),
|
||||
api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())),
|
||||
settings,
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
@@ -274,7 +215,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -296,10 +238,15 @@ impl OpenAiCompatibleLanguageModel {
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
|
||||
(state.api_key.clone(), state.settings.api_url.clone())
|
||||
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| {
|
||||
let api_url = &state.settings.api_url;
|
||||
(
|
||||
state.api_key_state.key(api_url),
|
||||
state.settings.api_url.clone(),
|
||||
)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let provider = self.provider_name.clone();
|
||||
@@ -469,56 +416,47 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
fn should_render_editor(&self, cx: &Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_name = self.state.read(cx).env_var_name.clone();
|
||||
let state = self.state.read(cx);
|
||||
let env_var_set = state.api_key_state.is_from_env_var();
|
||||
let env_var_name = &state.api_key_env_var.name;
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
@@ -550,9 +488,9 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {env_var_name} environment variable.")
|
||||
format!("API key set in {env_var_name} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
format!("API key configured for {}", truncate_and_trailoff(&state.settings.api_url, 32))
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
@@ -15,24 +14,28 @@ use language_model::{
|
||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use open_router::{
|
||||
Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models,
|
||||
stream_completion,
|
||||
Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, Provider, ResponseStreamEvent,
|
||||
list_models,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct OpenRouterSettings {
|
||||
pub api_url: String,
|
||||
@@ -90,93 +93,37 @@ pub struct OpenRouterLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
api_key_state: ApiKeyState,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<open_router::Model>,
|
||||
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
|
||||
settings: OpenRouterSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.open_router
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.open_router
|
||||
.api_url
|
||||
.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.restart_fetch_models_task(cx);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.open_router
|
||||
.api_url
|
||||
.clone();
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
let task = self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
this.restart_fetch_models_task(cx);
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
let result = task.await;
|
||||
this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
|
||||
.ok();
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
@@ -184,10 +131,9 @@ impl State {
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<(), LanguageModelCompletionError>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
let Some(api_key) = self.api_key.clone() else {
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
let Some(api_key) = self.api_key_state.key(&api_url) else {
|
||||
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
}));
|
||||
@@ -216,33 +162,52 @@ impl State {
|
||||
if self.is_authenticated() {
|
||||
let task = self.fetch_models(cx);
|
||||
self.fetch_models_task.replace(task);
|
||||
} else {
|
||||
self.available_models = Vec::new();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenRouterLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
http_client: http_client.clone(),
|
||||
available_models: Vec::new(),
|
||||
fetch_models_task: None,
|
||||
settings: OpenRouterSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let current_settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
||||
let settings_changed = current_settings != &this.settings;
|
||||
if settings_changed {
|
||||
this.settings = current_settings.clone();
|
||||
this.restart_fetch_models_task(cx);
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>({
|
||||
let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
|
||||
move |this: &mut State, cx| {
|
||||
let current_settings = OpenRouterLanguageModelProvider::settings(cx);
|
||||
let settings_changed = current_settings != &last_settings;
|
||||
if settings_changed {
|
||||
last_settings = current_settings.clone();
|
||||
this.authenticate(cx).detach();
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
http_client: http_client.clone(),
|
||||
available_models: Vec::new(),
|
||||
fetch_models_task: None,
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &OpenRouterSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).open_router
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
OPEN_ROUTER_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(OpenRouterLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
@@ -287,10 +252,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
|
||||
let mut models_from_api = self.state.read(cx).available_models.clone();
|
||||
let mut settings_models = Vec::new();
|
||||
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.open_router
|
||||
.available_models
|
||||
{
|
||||
for model in &Self::settings(cx).available_models {
|
||||
settings_models.push(open_router::Model {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
@@ -338,7 +300,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,14 +329,11 @@ impl OpenRouterLanguageModel {
|
||||
>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"App state dropped"
|
||||
))))
|
||||
.boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped").into())).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
@@ -382,7 +342,8 @@ impl OpenRouterLanguageModel {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
};
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let request =
|
||||
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
request.await.map_err(Into::into)
|
||||
}
|
||||
.boxed()
|
||||
@@ -830,20 +791,22 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx);
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -852,11 +815,11 @@ impl ConfigurationView {
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
@@ -891,7 +854,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials...")).into_any()
|
||||
@@ -928,7 +891,7 @@ impl Render for ConfigurationView {
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."),
|
||||
format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
@@ -947,9 +910,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
if api_url == OPEN_ROUTER_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -960,7 +928,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.disabled(env_var_set)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
@@ -10,24 +9,26 @@ use language_model::{
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter, Role,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::ResponseStreamEvent;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use vercel::Model;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use vercel::{Model, VERCEL_API_URL};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "VERCEL_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct VercelSettings {
|
||||
pub api_url: String,
|
||||
@@ -49,103 +50,48 @@ pub struct VercelLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = VercelLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = VercelLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl VercelLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -160,6 +106,19 @@ impl VercelLanguageModelProvider {
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &VercelSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).vercel
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
VERCEL_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for VercelLanguageModelProvider {
|
||||
@@ -200,10 +159,7 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
|
||||
}
|
||||
}
|
||||
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.vercel
|
||||
.available_models
|
||||
{
|
||||
for model in &Self::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
vercel::Model::Custom {
|
||||
@@ -241,7 +197,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,16 +218,12 @@ impl VercelLanguageModel {
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
(state.api_key.clone(), api_url)
|
||||
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = VercelLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
@@ -466,45 +419,35 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
@@ -514,7 +457,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
@@ -534,7 +477,7 @@ impl Render for ConfigurationView {
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed."
|
||||
"You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
@@ -559,9 +502,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {VERCEL_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = VercelLanguageModelProvider::api_url(cx);
|
||||
if api_url == VERCEL_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -572,7 +520,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use futures::{FutureExt, StreamExt, future, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
@@ -10,23 +9,25 @@ use language_model::{
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::ResponseStreamEvent;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use strum::IntoEnumIterator;
|
||||
use x_ai::Model;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, truncate_and_trailoff};
|
||||
use x_ai::{Model, XAI_API_URL};
|
||||
use zed_env_vars::{EnvVar, env_var};
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
use crate::{api_key::ApiKeyState, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "x_ai";
|
||||
const PROVIDER_NAME: &str = "xAI";
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
|
||||
|
||||
const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY";
|
||||
static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct XAiSettings {
|
||||
@@ -49,103 +50,48 @@ pub struct XAiLanguageModelProvider {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
api_key_state: ApiKeyState,
|
||||
}
|
||||
|
||||
const XAI_API_KEY_VAR: &str = "XAI_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
self.api_key_state.has_key()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let api_url = XAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state
|
||||
.store(api_url, api_key, |this| &mut this.api_key_state, cx)
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
let api_url = XAiLanguageModelProvider::api_url(cx);
|
||||
self.api_key_state.load_if_needed(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl XAiLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
let state = cx.new(|cx| {
|
||||
cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
let api_url = Self::api_url(cx);
|
||||
this.api_key_state.handle_url_change(
|
||||
api_url,
|
||||
&API_KEY_ENV_VAR,
|
||||
|this| &mut this.api_key_state,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}),
|
||||
})
|
||||
.detach();
|
||||
State {
|
||||
api_key_state: ApiKeyState::new(Self::api_url(cx)),
|
||||
}
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
@@ -160,6 +106,19 @@ impl XAiLanguageModelProvider {
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &XAiSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).x_ai
|
||||
}
|
||||
|
||||
fn api_url(cx: &App) -> SharedString {
|
||||
let api_url = &Self::settings(cx).api_url;
|
||||
if api_url.is_empty() {
|
||||
XAI_API_URL.into()
|
||||
} else {
|
||||
SharedString::new(api_url.as_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for XAiLanguageModelProvider {
|
||||
@@ -172,11 +131,11 @@ impl LanguageModelProviderState for XAiLanguageModelProvider {
|
||||
|
||||
impl LanguageModelProvider for XAiLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -200,10 +159,7 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
|
||||
}
|
||||
}
|
||||
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.x_ai
|
||||
.available_models
|
||||
{
|
||||
for model in &Self::settings(cx).available_models {
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
x_ai::Model::Custom {
|
||||
@@ -241,7 +197,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
self.state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,20 +218,20 @@ impl XAiLanguageModel {
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
x_ai::XAI_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
(state.api_key.clone(), api_url)
|
||||
|
||||
let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
|
||||
let api_url = XAiLanguageModelProvider::api_url(cx);
|
||||
(state.api_key_state.key(&api_url), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing xAI API Key")?;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
};
|
||||
let request =
|
||||
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
@@ -295,11 +252,11 @@ impl LanguageModel for XAiLanguageModel {
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
@@ -456,45 +413,35 @@ impl ConfigurationView {
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
|
||||
if api_key.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// url changes can cause the editor to be displayed again
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
self.api_key_editor
|
||||
.update(cx, |input, cx| input.set_text("", window, cx));
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(None, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
@@ -504,7 +451,7 @@ impl ConfigurationView {
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
@@ -524,7 +471,7 @@ impl Render for ConfigurationView {
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed."
|
||||
"You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
@@ -549,9 +496,14 @@ impl Render for ConfigurationView {
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {XAI_API_KEY_VAR} environment variable.")
|
||||
format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
let api_url = XAiLanguageModelProvider::api_url(cx);
|
||||
if api_url == XAI_API_URL {
|
||||
"API key configured".to_string()
|
||||
} else {
|
||||
format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
@@ -562,7 +514,7 @@ impl Render for ConfigurationView {
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable.")))
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
|
||||
@@ -30,6 +30,10 @@ impl BasedPyrightBanner {
|
||||
_subscriptions: [subscription],
|
||||
}
|
||||
}
|
||||
|
||||
fn onboarding_banner_enabled(&self) -> bool {
|
||||
!self.dismissed && self.have_basedpyright
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<ToolbarItemEvent> for BasedPyrightBanner {}
|
||||
@@ -38,7 +42,7 @@ impl Render for BasedPyrightBanner {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.id("basedpyright-banner")
|
||||
.when(!self.dismissed && self.have_basedpyright, |el| {
|
||||
.when(self.onboarding_banner_enabled(), |el| {
|
||||
el.child(
|
||||
Banner::new()
|
||||
.child(
|
||||
@@ -81,6 +85,9 @@ impl ToolbarItemView for BasedPyrightBanner {
|
||||
_window: &mut ui::Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ToolbarItemLocation {
|
||||
if !self.onboarding_banner_enabled() {
|
||||
return ToolbarItemLocation::Hidden;
|
||||
}
|
||||
if let Some(item) = active_pane_item
|
||||
&& let Some(editor) = item.act_as::<Editor>(cx)
|
||||
&& let Some(path) = editor.update(cx, |editor, cx| editor.target_file_abs_path(cx))
|
||||
|
||||
@@ -12,7 +12,8 @@ use theme::ActiveTheme;
|
||||
use tree_sitter::{Node, TreeCursor};
|
||||
use ui::{
|
||||
ButtonCommon, ButtonLike, Clickable, Color, ContextMenu, FluentBuilder as _, IconButton,
|
||||
IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, h_flex, v_flex,
|
||||
IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, WithScrollbar,
|
||||
h_flex, v_flex,
|
||||
};
|
||||
use workspace::{
|
||||
Event as WorkspaceEvent, SplitDirection, ToolbarItemEvent, ToolbarItemLocation,
|
||||
@@ -487,7 +488,7 @@ impl SyntaxTreeView {
|
||||
}
|
||||
|
||||
impl Render for SyntaxTreeView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.flex_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
@@ -512,6 +513,8 @@ impl Render for SyntaxTreeView {
|
||||
.text_bg(cx.theme().colors().background)
|
||||
.into_any_element(),
|
||||
)
|
||||
.vertical_scrollbar_for(self.list_scroll_handle.clone(), window, cx)
|
||||
.into_any_element()
|
||||
} else {
|
||||
let inner_content = v_flex()
|
||||
.items_center()
|
||||
@@ -540,6 +543,7 @@ impl Render for SyntaxTreeView {
|
||||
.size_full()
|
||||
.justify_center()
|
||||
.child(inner_content)
|
||||
.into_any_element()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -57,6 +57,7 @@ pet-core.workspace = true
|
||||
pet-fs.workspace = true
|
||||
pet-poetry.workspace = true
|
||||
pet-reporter.workspace = true
|
||||
pet-virtualenv.workspace = true
|
||||
pet.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
|
||||
@@ -10,4 +10,365 @@
|
||||
(raw_string_literal)
|
||||
(interpreted_string_literal)
|
||||
] @injection.content
|
||||
(#set! injection.language "regex")))
|
||||
(#set! injection.language "regex")
|
||||
))
|
||||
|
||||
; INJECT SQL
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*sql\\s*\\*\\/") ; /* sql */ or /*sql*/
|
||||
(#set! injection.language "sql")
|
||||
)
|
||||
|
||||
; INJECT JSON
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*json\\s*\\*\\/") ; /* json */ or /*json*/
|
||||
(#set! injection.language "json")
|
||||
)
|
||||
|
||||
; INJECT YAML
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*yaml\\s*\\*\\/") ; /* yaml */ or /*yaml*/
|
||||
(#set! injection.language "yaml")
|
||||
)
|
||||
|
||||
; INJECT XML
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*xml\\s*\\*\\/") ; /* xml */ or /*xml*/
|
||||
(#set! injection.language "xml")
|
||||
)
|
||||
|
||||
; INJECT HTML
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*html\\s*\\*\\/") ; /* html */ or /*html*/
|
||||
(#set! injection.language "html")
|
||||
)
|
||||
|
||||
; INJECT JS
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*js\\s*\\*\\/") ; /* js */ or /*js*/
|
||||
(#set! injection.language "javascript")
|
||||
)
|
||||
|
||||
; INJECT CSS
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*css\\s*\\*\\/") ; /* css */ or /*css*/
|
||||
(#set! injection.language "css")
|
||||
)
|
||||
|
||||
; INJECT LUA
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*lua\\s*\\*\\/") ; /* lua */ or /*lua*/
|
||||
(#set! injection.language "lua")
|
||||
)
|
||||
|
||||
; INJECT BASH
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*bash\\s*\\*\\/") ; /* bash */ or /*bash*/
|
||||
(#set! injection.language "bash")
|
||||
)
|
||||
|
||||
; INJECT CSV
|
||||
(
|
||||
[
|
||||
; var, const or short declaration of raw or interpreted string literal
|
||||
((comment) @comment
|
||||
.
|
||||
(expression_list
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a literal element (to struct field eg.)
|
||||
((comment) @comment
|
||||
.
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content
|
||||
))
|
||||
|
||||
; when passing as a function parameter
|
||||
((comment) @comment
|
||||
.
|
||||
[
|
||||
(interpreted_string_literal)
|
||||
(raw_string_literal)
|
||||
] @injection.content)
|
||||
]
|
||||
|
||||
(#match? @comment "^\\/\\*\\s*csv\\s*\\*\\/") ; /* csv */ or /*csv*/
|
||||
(#set! injection.language "csv")
|
||||
)
|
||||
|
||||
@@ -286,6 +286,7 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
|
||||
"HEEX",
|
||||
"HTML",
|
||||
"JavaScript",
|
||||
"TypeScript",
|
||||
"PHP",
|
||||
"Svelte",
|
||||
"TSX",
|
||||
|
||||
@@ -16,6 +16,7 @@ use node_runtime::{NodeRuntime, VersionStrategy};
|
||||
use pet_core::Configuration;
|
||||
use pet_core::os_environment::Environment;
|
||||
use pet_core::python_environment::{PythonEnvironment, PythonEnvironmentKind};
|
||||
use pet_virtualenv::is_virtualenv_dir;
|
||||
use project::Fs;
|
||||
use project::lsp_store::language_server_settings;
|
||||
use serde_json::{Value, json};
|
||||
@@ -900,6 +901,21 @@ fn python_module_name_from_relative_path(relative_path: &str) -> String {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn is_python_env_global(k: &PythonEnvironmentKind) -> bool {
|
||||
matches!(
|
||||
k,
|
||||
PythonEnvironmentKind::Homebrew
|
||||
| PythonEnvironmentKind::Pyenv
|
||||
| PythonEnvironmentKind::GlobalPaths
|
||||
| PythonEnvironmentKind::MacPythonOrg
|
||||
| PythonEnvironmentKind::MacCommandLineTools
|
||||
| PythonEnvironmentKind::LinuxGlobal
|
||||
| PythonEnvironmentKind::MacXCode
|
||||
| PythonEnvironmentKind::WindowsStore
|
||||
| PythonEnvironmentKind::WindowsRegistry
|
||||
)
|
||||
}
|
||||
|
||||
fn python_env_kind_display(k: &PythonEnvironmentKind) -> &'static str {
|
||||
match k {
|
||||
PythonEnvironmentKind::Conda => "Conda",
|
||||
@@ -966,6 +982,26 @@ async fn get_worktree_venv_declaration(worktree_root: &Path) -> Option<String> {
|
||||
Some(venv_name.trim().to_string())
|
||||
}
|
||||
|
||||
fn get_venv_parent_dir(env: &PythonEnvironment) -> Option<PathBuf> {
|
||||
// If global, we aren't a virtual environment
|
||||
if let Some(kind) = env.kind
|
||||
&& is_python_env_global(&kind)
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check to be sure we are a virtual environment using pet's most generic
|
||||
// virtual environment type, VirtualEnv
|
||||
let venv = env
|
||||
.executable
|
||||
.as_ref()
|
||||
.and_then(|p| p.parent())
|
||||
.and_then(|p| p.parent())
|
||||
.filter(|p| is_virtualenv_dir(p))?;
|
||||
|
||||
venv.parent().map(|parent| parent.to_path_buf())
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolchainLister for PythonToolchainProvider {
|
||||
async fn list(
|
||||
@@ -1025,11 +1061,15 @@ impl ToolchainLister for PythonToolchainProvider {
|
||||
});
|
||||
|
||||
// Compare project paths against worktree root
|
||||
let proj_ordering = || match (&lhs.project, &rhs.project) {
|
||||
(Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)),
|
||||
(Some(l), None) if l == &wr => Ordering::Less,
|
||||
(None, Some(r)) if r == &wr => Ordering::Greater,
|
||||
_ => Ordering::Equal,
|
||||
let proj_ordering = || {
|
||||
let lhs_project = lhs.project.clone().or_else(|| get_venv_parent_dir(lhs));
|
||||
let rhs_project = rhs.project.clone().or_else(|| get_venv_parent_dir(rhs));
|
||||
match (&lhs_project, &rhs_project) {
|
||||
(Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)),
|
||||
(Some(l), None) if l == &wr => Ordering::Less,
|
||||
(None, Some(r)) if r == &wr => Ordering::Greater,
|
||||
_ => Ordering::Equal,
|
||||
}
|
||||
};
|
||||
|
||||
// Compare environment priorities
|
||||
@@ -1131,7 +1171,7 @@ impl ToolchainLister for PythonToolchainProvider {
|
||||
let activate_keyword = match shell {
|
||||
ShellKind::Cmd => ".",
|
||||
ShellKind::Nushell => "overlay use",
|
||||
ShellKind::Powershell => ".",
|
||||
ShellKind::PowerShell => ".",
|
||||
ShellKind::Fish => "source",
|
||||
ShellKind::Csh => "source",
|
||||
ShellKind::Posix => "source",
|
||||
@@ -1141,7 +1181,7 @@ impl ToolchainLister for PythonToolchainProvider {
|
||||
ShellKind::Csh => "activate.csh",
|
||||
ShellKind::Fish => "activate.fish",
|
||||
ShellKind::Nushell => "activate.nu",
|
||||
ShellKind::Powershell => "activate.ps1",
|
||||
ShellKind::PowerShell => "activate.ps1",
|
||||
ShellKind::Cmd => "activate.bat",
|
||||
};
|
||||
let path = prefix.join(BINARY_DIR).join(activate_script_name);
|
||||
@@ -1165,7 +1205,7 @@ impl ToolchainLister for PythonToolchainProvider {
|
||||
ShellKind::Fish => Some(format!("\"{pyenv}\" shell - fish {version}")),
|
||||
ShellKind::Posix => Some(format!("\"{pyenv}\" shell - sh {version}")),
|
||||
ShellKind::Nushell => Some(format!("\"{pyenv}\" shell - nu {version}")),
|
||||
ShellKind::Powershell => None,
|
||||
ShellKind::PowerShell => None,
|
||||
ShellKind::Csh => None,
|
||||
ShellKind::Cmd => None,
|
||||
})
|
||||
|
||||
@@ -146,6 +146,7 @@ impl LspAdapter for TailwindLspAdapter {
|
||||
"html": "html",
|
||||
"css": "css",
|
||||
"javascript": "javascript",
|
||||
"typescript": "typescript",
|
||||
"typescriptreact": "typescriptreact",
|
||||
},
|
||||
})))
|
||||
@@ -178,6 +179,7 @@ impl LspAdapter for TailwindLspAdapter {
|
||||
(LanguageName::new("HTML"), "html".to_string()),
|
||||
(LanguageName::new("CSS"), "css".to_string()),
|
||||
(LanguageName::new("JavaScript"), "javascript".to_string()),
|
||||
(LanguageName::new("TypeScript"), "typescript".to_string()),
|
||||
(LanguageName::new("TSX"), "typescriptreact".to_string()),
|
||||
(LanguageName::new("Svelte"), "svelte".to_string()),
|
||||
(LanguageName::new("Elixir"), "phoenix-heex".to_string()),
|
||||
|
||||
@@ -21,9 +21,11 @@ word_characters = ["#", "$"]
|
||||
prettier_parser_name = "typescript"
|
||||
tab_size = 2
|
||||
debuggers = ["JavaScript"]
|
||||
scope_opt_in_language_servers = ["tailwindcss-language-server"]
|
||||
|
||||
[overrides.string]
|
||||
completion_query_characters = ["."]
|
||||
completion_query_characters = ["-", "."]
|
||||
opt_into_language_servers = ["tailwindcss-language-server"]
|
||||
prefer_label_for_snippet = true
|
||||
|
||||
[overrides.function_name_before_type_arguments]
|
||||
|
||||
@@ -1079,7 +1079,7 @@ impl Element for MarkdownElement {
|
||||
{
|
||||
builder.modify_current_div(|el| {
|
||||
let content_range = parser::extract_code_block_content_range(
|
||||
parsed_markdown.source()[range.clone()].trim(),
|
||||
&parsed_markdown.source()[range.clone()],
|
||||
);
|
||||
let content_range = content_range.start + range.start
|
||||
..content_range.end + range.start;
|
||||
@@ -1110,7 +1110,7 @@ impl Element for MarkdownElement {
|
||||
{
|
||||
builder.modify_current_div(|el| {
|
||||
let content_range = parser::extract_code_block_content_range(
|
||||
parsed_markdown.source()[range.clone()].trim(),
|
||||
&parsed_markdown.source()[range.clone()],
|
||||
);
|
||||
let content_range = content_range.start + range.start
|
||||
..content_range.end + range.start;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user