Compare commits
74 Commits
remove-d2d
...
acp-champa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f81993574e | ||
|
|
bc1f861d3f | ||
|
|
4f2d6a9ea9 | ||
|
|
604a88f6e3 | ||
|
|
a4fe8c6972 | ||
|
|
5d621bef78 | ||
|
|
27877325bc | ||
|
|
9e1c7fdfea | ||
|
|
387ee1be8d | ||
|
|
afc8cf6098 | ||
|
|
84d6a0fae9 | ||
|
|
afb5c4147a | ||
|
|
8890f590b1 | ||
|
|
76f0f9163d | ||
|
|
8acb58b6e5 | ||
|
|
6052115825 | ||
|
|
561ccf86aa | ||
|
|
8563ed2252 | ||
|
|
ac75593198 | ||
|
|
a3a3f111f8 | ||
|
|
faa45c53d7 | ||
|
|
b31f893408 | ||
|
|
f888f3fc0b | ||
|
|
b01d1872cc | ||
|
|
e5c6a596a9 | ||
|
|
106aa0d9cc | ||
|
|
f7f90593ac | ||
|
|
8be3f48f37 | ||
|
|
76a8293cc6 | ||
|
|
2315962e18 | ||
|
|
f8673dacf5 | ||
|
|
72d354de6c | ||
|
|
09b93caa9b | ||
|
|
7c169fc9b5 | ||
|
|
2b36d4ec94 | ||
|
|
4a82b6c5ee | ||
|
|
5feb759c20 | ||
|
|
410348deb0 | ||
|
|
8e7f1899e1 | ||
|
|
aea1d48184 | ||
|
|
c946b98ea1 | ||
|
|
c6947ee4f0 | ||
|
|
b59f992928 | ||
|
|
0a21b845fa | ||
|
|
6a8be1714e | ||
|
|
a2aea00253 | ||
|
|
98c66eddb8 | ||
|
|
558bbfffae | ||
|
|
89ed0b9601 | ||
|
|
4b9334b910 | ||
|
|
47af878ebb | ||
|
|
5488398986 | ||
|
|
b1a7993544 | ||
|
|
b90fd4287f | ||
|
|
e1e2775b80 | ||
|
|
ed104ec5e0 | ||
|
|
67a491df50 | ||
|
|
f003036aec | ||
|
|
fbc784d323 | ||
|
|
296bb66b65 | ||
|
|
02d3043ec5 | ||
|
|
30739041a4 | ||
|
|
27708143ec | ||
|
|
738296345e | ||
|
|
81c111510f | ||
|
|
f028ca4d1a | ||
|
|
6656403ce8 | ||
|
|
254c6be42b | ||
|
|
745e4b5f1e | ||
|
|
912ab505b2 | ||
|
|
b48faddaf4 | ||
|
|
477731d77d | ||
|
|
ced3d09f10 | ||
|
|
0f395df9a8 |
8
.github/actions/build_docs/action.yml
vendored
@@ -19,7 +19,7 @@ runs:
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: ./script/linux
|
||||
|
||||
- name: Check for broken links
|
||||
- name: Check for broken links (in MD)
|
||||
uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1
|
||||
with:
|
||||
args: --no-progress --exclude '^http' './docs/src/**/*'
|
||||
@@ -30,3 +30,9 @@ runs:
|
||||
run: |
|
||||
mkdir -p target/deploy
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Check for broken links (in HTML)
|
||||
uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1
|
||||
with:
|
||||
args: --no-progress --exclude '^http' 'target/deploy/docs/'
|
||||
fail: true
|
||||
|
||||
332
Cargo.lock
generated
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"markdown",
|
||||
"project",
|
||||
"serde",
|
||||
@@ -114,7 +115,6 @@ dependencies = [
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"proto",
|
||||
"rand 0.8.5",
|
||||
"ref-cast",
|
||||
"rope",
|
||||
@@ -138,15 +138,57 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "agent-client-protocol"
|
||||
version = "0.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
|
||||
version = "0.0.14"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"acp_thread",
|
||||
"agent-client-protocol",
|
||||
"agent_servers",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"chrono",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"env_logger 0.11.8",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"project",
|
||||
"reqwest_client",
|
||||
"rust-embed",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"thiserror 2.0.12",
|
||||
"ui",
|
||||
"util",
|
||||
"uuid",
|
||||
"worktree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "agent_servers"
|
||||
version = "0.1.0"
|
||||
@@ -210,6 +252,7 @@ dependencies = [
|
||||
"acp_thread",
|
||||
"agent",
|
||||
"agent-client-protocol",
|
||||
"agent2",
|
||||
"agent_servers",
|
||||
"agent_settings",
|
||||
"ai_onboarding",
|
||||
@@ -355,10 +398,10 @@ name = "ai_onboarding"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"component",
|
||||
"gpui",
|
||||
"language_model",
|
||||
"proto",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"telemetry",
|
||||
@@ -1075,17 +1118,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-recursion"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7d78656ba01f1b93024b7c3a0467f1608e4be67d725749fdcd7d2c7678fd7a2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-recursion"
|
||||
version = "1.1.1"
|
||||
@@ -2971,7 +3003,6 @@ name = "client"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-recursion 0.3.2",
|
||||
"async-tungstenite",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
@@ -3049,7 +3080,11 @@ dependencies = [
|
||||
name = "cloud_api_types"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"cloud_llm_client",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -4291,41 +4326,6 @@ dependencies = [
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "5.5.3"
|
||||
@@ -4541,37 +4541,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947"
|
||||
dependencies = [
|
||||
"derive_builder_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_core"
|
||||
version = "0.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_macro"
|
||||
version = "0.20.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_more"
|
||||
version = "0.99.19"
|
||||
@@ -4982,6 +4951,7 @@ dependencies = [
|
||||
"theme",
|
||||
"time",
|
||||
"tree-sitter-bash",
|
||||
"tree-sitter-c",
|
||||
"tree-sitter-html",
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-rust",
|
||||
@@ -5950,7 +5920,7 @@ dependencies = [
|
||||
"ignore",
|
||||
"libc",
|
||||
"log",
|
||||
"notify",
|
||||
"notify 8.0.0",
|
||||
"objc",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
@@ -7507,18 +7477,16 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "handlebars"
|
||||
version = "6.3.2"
|
||||
version = "5.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "759e2d5aea3287cb1190c8ec394f42866cb5bf74fcbf213f354e3c856ea26098"
|
||||
checksum = "d08485b96a0e6393e9e4d1b8d48cf74ad6c063cd905eb33f42c1ce3f0377539b"
|
||||
dependencies = [
|
||||
"derive_builder",
|
||||
"log",
|
||||
"num-order",
|
||||
"pest",
|
||||
"pest_derive",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7699,12 +7667,6 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hex-literal"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bcaaec4551594c969335c98c903c1397853d4198408ea609190f420500f6be71"
|
||||
|
||||
[[package]]
|
||||
name = "hexf-parse"
|
||||
version = "0.2.1"
|
||||
@@ -7882,6 +7844,7 @@ dependencies = [
|
||||
"http 1.3.1",
|
||||
"http-body 1.0.1",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"url",
|
||||
@@ -8189,12 +8152,6 @@ version = "2.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "1.0.3"
|
||||
@@ -8413,6 +8370,17 @@ dependencies = [
|
||||
"zeta",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inotify"
|
||||
version = "0.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8069d3ec154eb856955c1c0fbffefbf5f3c40a104ec912d4797314c1801abff"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"inotify-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inotify"
|
||||
version = "0.11.0"
|
||||
@@ -8566,7 +8534,7 @@ dependencies = [
|
||||
"fnv",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"mio",
|
||||
"mio 1.0.3",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"tempfile",
|
||||
@@ -9148,7 +9116,6 @@ dependencies = [
|
||||
"open_router",
|
||||
"partial-json-fixer",
|
||||
"project",
|
||||
"proto",
|
||||
"release_channel",
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -9420,7 +9387,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "libwebrtc"
|
||||
version = "0.3.10"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"cxx",
|
||||
"jni",
|
||||
@@ -9500,7 +9467,7 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
|
||||
[[package]]
|
||||
name = "livekit"
|
||||
version = "0.7.8"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"futures-util",
|
||||
@@ -9523,7 +9490,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "livekit-api"
|
||||
version = "0.4.2"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"http 0.2.12",
|
||||
@@ -9547,7 +9514,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "livekit-protocol"
|
||||
version = "0.3.9"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"livekit-runtime",
|
||||
@@ -9564,7 +9531,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "livekit-runtime"
|
||||
version = "0.4.0"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
@@ -9647,9 +9614,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
@@ -9886,7 +9853,7 @@ name = "markdown_preview"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-recursion 1.1.1",
|
||||
"async-recursion",
|
||||
"collections",
|
||||
"editor",
|
||||
"fs",
|
||||
@@ -10006,9 +9973,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mdbook"
|
||||
version = "0.4.48"
|
||||
version = "0.4.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6fbb4ac2d9fd7aa987c3510309ea3c80004a968d063c42f0d34fea070817c1"
|
||||
checksum = "b45a38e19bd200220ef07c892b0157ad3d2365e5b5a267ca01ad12182491eea5"
|
||||
dependencies = [
|
||||
"ammonia",
|
||||
"anyhow",
|
||||
@@ -10018,12 +9985,11 @@ dependencies = [
|
||||
"elasticlunr-rs",
|
||||
"env_logger 0.11.8",
|
||||
"futures-util",
|
||||
"handlebars 6.3.2",
|
||||
"hex",
|
||||
"handlebars 5.1.2",
|
||||
"ignore",
|
||||
"log",
|
||||
"memchr",
|
||||
"notify",
|
||||
"notify 6.1.1",
|
||||
"notify-debouncer-mini",
|
||||
"once_cell",
|
||||
"opener",
|
||||
@@ -10032,7 +9998,6 @@ dependencies = [
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"shlex",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
@@ -10175,6 +10140,18 @@ version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e53debba6bda7a793e5f99b8dacf19e626084f525f7829104ba9898f367d85ff"
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "1.0.3"
|
||||
@@ -10521,6 +10498,25 @@ dependencies = [
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notify"
|
||||
version = "6.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"crossbeam-channel",
|
||||
"filetime",
|
||||
"fsevent-sys 4.1.0",
|
||||
"inotify 0.9.6",
|
||||
"kqueue",
|
||||
"libc",
|
||||
"log",
|
||||
"mio 0.8.11",
|
||||
"walkdir",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notify"
|
||||
version = "8.0.0"
|
||||
@@ -10529,11 +10525,11 @@ dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"filetime",
|
||||
"fsevent-sys 4.1.0",
|
||||
"inotify",
|
||||
"inotify 0.11.0",
|
||||
"kqueue",
|
||||
"libc",
|
||||
"log",
|
||||
"mio",
|
||||
"mio 1.0.3",
|
||||
"notify-types",
|
||||
"walkdir",
|
||||
"windows-sys 0.59.0",
|
||||
@@ -10541,14 +10537,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "notify-debouncer-mini"
|
||||
version = "0.6.0"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a689eb4262184d9a1727f9087cd03883ea716682ab03ed24efec57d7716dccb8"
|
||||
checksum = "5d40b221972a1fc5ef4d858a2f671fb34c75983eb385463dff3780eeff6a9d43"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"log",
|
||||
"notify",
|
||||
"notify-types",
|
||||
"tempfile",
|
||||
"notify 6.1.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10688,21 +10683,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-modular"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17bb261bf36fa7d83f4c294f834e91256769097b3cb505d44831e0a179ac647f"
|
||||
|
||||
[[package]]
|
||||
name = "num-order"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "537b596b97c40fcf8056d153049eb22f481c17ebce72a513ec9286e4986d1bb6"
|
||||
dependencies = [
|
||||
"num-modular",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.2"
|
||||
@@ -10973,14 +10953,21 @@ dependencies = [
|
||||
name = "onboarding"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ai_onboarding",
|
||||
"anyhow",
|
||||
"client",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
"db",
|
||||
"documented",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"gpui",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"menu",
|
||||
"project",
|
||||
"schemars",
|
||||
"serde",
|
||||
@@ -10988,6 +10975,7 @@ dependencies = [
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"vim_mode_setting",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
@@ -11342,9 +11330,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.3"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
|
||||
checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
@@ -11352,9 +11340,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.10"
|
||||
version = "0.9.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
||||
checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
@@ -14759,6 +14747,7 @@ dependencies = [
|
||||
name = "settings_profile_selector"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"client",
|
||||
"editor",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
@@ -14768,6 +14757,7 @@ dependencies = [
|
||||
"project",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"theme",
|
||||
"ui",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
@@ -16232,7 +16222,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_slash_command",
|
||||
"async-recursion 1.1.1",
|
||||
"async-recursion",
|
||||
"breadcrumbs",
|
||||
"client",
|
||||
"collections",
|
||||
@@ -16581,6 +16571,7 @@ dependencies = [
|
||||
"call",
|
||||
"chrono",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"db",
|
||||
"gpui",
|
||||
@@ -16616,7 +16607,7 @@ dependencies = [
|
||||
"backtrace",
|
||||
"bytes 1.10.1",
|
||||
"libc",
|
||||
"mio",
|
||||
"mio 1.0.3",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
@@ -18594,7 +18585,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "webrtc-sys"
|
||||
version = "0.3.7"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cxx",
|
||||
@@ -18607,15 +18598,13 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "webrtc-sys-build"
|
||||
version = "0.3.6"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=383e5377f8b7de1f8627ee16f0cf11c5293337bd#383e5377f8b7de1f8627ee16f0cf11c5293337bd"
|
||||
source = "git+https://github.com/zed-industries/livekit-rust-sdks?rev=5f04705ac3f356350ae31534ffbc476abc9ea83d#5f04705ac3f356350ae31534ffbc476abc9ea83d"
|
||||
dependencies = [
|
||||
"fs2",
|
||||
"hex-literal",
|
||||
"regex",
|
||||
"reqwest 0.11.27",
|
||||
"scratch",
|
||||
"semver",
|
||||
"sha2",
|
||||
"zip",
|
||||
]
|
||||
|
||||
@@ -18644,7 +18633,6 @@ dependencies = [
|
||||
"serde",
|
||||
"settings",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"vim_mode_setting",
|
||||
@@ -19659,7 +19647,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"any_vec",
|
||||
"anyhow",
|
||||
"async-recursion 1.1.1",
|
||||
"async-recursion",
|
||||
"bincode",
|
||||
"call",
|
||||
"client",
|
||||
@@ -19793,7 +19781,7 @@ dependencies = [
|
||||
"md-5",
|
||||
"memchr",
|
||||
"miniz_oxide",
|
||||
"mio",
|
||||
"mio 1.0.3",
|
||||
"naga",
|
||||
"nix 0.29.0",
|
||||
"nom",
|
||||
@@ -20184,7 +20172,7 @@ dependencies = [
|
||||
"async-io",
|
||||
"async-lock",
|
||||
"async-process",
|
||||
"async-recursion 1.1.1",
|
||||
"async-recursion",
|
||||
"async-task",
|
||||
"async-trait",
|
||||
"blocking",
|
||||
@@ -20617,6 +20605,7 @@ dependencies = [
|
||||
"call",
|
||||
"client",
|
||||
"clock",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
@@ -20637,7 +20626,6 @@ dependencies = [
|
||||
"menu",
|
||||
"postage",
|
||||
"project",
|
||||
"proto",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
@@ -20662,6 +20650,42 @@ dependencies = [
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeta_cli"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"client",
|
||||
"debug_adapter_extension",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"language",
|
||||
"language_extension",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"languages",
|
||||
"node_runtime",
|
||||
"paths",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"shellexpand 2.1.2",
|
||||
"smol",
|
||||
"terminal_view",
|
||||
"util",
|
||||
"watch",
|
||||
"workspace-hack",
|
||||
"zeta",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
|
||||
@@ -4,6 +4,7 @@ members = [
|
||||
"crates/acp_thread",
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent2",
|
||||
"crates/agent_servers",
|
||||
"crates/agent_settings",
|
||||
"crates/agent_ui",
|
||||
@@ -189,6 +190,7 @@ members = [
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zeta",
|
||||
"crates/zeta_cli",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
|
||||
@@ -227,6 +229,7 @@ edition = "2024"
|
||||
|
||||
acp_thread = { path = "crates/acp_thread" }
|
||||
agent = { path = "crates/agent" }
|
||||
agent2 = { path = "crates/agent2" }
|
||||
activity_indicator = { path = "crates/activity_indicator" }
|
||||
agent_ui = { path = "crates/agent_ui" }
|
||||
agent_settings = { path = "crates/agent_settings" }
|
||||
@@ -420,7 +423,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = "0.0.11"
|
||||
agent-client-protocol = {path="../agent-client-protocol"}
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
@@ -678,8 +681,6 @@ features = [
|
||||
"UI_ViewManagement",
|
||||
"Wdk_System_SystemServices",
|
||||
"Win32_Globalization",
|
||||
"Win32_Graphics_Direct2D",
|
||||
"Win32_Graphics_Direct2D_Common",
|
||||
"Win32_Graphics_Direct3D",
|
||||
"Win32_Graphics_Direct3D11",
|
||||
"Win32_Graphics_Direct3D_Fxc",
|
||||
@@ -690,7 +691,6 @@ features = [
|
||||
"Win32_Graphics_Dxgi_Common",
|
||||
"Win32_Graphics_Gdi",
|
||||
"Win32_Graphics_Imaging",
|
||||
"Win32_Graphics_Imaging_D2D",
|
||||
"Win32_Networking_WinSock",
|
||||
"Win32_Security",
|
||||
"Win32_Security_Credentials",
|
||||
|
||||
3
assets/icons/editor_atom.svg
Normal file
|
After Width: | Height: | Size: 6.3 KiB |
9
assets/icons/editor_cursor.svg
Normal file
@@ -0,0 +1,9 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path opacity="0.6" d="M3.5 11V5.5L8.5 8L3.5 11Z" fill="black"/>
|
||||
<path opacity="0.4" d="M8.5 14L3.5 11L8.5 8V14Z" fill="black"/>
|
||||
<path opacity="0.6" d="M8.5 5.5H3.5L8.5 2.5L8.5 5.5Z" fill="black"/>
|
||||
<path opacity="0.8" d="M8.5 5.5V2.5L13.5 5.5H8.5Z" fill="black"/>
|
||||
<path opacity="0.2" d="M13.5 11L8.5 14L11 9.5L13.5 11Z" fill="black"/>
|
||||
<path opacity="0.5" d="M13.5 11L11 9.5L13.5 5V11Z" fill="black"/>
|
||||
<path d="M3.5 11V5L8.5 2.11325L13.5 5V11L8.5 13.8868L3.5 11Z" stroke="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 583 B |
10
assets/icons/editor_emacs.svg
Normal file
@@ -0,0 +1,10 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_2716_663)">
|
||||
<path d="M8.47552 2.45453C11.5167 2.45457 13.9814 4.94501 13.9814 8.01623C13.9814 11.0875 11.5167 13.578 8.47552 13.5781C5.43427 13.5781 2.96948 11.0875 2.96948 8.01623C2.9695 4.94498 5.43429 2.45453 8.47552 2.45453ZM10.8795 4.70348C10.7605 4.16887 10.1328 3.85468 9.53627 3.96342C8.97622 4.06552 7.62871 4.45681 7.62057 4.45916C9.29414 4.44469 9.57429 4.4726 9.69939 4.64751C9.77324 4.7508 9.66576 4.89248 9.21944 4.96538C8.73515 5.04447 7.73014 5.13958 7.72343 5.14022C6.75441 5.19776 6.07177 5.20168 5.86705 5.63512C5.73334 5.91827 6.00968 6.16857 6.13082 6.32527C6.64271 6.89455 7.38215 7.20158 7.85809 7.42767C8.03716 7.51274 8.56257 7.67345 8.56257 7.67345C7.01855 7.58853 5.90474 8.06267 5.2514 8.60855C4.51246 9.29204 4.83937 10.1067 6.35327 10.6084C7.24742 10.9047 7.69094 11.0439 9.02473 10.9238C9.81031 10.8815 9.9342 10.9068 9.94203 10.9712C9.95275 11.062 9.06932 11.2874 8.82812 11.357C8.21455 11.534 6.60645 11.8913 6.59758 11.8932C6.60115 11.8935 7.06249 11.9257 7.65531 11.8735C7.89632 11.8522 8.81142 11.7624 9.49557 11.6123C9.49557 11.6123 10.3297 11.4338 10.7759 11.2693C11.2429 11.0973 11.497 10.9512 11.6113 10.7443C11.6063 10.7019 11.6465 10.5516 11.4313 10.4613C10.8807 10.2304 10.2423 10.2721 8.9789 10.2453C7.57789 10.1972 7.11184 9.9626 6.86356 9.77373C6.62548 9.58212 6.74518 9.05204 7.76528 8.5851C8.27917 8.33646 10.2935 7.87759 10.2935 7.87759C9.61511 7.54227 8.35014 6.95284 8.09005 6.82552C7.86199 6.71388 7.49701 6.54572 7.4179 6.34233C7.32824 6.14709 7.6297 5.97888 7.79813 5.9307C8.34057 5.77424 9.10635 5.67701 9.8033 5.66609C10.1536 5.66061 10.2105 5.63806 10.2105 5.63806C10.6939 5.55787 11.0121 5.22722 10.8795 4.70348Z" fill="black"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_2716_663">
|
||||
<rect width="12" height="12" fill="white" transform="translate(2.5 2)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.9 KiB |
3
assets/icons/editor_jet_brains.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3.6725 13.9985C3.36161 13.9982 3.06354 13.8746 2.84371 13.6548C2.62388 13.435 2.50026 13.1369 2.5 12.826V7.494C2.5 6.8325 2.7675 6.185 3.2365 5.7165L6.219 2.736C6.45192 2.50247 6.72867 2.31724 7.03335 2.19094C7.33804 2.06464 7.66467 1.99975 7.9945 2H13.3275C13.6384 2.00027 13.9365 2.12388 14.1563 2.34371C14.3761 2.56354 14.4997 2.86162 14.5 3.1725V8.5045C14.4983 9.17074 14.2336 9.80936 13.7635 10.2815L10.781 13.264C10.5477 13.4976 10.2706 13.6829 9.96561 13.8092C9.66059 13.9355 9.33364 14.0003 9.0035 14V13.9985H3.6725ZM8.157 10.5715H5.243V11.257H8.157V10.5715ZM4.4815 5.257H11.243V12.0165L13.3715 9.888C13.7373 9.52036 13.9433 9.02316 13.9445 8.5045V3.1725C13.9445 2.8335 13.6685 2.5555 13.3275 2.5555H7.9945C7.73753 2.55499 7.483 2.6053 7.24556 2.70356C7.00813 2.80181 6.79246 2.94606 6.611 3.128L4.4815 5.257ZM4.3855 5.353L3.628 6.11C3.26258 6.47809 3.0569 6.97533 3.0555 7.494V12.826C3.0555 13.165 3.3315 13.443 3.6725 13.443H9.0055C9.26249 13.4434 9.51701 13.3929 9.75445 13.2946C9.99188 13.1963 10.2075 13.052 10.389 12.87L11.145 12.1145H4.3855V5.353Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
5
assets/icons/editor_sublime.svg
Normal file
@@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M13.0945 8.01611C13.0945 7.87619 12.9911 7.79551 12.8642 7.8356L4.13456 10.6038C4.00742 10.6441 3.90427 10.7904 3.90427 10.9301V13.7593C3.90427 13.8992 4.00742 13.9801 4.13456 13.9398L12.8642 11.1719C12.9911 11.1315 13.0945 10.9852 13.0945 10.8453V8.01611Z" fill="black"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.90427 7.92597C3.90427 8.06588 4.00742 8.21218 4.13456 8.25252L12.8655 11.0209C12.9926 11.0613 13.0958 10.9803 13.0958 10.8407V8.01124C13.0958 7.87158 12.9926 7.72529 12.8655 7.68494L4.13456 4.91652C4.00742 4.87618 3.90427 4.95686 3.90427 5.09677V7.92597Z" fill="black"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M13.0945 2.20248C13.0945 2.06256 12.9911 1.98163 12.8642 2.02197L4.13456 4.78988C4.00742 4.83022 3.90427 4.97652 3.90427 5.11644V7.94563C3.90427 8.08554 4.00742 8.16622 4.13456 8.12614L12.8642 5.35797C12.9911 5.31763 13.0945 5.17133 13.0945 5.03167V2.20248Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
3
assets/icons/editor_vs_code.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.0094 13.9181C11.1984 13.9917 11.4139 13.987 11.6047 13.8952L14.0753 12.7064C14.3349 12.5814 14.5 12.3187 14.5 12.0305V3.9696C14.5 3.68136 14.3349 3.41862 14.0753 3.2937L11.6047 2.10485C11.3543 1.98438 11.0614 2.01389 10.8416 2.17363C10.8102 2.19645 10.7803 2.22193 10.7523 2.25001L6.02261 6.56498L3.96246 5.00115C3.77068 4.85558 3.50244 4.86751 3.32432 5.02953L2.66356 5.63059C2.44569 5.82877 2.44544 6.17152 2.66302 6.37004L4.44965 8.00001L2.66302 9.62998C2.44544 9.82849 2.44569 10.1713 2.66356 10.3694L3.32432 10.9705C3.50244 11.1325 3.77068 11.1444 3.96246 10.9989L6.02261 9.43504L10.7523 13.75C10.8271 13.8249 10.915 13.8812 11.0094 13.9181ZM11.5018 5.27587L7.91309 8.00001L11.5018 10.7241V5.27587Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 876 B |
4
assets/icons/shield_check.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M13.0001 8.62505C13.0001 11.75 10.8126 13.3125 8.21266 14.2187C8.07651 14.2648 7.92862 14.2626 7.79392 14.2125C5.18771 13.3125 3.00024 11.75 3.00024 8.62505V4.25012C3.00024 4.08436 3.06609 3.92539 3.1833 3.80818C3.30051 3.69098 3.45948 3.62513 3.62523 3.62513C4.87521 3.62513 6.43769 2.87514 7.52517 1.92516C7.65758 1.81203 7.82601 1.74988 8.00016 1.74988C8.17431 1.74988 8.34275 1.81203 8.47515 1.92516C9.56889 2.88139 11.1251 3.62513 12.3751 3.62513C12.5408 3.62513 12.6998 3.69098 12.817 3.80818C12.9342 3.92539 13.0001 4.08436 13.0001 4.25012V8.62505Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6 8.00002L7.33333 9.33335L10 6.66669" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 883 B |
@@ -598,6 +598,7 @@
|
||||
"ctrl-shift-t": "pane::ReopenClosedItem",
|
||||
"ctrl-k ctrl-s": "zed::OpenKeymapEditor",
|
||||
"ctrl-k ctrl-t": "theme_selector::Toggle",
|
||||
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
|
||||
"ctrl-t": "project_symbols::Toggle",
|
||||
"ctrl-p": "file_finder::Toggle",
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
@@ -1167,5 +1168,14 @@
|
||||
"up": "menu::SelectPrevious",
|
||||
"down": "menu::SelectNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Onboarding",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-1": "onboarding::ActivateBasicsPage",
|
||||
"ctrl-2": "onboarding::ActivateEditingPage",
|
||||
"ctrl-3": "onboarding::ActivateAISetupPage"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -665,6 +665,7 @@
|
||||
"cmd-shift-t": "pane::ReopenClosedItem",
|
||||
"cmd-k cmd-s": "zed::OpenKeymapEditor",
|
||||
"cmd-k cmd-t": "theme_selector::Toggle",
|
||||
"ctrl-alt-cmd-p": "settings_profile_selector::Toggle",
|
||||
"cmd-t": "project_symbols::Toggle",
|
||||
"cmd-p": "file_finder::Toggle",
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
@@ -1269,5 +1270,14 @@
|
||||
"up": "menu::SelectPrevious",
|
||||
"down": "menu::SelectNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Onboarding",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-1": "onboarding::ActivateBasicsPage",
|
||||
"cmd-2": "onboarding::ActivateEditingPage",
|
||||
"cmd-3": "onboarding::ActivateAISetupPage"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
"ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }],
|
||||
"alt-shift-f10": "task::Spawn",
|
||||
"ctrl-e": "file_finder::Toggle",
|
||||
"ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
// "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
"ctrl-shift-n": "file_finder::Toggle",
|
||||
"ctrl-shift-a": "command_palette::Toggle",
|
||||
"shift shift": "command_palette::Toggle",
|
||||
|
||||
@@ -97,7 +97,7 @@
|
||||
"cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }],
|
||||
"ctrl-alt-r": "task::Spawn",
|
||||
"cmd-e": "file_finder::Toggle",
|
||||
"cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
// "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
"cmd-shift-o": "file_finder::Toggle",
|
||||
"cmd-shift-a": "command_palette::Toggle",
|
||||
"shift shift": "command_palette::Toggle",
|
||||
|
||||
@@ -1878,7 +1878,24 @@
|
||||
"dock": "bottom",
|
||||
"button": true
|
||||
},
|
||||
// Configures any number of settings profiles that are temporarily applied
|
||||
// when selected from `settings profile selector: toggle`.
|
||||
// Configures any number of settings profiles that are temporarily applied on
|
||||
// top of your existing user settings when selected from
|
||||
// `settings profile selector: toggle`.
|
||||
// Examples:
|
||||
// "profiles": {
|
||||
// "Presenting": {
|
||||
// "agent_font_size": 20.0,
|
||||
// "buffer_font_size": 20.0,
|
||||
// "theme": "One Light",
|
||||
// "ui_font_size": 20.0
|
||||
// },
|
||||
// "Python (ty)": {
|
||||
// "languages": {
|
||||
// "Python": {
|
||||
// "language_servers": ["ty"]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
"profiles": []
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
markdown.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -391,7 +391,7 @@ impl ToolCallContent {
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
match content {
|
||||
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
|
||||
acp::ToolCallContent::Content { content } => Self::ContentBlock {
|
||||
content: ContentBlock::new(content, &language_registry, cx),
|
||||
},
|
||||
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
||||
@@ -580,6 +580,9 @@ pub struct AcpThread {
|
||||
pub enum AcpThreadEvent {
|
||||
NewEntry,
|
||||
EntryUpdated(usize),
|
||||
ToolAuthorizationRequired,
|
||||
Stopped,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||
@@ -616,6 +619,7 @@ impl Error for LoadError {}
|
||||
|
||||
impl AcpThread {
|
||||
pub fn new(
|
||||
title: impl Into<SharedString>,
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
project: Entity<Project>,
|
||||
session_id: acp::SessionId,
|
||||
@@ -628,7 +632,7 @@ impl AcpThread {
|
||||
shared_buffers: Default::default(),
|
||||
entries: Default::default(),
|
||||
plan: Default::default(),
|
||||
title: connection.name().into(),
|
||||
title: title.into(),
|
||||
project,
|
||||
send_task: None,
|
||||
connection,
|
||||
@@ -652,6 +656,10 @@ impl AcpThread {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> &acp::SessionId {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
pub fn status(&self) -> ThreadStatus {
|
||||
if self.send_task.is_some() {
|
||||
if self.waiting_for_tool_confirmation() {
|
||||
@@ -676,20 +684,32 @@ impl AcpThread {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn used_tools_since_last_user_message(&self) -> bool {
|
||||
for entry in self.entries.iter().rev() {
|
||||
match entry {
|
||||
AgentThreadEntry::UserMessage(..) => return false,
|
||||
AgentThreadEntry::AssistantMessage(..) => continue,
|
||||
AgentThreadEntry::ToolCall(..) => return true,
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn handle_session_update(
|
||||
&mut self,
|
||||
update: acp::SessionUpdate,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
match update {
|
||||
acp::SessionUpdate::UserMessage(content_block) => {
|
||||
self.push_user_content_block(content_block, cx);
|
||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||
self.push_user_content_block(content, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentMessageChunk(content_block) => {
|
||||
self.push_assistant_content_block(content_block, false, cx);
|
||||
acp::SessionUpdate::AgentMessageChunk { content } => {
|
||||
self.push_assistant_content_block(content, false, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentThoughtChunk(content_block) => {
|
||||
self.push_assistant_content_block(content_block, true, cx);
|
||||
acp::SessionUpdate::AgentThoughtChunk { content } => {
|
||||
self.push_assistant_content_block(content, true, cx);
|
||||
}
|
||||
acp::SessionUpdate::ToolCall(tool_call) => {
|
||||
self.upsert_tool_call(tool_call, cx);
|
||||
@@ -879,6 +899,7 @@ impl AcpThread {
|
||||
};
|
||||
|
||||
self.upsert_tool_call_inner(tool_call, status, cx);
|
||||
cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
|
||||
rx
|
||||
}
|
||||
|
||||
@@ -957,10 +978,6 @@ impl AcpThread {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.connection.authenticate(cx)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn send_raw(
|
||||
&mut self,
|
||||
@@ -1002,7 +1019,7 @@ impl AcpThread {
|
||||
let result = this
|
||||
.update(cx, |this, cx| {
|
||||
this.connection.prompt(
|
||||
acp::PromptArguments {
|
||||
acp::PromptRequest {
|
||||
prompt: message,
|
||||
session_id: this.session_id.clone(),
|
||||
},
|
||||
@@ -1018,12 +1035,18 @@ impl AcpThread {
|
||||
.log_err();
|
||||
}));
|
||||
|
||||
async move {
|
||||
match rx.await {
|
||||
Ok(Err(e)) => Err(e)?,
|
||||
_ => Ok(()),
|
||||
cx.spawn(async move |this, cx| match rx.await {
|
||||
Ok(Err(e)) => {
|
||||
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
|
||||
.log_err();
|
||||
Err(e)?
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
|
||||
.log_err();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
@@ -1598,9 +1621,15 @@ mod tests {
|
||||
connection,
|
||||
child_status: io_task,
|
||||
current_thread: thread_rc,
|
||||
auth_methods: [acp::AuthMethod {
|
||||
id: acp::AuthMethodId("acp-old-no-id".into()),
|
||||
label: "Log in".into(),
|
||||
description: None,
|
||||
}],
|
||||
};
|
||||
|
||||
AcpThread::new(
|
||||
"test",
|
||||
Rc::new(connection),
|
||||
project,
|
||||
acp::SessionId("test".into()),
|
||||
|
||||
@@ -1,16 +1,62 @@
|
||||
use std::{path::Path, rc::Rc};
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language_model::LanguageModel;
|
||||
use project::Project;
|
||||
use ui::App;
|
||||
|
||||
use crate::AcpThread;
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn name(&self) -> &'static str;
|
||||
/// Trait for agents that support listing, selecting, and querying language models.
|
||||
///
|
||||
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
|
||||
pub trait ModelSelector: 'static {
|
||||
/// Lists all available language models for this agent.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `cx`: The GPUI app context for async operations and global access.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the list of models or an error (e.g., if no models are configured).
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
|
||||
|
||||
/// Selects a model for a specific session (thread).
|
||||
///
|
||||
/// This sets the default model for future interactions in the session.
|
||||
/// If the session doesn't exist or the model is invalid, it returns an error.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `session_id`: The ID of the session (thread) to apply the model to.
|
||||
/// - `model`: The model to select (should be one from [list_models]).
|
||||
/// - `cx`: The GPUI app context.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to `Ok(())` on success or an error.
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>>;
|
||||
|
||||
/// Retrieves the currently selected model for a specific session (thread).
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `session_id`: The ID of the session (thread) to query.
|
||||
/// - `cx`: The GPUI app context.
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the selected model (always set) or an error (e.g., session not found).
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>>;
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
@@ -18,9 +64,29 @@ pub trait AgentConnection {
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
None // Default impl for agents that don't support it
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,11 @@ use anyhow::{Context as _, Result};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
||||
use std::{cell::RefCell, path::Path, rc::Rc};
|
||||
use ui::App;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AcpThread, AgentConnection};
|
||||
use crate::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OldAcpClientDelegate {
|
||||
@@ -351,28 +351,15 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Unauthenticated;
|
||||
|
||||
impl Error for Unauthenticated {}
|
||||
impl fmt::Display for Unauthenticated {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unauthenticated")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OldAcpAgentConnection {
|
||||
pub name: &'static str,
|
||||
pub connection: acp_old::AgentConnection,
|
||||
pub child_status: Task<Result<()>>,
|
||||
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||
pub auth_methods: [acp::AuthMethod; 1],
|
||||
}
|
||||
|
||||
impl AgentConnection for OldAcpAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
@@ -391,13 +378,13 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(Unauthenticated)
|
||||
anyhow::bail!(AuthRequired)
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||
AcpThread::new(self.clone(), project, session_id, cx)
|
||||
AcpThread::new("Gemini", self.clone(), project, session_id, cx)
|
||||
});
|
||||
current_thread.replace(thread.downgrade());
|
||||
thread
|
||||
@@ -405,7 +392,11 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&self.auth_methods
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::AuthenticateParams.into_any());
|
||||
@@ -415,7 +406,7 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let chunks = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
|
||||
@@ -47,7 +47,6 @@ paths.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
ref-cast.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
@@ -13,7 +13,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
|
||||
use collections::HashMap;
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::{FutureExt, StreamExt as _, future::Shared};
|
||||
@@ -37,7 +37,6 @@ use project::{
|
||||
git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
|
||||
};
|
||||
use prompt_store::{ModelContext, PromptBuilder};
|
||||
use proto::Plan;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
@@ -3255,8 +3254,10 @@ impl Thread {
|
||||
}
|
||||
|
||||
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
|
||||
self.project.update(cx, |project, cx| {
|
||||
project.user_store().update(cx, |user_store, cx| {
|
||||
self.project
|
||||
.read(cx)
|
||||
.user_store()
|
||||
.update(cx, |user_store, cx| {
|
||||
user_store.update_model_request_usage(
|
||||
ModelRequestUsage(RequestUsage {
|
||||
amount: amount as i32,
|
||||
@@ -3264,8 +3265,7 @@ impl Thread {
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn deny_tool_use(
|
||||
|
||||
56
crates/agent2/Cargo.toml
Normal file
@@ -0,0 +1,56 @@
|
||||
[package]
|
||||
name = "agent2"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "GPL-3.0-or-later"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
path = "src/agent2.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_servers.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
thiserror.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
worktree.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
gpui = { workspace = true, "features" = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
language_model = { workspace = true, "features" = ["test-support"] }
|
||||
project = { workspace = true, "features" = ["test-support"] }
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, "features" = ["test-support"] }
|
||||
worktree = { workspace = true, "features" = ["test-support"] }
|
||||
1
crates/agent2/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
377
crates/agent2/src/agent.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
use acp_thread::ModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use project::Project;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{templates::Templates, Thread};
|
||||
|
||||
/// Holds both the internal Thread and the AcpThread for a session
|
||||
#[derive(Clone)]
|
||||
struct Session {
|
||||
/// The internal thread that processes messages
|
||||
thread: Entity<Thread>,
|
||||
/// The ACP thread that handles protocol communication
|
||||
acp_thread: Entity<acp_thread::AcpThread>,
|
||||
}
|
||||
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl NativeAgent {
|
||||
pub fn new(templates: Arc<Templates>) -> Self {
|
||||
log::info!("Creating new NativeAgent");
|
||||
Self {
|
||||
sessions: HashMap::new(),
|
||||
templates,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl ModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
log::info!("Found {} available models", models.len());
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!(
|
||||
"Setting model for session {}: {:?}",
|
||||
session_id,
|
||||
model.name()
|
||||
);
|
||||
let agent = self.0.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(session) = agent.sessions.get(&session_id) {
|
||||
session.thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let session = agent
|
||||
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = session
|
||||
.thread
|
||||
.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
|
||||
let agent = self.0.clone();
|
||||
log::info!("Creating new thread for project at: {:?}", cwd);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Starting thread creation in async context");
|
||||
// Create Thread
|
||||
let (session_id, thread) = agent.update(
|
||||
cx,
|
||||
|agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
|
||||
// Fetch default model from registry settings
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
|
||||
// Log available models for debugging
|
||||
let available_count = registry.available_models(cx).count();
|
||||
log::debug!("Total available models: {}", available_count);
|
||||
|
||||
let default_model = registry
|
||||
.default_model()
|
||||
.map(|configured| {
|
||||
log::info!(
|
||||
"Using configured default model: {:?} from provider: {:?}",
|
||||
configured.model.name(),
|
||||
configured.provider.name()
|
||||
);
|
||||
configured.model
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
|
||||
|
||||
// Generate session ID
|
||||
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
|
||||
log::info!("Created session with ID: {}", session_id);
|
||||
Ok((session_id, thread))
|
||||
},
|
||||
)??;
|
||||
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
})?;
|
||||
|
||||
// Store the session
|
||||
agent.update(cx, |agent, _cx| {
|
||||
agent.sessions.insert(
|
||||
session_id,
|
||||
Session {
|
||||
thread,
|
||||
acp_thread: acp_thread.clone(),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
|
||||
Ok(acp_thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[] // No auth for in-process
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
log::debug!("Prompt blocks count: {}", params.prompt.len());
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Get session
|
||||
let session = agent
|
||||
.read_with(cx, |agent, _| {
|
||||
agent.sessions.get(&session_id).map(|s| Session {
|
||||
thread: s.thread.clone(),
|
||||
acp_thread: s.acp_thread.clone(),
|
||||
})
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
log::error!("Session not found: {}", session_id);
|
||||
anyhow::anyhow!("Session not found")
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
// Convert prompt to message
|
||||
let message = convert_prompt_to_message(params.prompt);
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
log::debug!("Message content: {}", message);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
let model = session
|
||||
.thread
|
||||
.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let response_stream = session
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.send(model, message, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
let acp_thread = session.acp_thread.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
use futures::StreamExt;
|
||||
use language_model::LanguageModelCompletionEvent;
|
||||
|
||||
let mut response_stream = response_stream;
|
||||
|
||||
while let Some(result) = response_stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::Text(text) => {
|
||||
// Send text chunk as agent message
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentMessageChunk {
|
||||
content: acp::ContentBlock::Text(
|
||||
acp::TextContent {
|
||||
text: text.into(),
|
||||
annotations: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
// Convert LanguageModelToolUse to ACP ToolCall
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::ToolCall(acp::ToolCall {
|
||||
id: acp::ToolCallId(tool_use.id.to_string().into()),
|
||||
label: tool_use.name.to_string(),
|
||||
kind: acp::ToolKind::Other,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
locations: vec![],
|
||||
raw_input: Some(tool_use.input),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
log::debug!("Started new assistant message");
|
||||
}
|
||||
LanguageModelCompletionEvent::UsageUpdate(usage) => {
|
||||
log::debug!("Token usage update: {:?}", usage);
|
||||
}
|
||||
LanguageModelCompletionEvent::Thinking { text, .. } => {
|
||||
// Send thinking text as agent thought chunk
|
||||
acp_thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(
|
||||
acp::SessionUpdate::AgentThoughtChunk {
|
||||
content: acp::ContentBlock::Text(
|
||||
acp::TextContent {
|
||||
text: text.into(),
|
||||
annotations: None,
|
||||
},
|
||||
),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})??;
|
||||
}
|
||||
LanguageModelCompletionEvent::StatusUpdate(status) => {
|
||||
log::trace!("Status update: {:?}", status);
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => {
|
||||
log::debug!("Assistant message complete: {:?}", stop_reason);
|
||||
}
|
||||
LanguageModelCompletionEvent::RedactedThinking { .. } => {
|
||||
log::trace!("Redacted thinking event");
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input,
|
||||
json_parse_error,
|
||||
} => {
|
||||
log::error!(
|
||||
"Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}",
|
||||
tool_name,
|
||||
id,
|
||||
json_parse_error,
|
||||
raw_input
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Error in model response stream: {:?}", e);
|
||||
// TODO: Consider sending an error message to the UI
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("Response stream completed");
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach();
|
||||
|
||||
log::info!("Successfully sent prompt to thread and started response handler");
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
log::info!("Cancelling session: {}", session_id);
|
||||
self.0.update(cx, |agent, _cx| {
|
||||
agent.sessions.remove(session_id);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert ACP content blocks to a message string
|
||||
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||
log::debug!("Converting {} content blocks to message", blocks.len());
|
||||
let mut message = String::new();
|
||||
|
||||
for block in blocks {
|
||||
match block {
|
||||
acp::ContentBlock::Text(text) => {
|
||||
log::trace!("Processing text block: {} chars", text.text.len());
|
||||
message.push_str(&text.text);
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(link) => {
|
||||
log::trace!("Processing resource link: {}", link.uri);
|
||||
message.push_str(&format!(" @{} ", link.uri));
|
||||
}
|
||||
acp::ContentBlock::Image(_) => {
|
||||
log::trace!("Processing image block");
|
||||
message.push_str(" [image] ");
|
||||
}
|
||||
acp::ContentBlock::Audio(_) => {
|
||||
log::trace!("Processing audio block");
|
||||
message.push_str(" [audio] ");
|
||||
}
|
||||
acp::ContentBlock::Resource(resource) => {
|
||||
log::trace!("Processing resource block: {:?}", resource.resource);
|
||||
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
13
crates/agent2/src/agent2.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
mod agent;
|
||||
mod native_agent_server;
|
||||
mod prompts;
|
||||
mod templates;
|
||||
mod thread;
|
||||
mod tools;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use agent::*;
|
||||
pub use native_agent_server::NativeAgentServer;
|
||||
pub use thread::*;
|
||||
58
crates/agent2/src/native_agent_server.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use project::Project;
|
||||
|
||||
use crate::{templates::Templates, NativeAgent, NativeAgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer;
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn name(&self) -> &'static str {
|
||||
"Native Agent"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Native Agent"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"How can I help you today?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
// Using the ZedAssistant icon as it's the native built-in agent
|
||||
ui::IconName::ZedAssistant
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
|
||||
log::info!(
|
||||
"NativeAgentServer::connect called for path: {:?}",
|
||||
_root_dir
|
||||
);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
// Create templates (you might want to load these from files or resources)
|
||||
let templates = Templates::new();
|
||||
|
||||
// Create the native agent
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = cx.update(|cx| cx.new(|_| NativeAgent::new(templates)))?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
log::info!("NativeAgentServer connection established successfully");
|
||||
|
||||
Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
|
||||
})
|
||||
}
|
||||
}
|
||||
30
crates/agent2/src/prompts.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use crate::{
|
||||
templates::{BaseTemplate, Template, Templates, WorktreeData},
|
||||
thread::Prompt,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use gpui::{App, Entity};
|
||||
use project::Project;
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct BasePrompt {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl Prompt for BasePrompt {
|
||||
fn render(&self, templates: &Templates, cx: &App) -> Result<String> {
|
||||
BaseTemplate {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
shell: util::get_system_shell(),
|
||||
worktrees: self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| WorktreeData {
|
||||
root_name: worktree.read(cx).root_name().to_string(),
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
.render(templates)
|
||||
}
|
||||
}
|
||||
57
crates/agent2/src/templates.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct BaseTemplate {
|
||||
pub os: String,
|
||||
pub shell: String,
|
||||
pub worktrees: Vec<WorktreeData>,
|
||||
}
|
||||
|
||||
impl Template for BaseTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "base.hbs";
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct WorktreeData {
|
||||
pub root_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GlobTemplate {
|
||||
pub project_roots: String,
|
||||
}
|
||||
|
||||
impl Template for GlobTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "glob.hbs";
|
||||
}
|
||||
56
crates/agent2/src/templates/base.hbs
Normal file
@@ -0,0 +1,56 @@
|
||||
You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices.
|
||||
|
||||
## Communication
|
||||
|
||||
1. Be conversational but professional.
|
||||
2. Refer to the USER in the second person and yourself in the first person.
|
||||
3. Format your responses in markdown. Use backticks to format file, directory, function, and class names.
|
||||
4. NEVER lie or make things up.
|
||||
5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing.
|
||||
|
||||
## Tool Use
|
||||
|
||||
1. Make sure to adhere to the tools schema.
|
||||
2. Provide every required argument.
|
||||
3. DO NOT use tools to access items that are already available in the context section.
|
||||
4. Use only the tools that are currently available.
|
||||
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
|
||||
|
||||
## Searching and Reading
|
||||
|
||||
If you are unsure how to fulfill the user's request, gather more information with tool calls and/or clarifying questions.
|
||||
|
||||
If appropriate, use tool calls to explore the current project, which contains the following root directories:
|
||||
|
||||
{{#each worktrees}}
|
||||
- `{{root_name}}`
|
||||
{{/each}}
|
||||
|
||||
- When providing paths to tools, the path should always begin with a path that starts with a project root directory listed above.
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.
|
||||
2. Never simplify code you've written just to solve diagnostics. Complete, mostly correct code is more valuable than perfect code that doesn't solve the problem.
|
||||
|
||||
## Debugging
|
||||
|
||||
When debugging, only make code changes if you are certain that you can solve the problem.
|
||||
Otherwise, follow debugging best practices:
|
||||
1. Address the root cause instead of the symptoms.
|
||||
2. Add descriptive logging statements and error messages to track variable and code state.
|
||||
3. Add test functions and statements to isolate the problem.
|
||||
|
||||
## Calling External APIs
|
||||
|
||||
1. Unless explicitly requested by the user, use the best suited external APIs and packages to solve the task. There is no need to ask the user for permission.
|
||||
2. When selecting which version of an API or package to use, choose one that is compatible with the user's dependency management file. If no such file exists or if the package is not present, use the latest version that is in your training data.
|
||||
3. If an external API requires an API Key, be sure to point this out to the user. Adhere to best security practices (e.g. DO NOT hardcode an API key in a place where it can be exposed)
|
||||
|
||||
## System Information
|
||||
|
||||
Operating System: {{os}}
|
||||
Default Shell: {{shell}}
|
||||
8
crates/agent2/src/templates/glob.hbs
Normal file
@@ -0,0 +1,8 @@
|
||||
Find paths on disk with glob patterns.
|
||||
|
||||
Assume that all glob patterns are matched in a project directory with the following entries.
|
||||
|
||||
{{project_roots}}
|
||||
|
||||
When searching with patterns that begin with literal path components, e.g. `foo/bar/**/*.rs`, be
|
||||
sure to anchor them with one of the directories listed above.
|
||||
378
crates/agent2/src/tests/mod.rs
Normal file
@@ -0,0 +1,378 @@
|
||||
use super::*;
|
||||
use crate::templates::Templates;
|
||||
use acp_thread::AgentConnection as _;
|
||||
use agent_client_protocol as acp;
|
||||
use client::{Client, UserStore};
|
||||
use gpui::{AppContext, Entity, TestAppContext};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, MessageContent, StopReason,
|
||||
};
|
||||
use project::Project;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
|
||||
|
||||
mod test_tools;
|
||||
use test_tools::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_echo(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
|
||||
// Test a tool calls that's likely to complete *after* streaming stops.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
assert_eq!(
|
||||
stop_events(events),
|
||||
vec![StopReason::ToolUse, StopReason::EndTurn]
|
||||
);
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert!(thread
|
||||
.messages()
|
||||
.last()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(WordListTool);
|
||||
thread.send(model.clone(), "Test the word_list tool.", cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
while let Some(event) = events.next().await {
|
||||
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use_event)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_use_event.is_input_complete {
|
||||
last_tool_use
|
||||
.input
|
||||
.get("a")
|
||||
.expect("'a' has streamed because input is now complete");
|
||||
last_tool_use
|
||||
.input
|
||||
.get("g")
|
||||
.expect("'g' has streamed because input is now complete");
|
||||
} else {
|
||||
if !last_tool_use.is_input_complete
|
||||
&& last_tool_use.input.get("g").is_none()
|
||||
{
|
||||
saw_partial_tool_use = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
panic!("last content should be a tool use");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
saw_partial_tool_use,
|
||||
"should see at least one partially streamed tool use in the history"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx).await;
|
||||
|
||||
// Test concurrent tool calls with different delay times
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
model.clone(),
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let stop_reasons = stop_events(events);
|
||||
if stop_reasons.len() == 2 {
|
||||
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
|
||||
} else if stop_reasons.len() == 3 {
|
||||
assert_eq!(
|
||||
stop_reasons,
|
||||
vec![
|
||||
StopReason::ToolUse,
|
||||
StopReason::ToolUse,
|
||||
StopReason::EndTurn
|
||||
]
|
||||
);
|
||||
} else {
|
||||
panic!("Expected either 1 or 2 tool uses followed by end turn");
|
||||
}
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
let text = last_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
|
||||
assert!(text.contains("Ding"));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let templates = Templates::new();
|
||||
|
||||
// Initialize language model system with test provider
|
||||
cx.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
// Initialize project settings
|
||||
Project::init_settings(cx);
|
||||
|
||||
// Use test registry with fake provider
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
|
||||
// Create agent and connection
|
||||
let agent = cx.new(|_| NativeAgent::new(templates.clone()));
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
let selector_opt = connection.model_selector();
|
||||
assert!(
|
||||
selector_opt.is_some(),
|
||||
"agent2 should always support ModelSelector"
|
||||
);
|
||||
let selector = selector_opt.unwrap();
|
||||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
|
||||
// Create a thread using new_thread
|
||||
let cwd = Path::new("/test");
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
// Get the session_id from the AcpThread
|
||||
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||
|
||||
// Test selected_model returns the default
|
||||
let selected = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
assert_eq!(selected.id().0, "fake", "should return default model");
|
||||
|
||||
// The thread was created via prompt with the default model
|
||||
// We can verify it through selected_model
|
||||
|
||||
// Test prompt uses the selected model
|
||||
let prompt_request = acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec![acp::ContentBlock::Text(acp::TextContent {
|
||||
text: "Test prompt".into(),
|
||||
annotations: None,
|
||||
})],
|
||||
};
|
||||
|
||||
cx.update(|cx| connection.prompt(prompt_request, cx))
|
||||
.await
|
||||
.expect("prompt should succeed");
|
||||
|
||||
// The prompt was sent successfully
|
||||
|
||||
// Test cancel
|
||||
cx.update(|cx| connection.cancel(&session_id, cx));
|
||||
|
||||
// After cancel, selected_model should fail
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "selected_model should fail after cancel");
|
||||
|
||||
// Test error case: invalid session
|
||||
let invalid_session = acp::SessionId("invalid".into());
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&invalid_session, &mut async_cx)
|
||||
})
|
||||
.await;
|
||||
assert!(result.is_err(), "should fail for invalid session");
|
||||
if let Err(e) = result {
|
||||
assert!(
|
||||
e.to_string().contains("Session not found"),
|
||||
"should have correct error message"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters out the stop events for asserting against in tests
|
||||
fn stop_events(
|
||||
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
) -> Vec<StopReason> {
|
||||
result_events
|
||||
.into_iter()
|
||||
.filter_map(|event| match event.unwrap() {
|
||||
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
}
|
||||
|
||||
async fn setup(cx: &mut TestAppContext) -> ThreadTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
let templates = Templates::new();
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
|
||||
let thread = cx.new(|_| Thread::new(templates, model.clone()));
|
||||
|
||||
ThreadTest { model, thread }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
85
crates/agent2/src/tests/test_tools.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use gpui::{App, SharedString, Task};
|
||||
|
||||
/// A tool that echoes its input
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct EchoToolInput {
|
||||
/// The text to echo.
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub struct EchoTool;
|
||||
|
||||
impl AgentTool for EchoTool {
|
||||
type Input = EchoToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"echo".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok(input.text))
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that waits for a specified delay
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct DelayToolInput {
|
||||
/// The delay in milliseconds.
|
||||
ms: u64,
|
||||
}
|
||||
|
||||
pub struct DelayTool;
|
||||
|
||||
impl AgentTool for DelayTool {
|
||||
type Input = DelayToolInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"delay".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
cx.foreground_executor().spawn(async move {
|
||||
smol::Timer::after(Duration::from_millis(input.ms)).await;
|
||||
Ok("Ding".to_string())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool that takes an object with map from letters to random words starting with that letter.
|
||||
/// All fiealds are required! Pass a word for every letter!
|
||||
#[derive(JsonSchema, Serialize, Deserialize)]
|
||||
pub struct WordListInput {
|
||||
/// Provide a random word that starts with A.
|
||||
a: Option<String>,
|
||||
/// Provide a random word that starts with B.
|
||||
b: Option<String>,
|
||||
/// Provide a random word that starts with C.
|
||||
c: Option<String>,
|
||||
/// Provide a random word that starts with D.
|
||||
d: Option<String>,
|
||||
/// Provide a random word that starts with E.
|
||||
e: Option<String>,
|
||||
/// Provide a random word that starts with F.
|
||||
f: Option<String>,
|
||||
/// Provide a random word that starts with G.
|
||||
g: Option<String>,
|
||||
}
|
||||
|
||||
pub struct WordListTool;
|
||||
|
||||
impl AgentTool for WordListTool {
|
||||
type Input = WordListInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"word_list".into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, _input: Self::Input, _cx: &mut App) -> Task<Result<String>> {
|
||||
Task::ready(Ok("ok".to_string()))
|
||||
}
|
||||
}
|
||||
493
crates/agent2/src/thread.rs
Normal file
@@ -0,0 +1,493 @@
|
||||
use crate::templates::Templates;
|
||||
use anyhow::{anyhow, Result};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use futures::{channel::mpsc, future};
|
||||
use gpui::{App, Context, SharedString, Task};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason,
|
||||
};
|
||||
use log;
|
||||
use schemars::{JsonSchema, Schema};
|
||||
use serde::Deserialize;
|
||||
use smol::stream::StreamExt;
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentMessage {
|
||||
pub role: Role,
|
||||
pub content: Vec<MessageContent>,
|
||||
}
|
||||
|
||||
pub type AgentResponseEvent = LanguageModelCompletionEvent;
|
||||
|
||||
pub trait Prompt {
|
||||
fn render(&self, prompts: &Templates, cx: &App) -> Result<String>;
|
||||
}
|
||||
|
||||
pub struct Thread {
|
||||
messages: Vec<AgentMessage>,
|
||||
completion_mode: CompletionMode,
|
||||
/// Holds the task that handles agent interaction until the end of the turn.
|
||||
/// Survives across multiple requests as the model performs tool calls and
|
||||
/// we run tools, report their results.
|
||||
running_turn: Option<Task<()>>,
|
||||
system_prompts: Vec<Arc<dyn Prompt>>,
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
templates: Arc<Templates>,
|
||||
pub selected_model: Arc<dyn LanguageModel>,
|
||||
// action_log: Entity<ActionLog>,
|
||||
}
|
||||
|
||||
impl Thread {
|
||||
pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
completion_mode: CompletionMode::Normal,
|
||||
system_prompts: Vec::new(),
|
||||
running_turn: None,
|
||||
tools: BTreeMap::default(),
|
||||
templates,
|
||||
selected_model: default_model,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_mode(&mut self, mode: CompletionMode) {
|
||||
self.completion_mode = mode;
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[AgentMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
pub fn add_tool(&mut self, tool: impl AgentTool) {
|
||||
self.tools.insert(tool.name(), tool.erase());
|
||||
}
|
||||
|
||||
pub fn remove_tool(&mut self, name: &str) -> bool {
|
||||
self.tools.remove(name).is_some()
|
||||
}
|
||||
|
||||
/// Sending a message results in the model streaming a response, which could include tool calls.
|
||||
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
|
||||
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
|
||||
pub fn send(
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
content: impl Into<MessageContent>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
|
||||
let content = content.into();
|
||||
log::info!("Thread::send called with model: {:?}", model.name());
|
||||
log::debug!("Thread::send content: {:?}", content);
|
||||
|
||||
cx.notify();
|
||||
let (events_tx, events_rx) =
|
||||
mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
|
||||
|
||||
let system_message = self.build_system_message(cx);
|
||||
log::debug!(
|
||||
"System messages count: {}",
|
||||
if system_message.is_some() { 1 } else { 0 }
|
||||
);
|
||||
self.messages.extend(system_message);
|
||||
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: vec![content],
|
||||
});
|
||||
log::info!("Total messages in thread: {}", self.messages.len());
|
||||
self.running_turn = Some(cx.spawn(async move |thread, cx| {
|
||||
log::info!("Starting agent turn execution");
|
||||
let turn_result = async {
|
||||
// Perform one request, then keep looping if the model makes tool calls.
|
||||
let mut completion_intent = CompletionIntent::UserPrompt;
|
||||
loop {
|
||||
log::debug!(
|
||||
"Building completion request with intent: {:?}",
|
||||
completion_intent
|
||||
);
|
||||
let request = thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(completion_intent, cx)
|
||||
})?;
|
||||
|
||||
// println!(
|
||||
// "request: {}",
|
||||
// serde_json::to_string_pretty(&request).unwrap()
|
||||
// );
|
||||
|
||||
// Stream events, appending to messages and collecting up tool uses.
|
||||
log::info!("Calling model.stream_completion");
|
||||
let mut events = model.stream_completion(request, cx).await?;
|
||||
log::debug!("Stream completion started successfully");
|
||||
let mut tool_uses = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
log::trace!("Received completion event: {:?}", event);
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
tool_uses.extend(thread.handle_streamed_completion_event(
|
||||
event,
|
||||
events_tx.clone(),
|
||||
cx,
|
||||
));
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(error) => {
|
||||
log::error!("Error in completion stream: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no tool uses, the turn is done.
|
||||
if tool_uses.is_empty() {
|
||||
log::info!("No tool uses found, completing turn");
|
||||
break;
|
||||
}
|
||||
log::info!("Found {} tool uses to execute", tool_uses.len());
|
||||
|
||||
// If there are tool uses, wait for their results to be
|
||||
// computed, then send them together in a single message on
|
||||
// the next loop iteration.
|
||||
let tool_results = future::join_all(tool_uses).await;
|
||||
log::debug!("Tool execution completed, {} results", tool_results.len());
|
||||
thread
|
||||
.update(cx, |thread, _cx| {
|
||||
thread.messages.push(AgentMessage {
|
||||
role: Role::User,
|
||||
content: tool_results
|
||||
.into_iter()
|
||||
.map(MessageContent::ToolResult)
|
||||
.collect(),
|
||||
});
|
||||
})
|
||||
.ok();
|
||||
completion_intent = CompletionIntent::ToolResults;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(error) = turn_result {
|
||||
log::error!("Turn execution failed: {:?}", error);
|
||||
events_tx.unbounded_send(Err(error)).ok();
|
||||
} else {
|
||||
log::info!("Turn execution completed successfully");
|
||||
}
|
||||
}));
|
||||
events_rx
|
||||
}
|
||||
|
||||
pub fn build_system_message(&mut self, cx: &App) -> Option<AgentMessage> {
|
||||
log::debug!("Building system message");
|
||||
let mut system_message = AgentMessage {
|
||||
role: Role::System,
|
||||
content: Vec::new(),
|
||||
};
|
||||
|
||||
for prompt in &self.system_prompts {
|
||||
if let Some(rendered_prompt) = prompt.render(&self.templates, cx).log_err() {
|
||||
system_message
|
||||
.content
|
||||
.push(MessageContent::Text(rendered_prompt));
|
||||
}
|
||||
}
|
||||
|
||||
let result = (!system_message.content.is_empty()).then_some(system_message);
|
||||
log::debug!("System message built: {}", result.is_some());
|
||||
result
|
||||
}
|
||||
|
||||
/// A helper method that's called on every streamed completion event.
|
||||
/// Returns an optional tool result task, which the main agentic loop in
|
||||
/// send will send back to the model when it resolves.
|
||||
fn handle_streamed_completion_event(
|
||||
&mut self,
|
||||
event: LanguageModelCompletionEvent,
|
||||
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
log::trace!("Handling streamed completion event: {:?}", event);
|
||||
use LanguageModelCompletionEvent::*;
|
||||
events_tx.unbounded_send(Ok(event.clone())).ok();
|
||||
|
||||
match event {
|
||||
Text(new_text) => self.handle_text_event(new_text, cx),
|
||||
Thinking {
|
||||
text: _text,
|
||||
signature: _signature,
|
||||
} => {
|
||||
todo!()
|
||||
}
|
||||
ToolUse(tool_use) => {
|
||||
return self.handle_tool_use_event(tool_use, cx);
|
||||
}
|
||||
StartMessage { .. } => {
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
UsageUpdate(_) => {}
|
||||
Stop(stop_reason) => self.handle_stop_event(stop_reason),
|
||||
StatusUpdate(_completion_request_status) => {}
|
||||
RedactedThinking { data: _data } => todo!(),
|
||||
ToolUseJsonParseError {
|
||||
id: _id,
|
||||
tool_name: _tool_name,
|
||||
raw_input: _raw_input,
|
||||
json_parse_error: _json_parse_error,
|
||||
} => todo!(),
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_stop_event(&mut self, stop_reason: StopReason) {
|
||||
match stop_reason {
|
||||
StopReason::EndTurn | StopReason::ToolUse => {}
|
||||
StopReason::MaxTokens => todo!(),
|
||||
StopReason::Refusal => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
|
||||
let last_message = self.last_assistant_message();
|
||||
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
|
||||
text.push_str(&new_text);
|
||||
} else {
|
||||
last_message.content.push(MessageContent::Text(new_text));
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_tool_use_event(
|
||||
&mut self,
|
||||
tool_use: LanguageModelToolUse,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<LanguageModelToolResult>> {
|
||||
cx.notify();
|
||||
|
||||
let last_message = self.last_assistant_message();
|
||||
|
||||
// Ensure the last message ends in the current tool use
|
||||
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
|
||||
if let MessageContent::ToolUse(last_tool_use) = content {
|
||||
if last_tool_use.id == tool_use.id {
|
||||
*last_tool_use = tool_use.clone();
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if push_new_tool_use {
|
||||
last_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
}
|
||||
|
||||
if !tool_use.is_input_complete {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(tool) = self.tools.get(tool_use.name.as_ref()) {
|
||||
let pending_tool_result = tool.clone().run(tool_use.input, cx);
|
||||
|
||||
Some(cx.foreground_executor().spawn(async move {
|
||||
match pending_tool_result.await {
|
||||
Ok(tool_output) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: false,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(tool_output)),
|
||||
output: None,
|
||||
},
|
||||
Err(error) => LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())),
|
||||
output: None,
|
||||
},
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
Some(Task::ready(LanguageModelToolResult {
|
||||
content: LanguageModelToolResultContent::Text(Arc::from(format!(
|
||||
"No tool named {} exists",
|
||||
tool_use.name
|
||||
))),
|
||||
tool_use_id: tool_use.id,
|
||||
tool_name: tool_use.name,
|
||||
is_error: true,
|
||||
output: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Guarantees the last message is from the assistant and returns a mutable reference.
|
||||
fn last_assistant_message(&mut self) -> &mut AgentMessage {
|
||||
if self
|
||||
.messages
|
||||
.last()
|
||||
.map_or(true, |m| m.role != Role::Assistant)
|
||||
{
|
||||
self.messages.push(AgentMessage {
|
||||
role: Role::Assistant,
|
||||
content: Vec::new(),
|
||||
});
|
||||
}
|
||||
self.messages.last_mut().unwrap()
|
||||
}
|
||||
|
||||
fn build_completion_request(
|
||||
&self,
|
||||
completion_intent: CompletionIntent,
|
||||
cx: &mut App,
|
||||
) -> LanguageModelRequest {
|
||||
log::debug!("Building completion request");
|
||||
log::debug!("Completion intent: {:?}", completion_intent);
|
||||
log::debug!("Completion mode: {:?}", self.completion_mode);
|
||||
|
||||
let messages = self.build_request_messages();
|
||||
log::info!("Request will include {} messages", messages.len());
|
||||
|
||||
let tools: Vec<LanguageModelRequestTool> = self
|
||||
.tools
|
||||
.values()
|
||||
.filter_map(|tool| {
|
||||
let tool_name = tool.name().to_string();
|
||||
log::trace!("Including tool: {}", tool_name);
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool_name,
|
||||
description: tool.description(cx).to_string(),
|
||||
input_schema: tool
|
||||
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
|
||||
.log_err()?,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
log::info!("Request includes {} tools", tools.len());
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: Some(completion_intent),
|
||||
mode: Some(self.completion_mode),
|
||||
messages,
|
||||
tools,
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: false,
|
||||
};
|
||||
|
||||
log::debug!("Completion request built successfully");
|
||||
request
|
||||
}
|
||||
|
||||
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
|
||||
log::trace!(
|
||||
"Building request messages from {} thread messages",
|
||||
self.messages.len()
|
||||
);
|
||||
let messages = self
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
log::trace!(
|
||||
" - {} message with {} content items",
|
||||
match message.role {
|
||||
Role::System => "System",
|
||||
Role::User => "User",
|
||||
Role::Assistant => "Assistant",
|
||||
},
|
||||
message.content.len()
|
||||
);
|
||||
LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
content: message.content.clone(),
|
||||
cache: false,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentTool
|
||||
where
|
||||
Self: 'static + Sized,
|
||||
{
|
||||
type Input: for<'de> Deserialize<'de> + JsonSchema;
|
||||
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, _cx: &mut App) -> SharedString {
|
||||
let schema = schemars::schema_for!(Self::Input);
|
||||
SharedString::new(
|
||||
schema
|
||||
.get("description")
|
||||
.and_then(|description| description.as_str())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the JSON schema that describes the tool's input.
|
||||
fn input_schema(&self, _format: LanguageModelToolSchemaFormat) -> Schema {
|
||||
schemars::schema_for!(Self::Input)
|
||||
}
|
||||
|
||||
/// Runs the tool with the provided input.
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>>;
|
||||
|
||||
fn erase(self) -> Arc<dyn AnyAgentTool> {
|
||||
Arc::new(Erased(Arc::new(self)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Erased<T>(T);
|
||||
|
||||
pub trait AnyAgentTool {
|
||||
fn name(&self) -> SharedString;
|
||||
fn description(&self, cx: &mut App) -> SharedString;
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value>;
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
impl<T> AnyAgentTool for Erased<Arc<T>>
|
||||
where
|
||||
T: AgentTool,
|
||||
{
|
||||
fn name(&self) -> SharedString {
|
||||
self.0.name()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
self.0.description(cx)
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
Ok(serde_json::to_value(self.0.input_schema(format))?)
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: serde_json::Value, cx: &mut App) -> Task<Result<String>> {
|
||||
let parsed_input: Result<T::Input> = serde_json::from_value(input).map_err(Into::into);
|
||||
match parsed_input {
|
||||
Ok(input) => self.0.clone().run(input, cx),
|
||||
Err(error) => Task::ready(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
1
crates/agent2/src/tools.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod glob;
|
||||
76
crates/agent2/src/tools/glob.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use gpui::{App, AppContext, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::paths::PathMatcher;
|
||||
use worktree::Snapshot as WorktreeSnapshot;
|
||||
|
||||
use crate::{
|
||||
templates::{GlobTemplate, Template, Templates},
|
||||
thread::AgentTool,
|
||||
};
|
||||
|
||||
// Description is dynamic, see `fn description` below
|
||||
#[derive(Deserialize, JsonSchema)]
|
||||
struct GlobInput {
|
||||
/// A POSIX glob pattern
|
||||
glob: SharedString,
|
||||
}
|
||||
|
||||
struct GlobTool {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl AgentTool for GlobTool {
|
||||
type Input = GlobInput;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"glob".into()
|
||||
}
|
||||
|
||||
fn description(&self, cx: &mut App) -> SharedString {
|
||||
let project_roots = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).root_name().into())
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
GlobTemplate { project_roots }
|
||||
.render(&self.templates)
|
||||
.expect("template failed to render")
|
||||
.into()
|
||||
}
|
||||
|
||||
fn run(self: Arc<Self>, input: Self::Input, cx: &mut App) -> Task<Result<String>> {
|
||||
let path_matcher = match PathMatcher::new([&input.glob]) {
|
||||
Ok(matcher) => matcher,
|
||||
Err(error) => return Task::ready(Err(anyhow!(error))),
|
||||
};
|
||||
|
||||
let snapshots: Vec<WorktreeSnapshot> = self
|
||||
.project
|
||||
.read(cx)
|
||||
.worktrees(cx)
|
||||
.map(|worktree| worktree.read(cx).snapshot())
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let paths = snapshots.iter().flat_map(|snapshot| {
|
||||
let root_name = PathBuf::from(snapshot.root_name());
|
||||
snapshot
|
||||
.entries(false, 0)
|
||||
.map(move |entry| root_name.join(&entry.path))
|
||||
.filter(|path| path_matcher.is_match(&path))
|
||||
});
|
||||
let output = paths
|
||||
.map(|path| format!("{}\n", path.display()))
|
||||
.collect::<String>();
|
||||
Ok(output)
|
||||
})
|
||||
}
|
||||
}
|
||||
0
crates/agent_servers/acp
Normal file
245
crates/agent_servers/src/acp_connection.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use agent_client_protocol::{self as acp, Agent as _};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use project::Project;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::AgentServerCommand;
|
||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
pub struct AcpConnection {
|
||||
server_name: &'static str,
|
||||
connection: Rc<acp::ClientSideConnection>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
auth_methods: Vec<acp::AuthMethod>,
|
||||
_io_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
pub struct AcpSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
}
|
||||
|
||||
impl AcpConnection {
|
||||
pub async fn stdio(
|
||||
server_name: &'static str,
|
||||
command: AgentServerCommand,
|
||||
root_dir: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||
.envs(command.env.iter().flatten())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdout = child.stdout.take().expect("Failed to take stdout");
|
||||
let stdin = child.stdin.take().expect("Failed to take stdin");
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let client = ClientDelegate {
|
||||
sessions: sessions.clone(),
|
||||
cx: cx.clone(),
|
||||
};
|
||||
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
move |fut| {
|
||||
foreground_executor.spawn(fut).detach();
|
||||
}
|
||||
});
|
||||
|
||||
let io_task = cx.background_spawn(io_task);
|
||||
|
||||
let response = connection
|
||||
.initialize(acp::InitializeRequest {
|
||||
protocol_version: acp::VERSION,
|
||||
client_capabilities: acp::ClientCapabilities {
|
||||
fs: acp::FileSystemCapability {
|
||||
read_text_file: true,
|
||||
write_text_file: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
// todo! check version
|
||||
|
||||
Ok(Self {
|
||||
auth_methods: response.auth_methods,
|
||||
connection: connection.into(),
|
||||
server_name,
|
||||
sessions,
|
||||
_io_task: io_task,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for AcpConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let conn = self.connection.clone();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let response = conn
|
||||
.new_session(acp::NewSessionRequest {
|
||||
// todo! Zed MCP server?
|
||||
mcp_servers: vec![],
|
||||
cwd,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let Some(session_id) = response.session_id else {
|
||||
anyhow::bail!(AuthRequired);
|
||||
};
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
self.server_name,
|
||||
self.clone(),
|
||||
project,
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let session = AcpSession {
|
||||
thread: thread.downgrade(),
|
||||
};
|
||||
sessions.borrow_mut().insert(session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&self.auth_methods
|
||||
}
|
||||
|
||||
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let result = conn
|
||||
.authenticate(acp::AuthenticateRequest {
|
||||
method_id: method_id.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { Ok(conn.prompt(params).await?) })
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
let conn = self.connection.clone();
|
||||
let params = acp::CancelledNotification {
|
||||
session_id: session_id.clone(),
|
||||
};
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { conn.cancelled(params).await })
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientDelegate {
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
cx: AsyncApp,
|
||||
}
|
||||
|
||||
impl acp::Client for ClientDelegate {
|
||||
async fn request_permission(
|
||||
&self,
|
||||
arguments: acp::RequestPermissionRequest,
|
||||
) -> Result<acp::RequestPermissionResponse, acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let result = self
|
||||
.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
|
||||
};
|
||||
|
||||
Ok(acp::RequestPermissionResponse { outcome })
|
||||
}
|
||||
|
||||
async fn write_text_file(
|
||||
&self,
|
||||
arguments: acp::WriteTextFileRequest,
|
||||
) -> Result<(), acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
self.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(arguments.path, arguments.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_text_file(
|
||||
&self,
|
||||
arguments: acp::ReadTextFileRequest,
|
||||
) -> Result<acp::ReadTextFileResponse, acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let content = self
|
||||
.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(acp::ReadTextFileResponse { content })
|
||||
}
|
||||
|
||||
async fn session_notification(
|
||||
&self,
|
||||
notification: acp::SessionNotification,
|
||||
) -> Result<(), acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let sessions = self.sessions.borrow();
|
||||
let session = sessions
|
||||
.get(¬ification.session_id)
|
||||
.context("Failed to get session")?;
|
||||
|
||||
session.thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,12 @@
|
||||
mod acp_connection;
|
||||
mod claude;
|
||||
mod codex;
|
||||
mod gemini;
|
||||
mod mcp_server;
|
||||
mod settings;
|
||||
|
||||
#[cfg(test)]
|
||||
mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
pub use codex::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
|
||||
@@ -38,7 +36,6 @@ pub trait AgentServer: Send {
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
// these will go away when old_acp is fully removed
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
|
||||
@@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
|
||||
}
|
||||
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
ClaudeCode.name()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
@@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
}
|
||||
});
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
||||
})?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
@@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(¶ms.session_id) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
|
||||
@@ -1,319 +0,0 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use context_server::types::requests;
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::mcp_server::ZedMcpServer;
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Codex;
|
||||
|
||||
impl AgentServer for Codex {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Welcome to Codex"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"What can I help with?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiOpenAi
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let project = project.clone();
|
||||
let working_directory = project.read(cx).active_project_directory(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find codex binary");
|
||||
};
|
||||
|
||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
||||
ContextServerId("codex-mcp-server".into()),
|
||||
ContextServerCommand {
|
||||
path: command.path,
|
||||
args: command.args,
|
||||
env: command.env,
|
||||
},
|
||||
working_directory,
|
||||
)
|
||||
.into();
|
||||
ContextServer::start(client.clone(), cx).await?;
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
client
|
||||
.client()
|
||||
.context("Failed to subscribe")?
|
||||
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
|
||||
move |notification, _cx| {
|
||||
let notification_tx = notification_tx.clone();
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(notification) =
|
||||
serde_json::from_value::<acp::SessionNotification>(notification)
|
||||
.log_err()
|
||||
{
|
||||
notification_tx.unbounded_send(notification).ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let notification_handler_task = cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
while let Some(notification) = notification_rx.next().await {
|
||||
CodexConnection::handle_session_notification(
|
||||
notification,
|
||||
sessions.clone(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let connection = CodexConnection {
|
||||
client,
|
||||
sessions,
|
||||
_notification_handler_task: notification_handler_task,
|
||||
};
|
||||
Ok(Rc::new(connection) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct CodexConnection {
|
||||
client: Arc<context_server::ContextServer>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
_notification_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
struct CodexSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
cancel_tx: Option<oneshot::Sender<()>>,
|
||||
_mcp_server: ZedMcpServer,
|
||||
}
|
||||
|
||||
impl AgentConnection for CodexConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||
|
||||
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
||||
|
||||
let response = client
|
||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||
name: acp::NEW_SESSION_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
||||
mcp_servers: [(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
mcp_server.server_config()?,
|
||||
)]
|
||||
.into(),
|
||||
client_tools: acp::ClientTools {
|
||||
request_permission: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
||||
}),
|
||||
read_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
||||
}),
|
||||
write_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
||||
}),
|
||||
},
|
||||
cwd,
|
||||
})?),
|
||||
meta: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
||||
response.structured_content.context("Empty response")?,
|
||||
)?;
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
let session = CodexSession {
|
||||
thread: thread.downgrade(),
|
||||
cancel_tx: None,
|
||||
_mcp_server: mcp_server,
|
||||
};
|
||||
sessions.borrow_mut().insert(result.session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
params: agent_client_protocol::PromptArguments,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
|
||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
||||
{
|
||||
let mut sessions = sessions.borrow_mut();
|
||||
let session = sessions
|
||||
.get_mut(¶ms.session_id)
|
||||
.context("Session not found")?;
|
||||
session.cancel_tx.replace(new_cancel_tx);
|
||||
}
|
||||
|
||||
let result = client
|
||||
.request_with::<requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: acp::PROMPT_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
meta: None,
|
||||
},
|
||||
Some(cancel_rx),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = &result
|
||||
&& err.is::<context_server::client::RequestCanceled>()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let response = result?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
||||
let mut sessions = self.sessions.borrow_mut();
|
||||
|
||||
if let Some(cancel_tx) = sessions
|
||||
.get_mut(session_id)
|
||||
.and_then(|session| session.cancel_tx.take())
|
||||
{
|
||||
cancel_tx.send(()).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexConnection {
|
||||
pub fn handle_session_notification(
|
||||
notification: acp::SessionNotification,
|
||||
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let threads = threads.borrow();
|
||||
let Some(thread) = threads
|
||||
.get(¬ification.session_id)
|
||||
.and_then(|session| session.thread.upgrade())
|
||||
else {
|
||||
log::error!(
|
||||
"Thread not found for session ID: {}",
|
||||
notification.session_id
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CodexConnection {
|
||||
fn drop(&mut self) {
|
||||
self.client.stop().log_err();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::AgentServerCommand;
|
||||
use std::path::Path;
|
||||
|
||||
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../codex/codex-rs/target/debug/codex");
|
||||
|
||||
AgentServerCommand {
|
||||
path: cli_path,
|
||||
args: vec![],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
gemini: Some(AgentServerSettings {
|
||||
command: crate::gemini::tests::local_command(),
|
||||
}),
|
||||
codex: Some(AgentServerSettings {
|
||||
command: crate::codex::tests::local_command(),
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
use anyhow::anyhow;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
||||
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
|
||||
use agentic_coding_protocol as acp_old;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::Result;
|
||||
use gpui::{Entity, Task};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use ui::App;
|
||||
@@ -43,146 +39,27 @@ impl AgentServer for Gemini {
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let project = project.clone();
|
||||
let this = self.clone();
|
||||
let name = self.name();
|
||||
|
||||
let server_name = self.name();
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let command = this.command(&project, cx).await?;
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find gemini binary");
|
||||
};
|
||||
// todo! check supported version
|
||||
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
||||
|
||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
let result = match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(AgentServerVersion::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
}) = this.version(&command).await.log_err()
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(io_task);
|
||||
result
|
||||
});
|
||||
|
||||
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
|
||||
name,
|
||||
connection,
|
||||
child_status,
|
||||
current_thread: thread_rc,
|
||||
});
|
||||
|
||||
Ok(connection)
|
||||
let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?;
|
||||
Ok(Rc::new(conn) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Gemini {
|
||||
async fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
|
||||
if let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
{
|
||||
return Ok(command);
|
||||
};
|
||||
|
||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
||||
(project.fs().clone(), project.node_runtime().cloned())
|
||||
})?;
|
||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
||||
|
||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
||||
fs.create_dir(&directory).await?;
|
||||
node_runtime
|
||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
||||
.await?;
|
||||
let path = directory.join("node_modules/.bin/gemini");
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![ACP_ARG.into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
let version_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--version")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let help_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--help")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
||||
|
||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
||||
|
||||
if supported {
|
||||
Ok(AgentServerVersion::Supported)
|
||||
} else {
|
||||
Ok(AgentServerVersion::Unsupported {
|
||||
error_message: format!(
|
||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
||||
current_version
|
||||
).into(),
|
||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
@@ -199,7 +76,7 @@ pub(crate) mod tests {
|
||||
|
||||
AgentServerCommand {
|
||||
path: "node".into(),
|
||||
args: vec![cli_path, ACP_ARG.into()],
|
||||
args: vec![cli_path],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,207 +0,0 @@
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
use context_server::types::{
|
||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||
ToolsCapabilities, requests,
|
||||
};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use indoc::indoc;
|
||||
|
||||
pub struct ZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
|
||||
impl ZedMcpServer {
|
||||
pub async fn new(
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
|
||||
mcp_server.add_tool(RequestPermissionTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(ReadTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(WriteTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
}
|
||||
|
||||
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
|
||||
#[cfg(not(test))]
|
||||
let zed_path = anyhow::Context::context(
|
||||
std::env::current_exe(),
|
||||
"finding current executable path for use in mcp_server",
|
||||
)?;
|
||||
|
||||
#[cfg(test)]
|
||||
let zed_path = crate::e2e_tests::get_zed_path();
|
||||
|
||||
Ok(acp::McpServerConfig {
|
||||
command: zed_path,
|
||||
args: vec![
|
||||
"--nc".into(),
|
||||
self.server.socket_path().display().to_string(),
|
||||
],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(InitializeResponse {
|
||||
protocol_version: ProtocolVersion("2025-06-18".into()),
|
||||
capabilities: ServerCapabilities {
|
||||
experimental: None,
|
||||
logging: None,
|
||||
completions: None,
|
||||
prompts: None,
|
||||
resources: None,
|
||||
tools: Some(ToolsCapabilities {
|
||||
list_changed: Some(false),
|
||||
}),
|
||||
},
|
||||
server_info: Implementation {
|
||||
name: SERVER_NAME.into(),
|
||||
version: "0.1.0".into(),
|
||||
},
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RequestPermissionTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for RequestPermissionTool {
|
||||
type Input = acp::RequestPermissionArguments;
|
||||
type Output = acp::RequestPermissionOutput;
|
||||
|
||||
const NAME: &'static str = "Confirmation";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
indoc! {"
|
||||
Request permission for tool calls.
|
||||
|
||||
This tool is meant to be called programmatically by the agent loop, not the LLM.
|
||||
"}
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let result = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(input.tool_call, input.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
|
||||
};
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::RequestPermissionOutput { outcome },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReadTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for ReadTextFileTool {
|
||||
type Input = acp::ReadTextFileArguments;
|
||||
type Output = acp::ReadTextFileOutput;
|
||||
|
||||
const NAME: &'static str = "Read";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Reads the content of the given file in the project including unsaved changes."
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.path, input.line, input.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::ReadTextFileOutput { content },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WriteTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for WriteTextFileTool {
|
||||
type Input = acp::WriteTextFileArguments;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Write";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Write to a file replacing its contents"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(input.path, input.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
|
||||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
pub claude: Option<AgentServerSettings>,
|
||||
pub codex: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
@@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for AllAgentServersSettings {
|
||||
gemini,
|
||||
claude,
|
||||
codex,
|
||||
} in sources.defaults_and_customizations()
|
||||
{
|
||||
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
|
||||
if gemini.is_some() {
|
||||
settings.gemini = gemini.clone();
|
||||
}
|
||||
if claude.is_some() {
|
||||
settings.claude = claude.clone();
|
||||
}
|
||||
if codex.is_some() {
|
||||
settings.codex = codex.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
||||
@@ -19,6 +19,7 @@ test-support = ["gpui/test-support", "language/test-support"]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent.workspace = true
|
||||
agent2.workspace = true
|
||||
agent_servers.workspace = true
|
||||
agent_settings.workspace = true
|
||||
ai_onboarding.workspace = true
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use acp_thread::{AgentConnection, Plan};
|
||||
use agent_servers::AgentServer;
|
||||
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
|
||||
use audio::{Audio, Sound};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
@@ -18,10 +20,10 @@ use editor::{
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
|
||||
Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
|
||||
Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
|
||||
pulsating_between,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
|
||||
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
|
||||
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
|
||||
list, percentage, point, prelude::*, pulsating_between,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use language::{Buffer, Language};
|
||||
@@ -45,7 +47,10 @@ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSe
|
||||
use crate::acp::message_history::MessageHistory;
|
||||
use crate::agent_diff::AgentDiff;
|
||||
use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES};
|
||||
use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll};
|
||||
use crate::ui::{AgentNotification, AgentNotificationEvent};
|
||||
use crate::{
|
||||
AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
|
||||
};
|
||||
|
||||
const RESPONSE_PADDING_X: Pixels = px(19.);
|
||||
|
||||
@@ -59,6 +64,8 @@ pub struct AcpThreadView {
|
||||
message_set_from_history: bool,
|
||||
_message_editor_subscription: Subscription,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
last_error: Option<Entity<Markdown>>,
|
||||
list_state: ListState,
|
||||
auth_task: Option<Task<()>>,
|
||||
@@ -174,6 +181,8 @@ impl AcpThreadView {
|
||||
message_set_from_history: false,
|
||||
_message_editor_subscription: message_editor_subscription,
|
||||
mention_set,
|
||||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
diff_editors: Default::default(),
|
||||
list_state: list_state,
|
||||
last_error: None,
|
||||
@@ -223,7 +232,8 @@ impl AcpThreadView {
|
||||
{
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||
// todo! remove duplication
|
||||
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
@@ -381,7 +391,9 @@ impl AcpThreadView {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(thread) = self.thread() else { return };
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -564,6 +576,30 @@ impl AcpThreadView {
|
||||
self.sync_thread_entry_view(index, window, cx);
|
||||
self.list_state.splice(index..index + 1, 1);
|
||||
}
|
||||
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
}
|
||||
AcpThreadEvent::Stopped => {
|
||||
let used_tools = thread.read(cx).used_tools_since_last_user_message();
|
||||
self.notify_with_sound(
|
||||
if used_tools {
|
||||
"Finished running tools"
|
||||
} else {
|
||||
"New message"
|
||||
},
|
||||
IconName::ZedAssistant,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
AcpThreadEvent::Error => {
|
||||
self.notify_with_sound(
|
||||
"Agent stopped due to an error",
|
||||
IconName::Warning,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -640,13 +676,18 @@ impl AcpThreadView {
|
||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
fn authenticate(
|
||||
&mut self,
|
||||
method: acp::AuthMethodId,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.last_error.take();
|
||||
let authenticate = connection.authenticate(cx);
|
||||
let authenticate = connection.authenticate(method, cx);
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let project = self.project.clone();
|
||||
let agent = self.agent.clone();
|
||||
@@ -2160,6 +2201,154 @@ impl AcpThreadView {
|
||||
self.list_state.scroll_to(ListOffset::default());
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn notify_with_sound(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
icon: IconName,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.play_notification_sound(window, cx);
|
||||
self.show_notification(caption, icon, window, cx);
|
||||
}
|
||||
|
||||
fn play_notification_sound(&self, window: &Window, cx: &mut App) {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
if settings.play_sound_when_agent_done && !window.is_window_active() {
|
||||
Audio::play_sound(Sound::AgentDone, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn show_notification(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
icon: IconName,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if window.is_window_active() || !self.notifications.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let title = self.title(cx);
|
||||
|
||||
match AgentSettings::get_global(cx).notify_when_agent_waiting {
|
||||
NotifyWhenAgentWaiting::PrimaryScreen => {
|
||||
if let Some(primary) = cx.primary_display() {
|
||||
self.pop_up(icon, caption.into(), title, window, primary, cx);
|
||||
}
|
||||
}
|
||||
NotifyWhenAgentWaiting::AllScreens => {
|
||||
let caption = caption.into();
|
||||
for screen in cx.displays() {
|
||||
self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx);
|
||||
}
|
||||
}
|
||||
NotifyWhenAgentWaiting::Never => {
|
||||
// Don't show anything
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pop_up(
|
||||
&mut self,
|
||||
icon: IconName,
|
||||
caption: SharedString,
|
||||
title: SharedString,
|
||||
window: &mut Window,
|
||||
screen: Rc<dyn PlatformDisplay>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let options = AgentNotification::window_options(screen, cx);
|
||||
|
||||
let project_name = self.workspace.upgrade().and_then(|workspace| {
|
||||
workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| worktree.read(cx).root_name().to_string())
|
||||
});
|
||||
|
||||
if let Some(screen_window) = cx
|
||||
.open_window(options, |_, cx| {
|
||||
cx.new(|_| {
|
||||
AgentNotification::new(title.clone(), caption.clone(), icon, project_name)
|
||||
})
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
if let Some(pop_up) = screen_window.entity(cx).log_err() {
|
||||
self.notification_subscriptions
|
||||
.entry(screen_window)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(cx.subscribe_in(&pop_up, window, {
|
||||
|this, _, event, window, cx| match event {
|
||||
AgentNotificationEvent::Accepted => {
|
||||
let handle = window.window_handle();
|
||||
cx.activate(true);
|
||||
|
||||
let workspace_handle = this.workspace.clone();
|
||||
|
||||
// If there are multiple Zed windows, activate the correct one.
|
||||
cx.defer(move |cx| {
|
||||
handle
|
||||
.update(cx, |_view, window, _cx| {
|
||||
window.activate_window();
|
||||
|
||||
if let Some(workspace) = workspace_handle.upgrade() {
|
||||
workspace.update(_cx, |workspace, cx| {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
});
|
||||
|
||||
this.dismiss_notifications(cx);
|
||||
}
|
||||
AgentNotificationEvent::Dismissed => {
|
||||
this.dismiss_notifications(cx);
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
self.notifications.push(screen_window);
|
||||
|
||||
// If the user manually refocuses the original window, dismiss the popup.
|
||||
self.notification_subscriptions
|
||||
.entry(screen_window)
|
||||
.or_insert_with(Vec::new)
|
||||
.push({
|
||||
let pop_up_weak = pop_up.downgrade();
|
||||
|
||||
cx.observe_window_activation(window, move |_, window, cx| {
|
||||
if window.is_window_active() {
|
||||
if let Some(pop_up) = pop_up_weak.upgrade() {
|
||||
pop_up.update(cx, |_, cx| {
|
||||
cx.emit(AgentNotificationEvent::Dismissed);
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dismiss_notifications(&mut self, cx: &mut Context<Self>) {
|
||||
for window in self.notifications.drain(..) {
|
||||
window
|
||||
.update(cx, |_, window, _| {
|
||||
window.remove_window();
|
||||
})
|
||||
.ok();
|
||||
|
||||
self.notification_subscriptions.remove(&window);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AcpThreadView {
|
||||
@@ -2197,22 +2386,26 @@ impl Render for AcpThreadView {
|
||||
.on_action(cx.listener(Self::next_history_message))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { .. } => {
|
||||
v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(
|
||||
h_flex().mt_1p5().justify_center().child(
|
||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.authenticate(window, cx)
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
ThreadState::Unauthenticated { connection } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.label.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
ThreadState::LoadError(e) => v_flex()
|
||||
.p_2()
|
||||
@@ -2441,3 +2634,341 @@ fn plan_label_markdown_style(
|
||||
..default_md_style
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use agent_client_protocol::SessionId;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use rand::Rng;
|
||||
use settings::SettingsStore;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello", window, cx);
|
||||
});
|
||||
|
||||
cx.deactivate_window();
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
cx.windows()
|
||||
.iter()
|
||||
.any(|window| window.downcast::<AgentNotification>().is_some())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_error(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let (thread_view, cx) =
|
||||
setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello", window, cx);
|
||||
});
|
||||
|
||||
cx.deactivate_window();
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
cx.windows()
|
||||
.iter()
|
||||
.any(|window| window.downcast::<AgentNotification>().is_some())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let tool_call_id = acp::ToolCallId("1".into());
|
||||
let tool_call = acp::ToolCall {
|
||||
id: tool_call_id.clone(),
|
||||
label: "Label".into(),
|
||||
kind: acp::ToolKind::Edit,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec!["hi".into()],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
};
|
||||
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
|
||||
.with_permission_requests(HashMap::from_iter([(
|
||||
tool_call_id,
|
||||
vec![acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("1".into()),
|
||||
label: "Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
}],
|
||||
)]));
|
||||
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello", window, cx);
|
||||
});
|
||||
|
||||
cx.deactivate_window();
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
cx.windows()
|
||||
.iter()
|
||||
.any(|window| window.downcast::<AgentNotification>().is_some())
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_thread_view(
|
||||
agent: impl AgentServer + 'static,
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<AcpThreadView>, &mut VisualTestContext) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let thread_view = cx.update(|window, cx| {
|
||||
cx.new(|cx| {
|
||||
AcpThreadView::new(
|
||||
Rc::new(agent),
|
||||
workspace.downgrade(),
|
||||
project,
|
||||
Rc::new(RefCell::new(MessageHistory::default())),
|
||||
1,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
cx.run_until_parked();
|
||||
(thread_view, cx)
|
||||
}
|
||||
|
||||
struct StubAgentServer<C> {
|
||||
connection: C,
|
||||
}
|
||||
|
||||
impl<C> StubAgentServer<C> {
|
||||
fn new(connection: C) -> Self {
|
||||
Self { connection }
|
||||
}
|
||||
}
|
||||
|
||||
impl StubAgentServer<StubAgentConnection> {
|
||||
fn default() -> Self {
|
||||
Self::new(StubAgentConnection::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AgentServer for StubAgentServer<C>
|
||||
where
|
||||
C: 'static + AgentConnection + Send + Clone,
|
||||
{
|
||||
fn logo(&self) -> ui::IconName {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
|
||||
Task::ready(Ok(Rc::new(self.connection.clone())))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct StubAgentConnection {
|
||||
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
updates: Vec<acp::SessionUpdate>,
|
||||
}
|
||||
|
||||
impl StubAgentConnection {
|
||||
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
|
||||
Self {
|
||||
updates,
|
||||
permission_requests: HashMap::default(),
|
||||
sessions: Arc::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_permission_requests(
|
||||
mut self,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
) -> Self {
|
||||
self.permission_requests = permission_requests;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = SessionId(
|
||||
rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
);
|
||||
let thread = cx
|
||||
.new(|cx| {
|
||||
AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
.unwrap();
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
for update in &self.updates {
|
||||
let thread = thread.clone();
|
||||
let update = update.clone();
|
||||
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
|
||||
&& let Some(options) = self.permission_requests.get(&tool_call.id)
|
||||
{
|
||||
Some((tool_call.clone(), options.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let task = cx.spawn(async move |cx| {
|
||||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
tool_call.clone(),
|
||||
options.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
permission.await?;
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
cx.spawn(async move |_| {
|
||||
try_join_all(tasks).await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SaboteurAgentConnection;
|
||||
|
||||
impl AgentConnection for SaboteurAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
Task::ready(Ok(cx
|
||||
.new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx))
|
||||
.unwrap()))
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
Task::ready(Err(anyhow::anyhow!("Error prompting")))
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AgentSettings::register(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
EditorSettings::register(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use agent_settings::AgentSettings;
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use cloud_llm_client::Plan;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerId;
|
||||
use extension::ExtensionManifest;
|
||||
@@ -25,7 +26,6 @@ use project::{
|
||||
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
|
||||
project_settings::{ContextServerSettings, ProjectSettings},
|
||||
};
|
||||
use proto::Plan;
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
|
||||
@@ -180,7 +180,7 @@ impl AgentConfiguration {
|
||||
let current_plan = if is_zed_provider {
|
||||
self.workspace
|
||||
.upgrade()
|
||||
.and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan())
|
||||
.and_then(|workspace| workspace.read(cx).user_store().read(cx).plan())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -406,7 +406,9 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"always-allow-tool-actions-switch",
|
||||
"Allow running commands without asking for confirmation",
|
||||
"The agent can perform potentially destructive actions without asking for your confirmation.",
|
||||
Some(
|
||||
"The agent can perform potentially destructive actions without asking for your confirmation.".into(),
|
||||
),
|
||||
always_allow_tool_actions,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -424,7 +426,7 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"single-file-review",
|
||||
"Enable single-file agent reviews",
|
||||
"Agent edits are also displayed in single-file editors for review.",
|
||||
Some("Agent edits are also displayed in single-file editors for review.".into()),
|
||||
single_file_review,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -442,7 +444,9 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"sound-notification",
|
||||
"Play sound when finished generating",
|
||||
"Hear a notification sound when the agent is done generating changes or needs your input.",
|
||||
Some(
|
||||
"Hear a notification sound when the agent is done generating changes or needs your input.".into(),
|
||||
),
|
||||
play_sound_when_agent_done,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -460,7 +464,9 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"modifier-send",
|
||||
"Use modifier to submit a message",
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.",
|
||||
Some(
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
|
||||
),
|
||||
use_modifier_to_send,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -502,7 +508,7 @@ impl AgentConfiguration {
|
||||
.blend(cx.theme().colors().text_accent.opacity(0.2));
|
||||
|
||||
let (plan_name, label_color, bg_color) = match plan {
|
||||
Plan::Free => ("Free", Color::Default, free_chip_bg),
|
||||
Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
|
||||
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
|
||||
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
|
||||
};
|
||||
|
||||
@@ -1521,6 +1521,9 @@ impl AgentDiff {
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
}
|
||||
AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Error => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{DisableAiSettings, UserStore, zed_urls};
|
||||
use cloud_llm_client::{CompletionIntent, UsageLimit};
|
||||
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
@@ -60,7 +60,6 @@ use language_model::{
|
||||
};
|
||||
use project::{Project, ProjectPath, Worktree};
|
||||
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
|
||||
use proto::Plan;
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use search::{BufferSearchBar, buffer_search};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -579,7 +578,6 @@ impl AgentPanel {
|
||||
MessageEditor::new(
|
||||
fs.clone(),
|
||||
workspace.clone(),
|
||||
user_store.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
@@ -848,7 +846,6 @@ impl AgentPanel {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
self.user_store.clone(),
|
||||
context_store.clone(),
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
@@ -1122,7 +1119,6 @@ impl AgentPanel {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
self.user_store.clone(),
|
||||
context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
@@ -1958,54 +1954,54 @@ impl AgentPanel {
|
||||
this
|
||||
}
|
||||
})
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.separator()
|
||||
.header("External Agents")
|
||||
.item(
|
||||
ContextMenuEntry::new("New Gemini Thread")
|
||||
.icon(IconName::AiGemini)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Claude Code Thread")
|
||||
.icon(IconName::AiClaude)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Codex Thread")
|
||||
.icon(IconName::AiOpenAi)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Codex),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
});
|
||||
// Temporarily removed feature flag check for testing
|
||||
// .when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
// this
|
||||
.separator()
|
||||
.header("External Agents")
|
||||
.item(
|
||||
ContextMenuEntry::new("New Gemini Thread")
|
||||
.icon(IconName::AiGemini)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Claude Code Thread")
|
||||
.icon(IconName::AiClaude)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::ClaudeCode),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Native Agent Thread")
|
||||
.icon(IconName::ZedAssistant)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::NativeAgent),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
);
|
||||
// });
|
||||
menu
|
||||
}))
|
||||
}
|
||||
@@ -2293,10 +2289,10 @@ impl AgentPanel {
|
||||
| ActiveView::Configuration => return false,
|
||||
}
|
||||
|
||||
let plan = self.user_store.read(cx).current_plan();
|
||||
let plan = self.user_store.read(cx).plan();
|
||||
let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some();
|
||||
|
||||
matches!(plan, Some(Plan::Free)) && has_previous_trial
|
||||
matches!(plan, Some(Plan::ZedFree)) && has_previous_trial
|
||||
}
|
||||
|
||||
fn should_render_onboarding(&self, cx: &mut Context<Self>) -> bool {
|
||||
@@ -2612,82 +2608,87 @@ impl AgentPanel {
|
||||
),
|
||||
),
|
||||
)
|
||||
.when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
this.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-gemini-thread-btn",
|
||||
"New Gemini Thread",
|
||||
IconName::AiGemini,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::Gemini,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
// Temporarily removed feature flag check for testing
|
||||
// .when(cx.has_flag::<feature_flags::AcpFeatureFlag>(), |this| {
|
||||
// this
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-gemini-thread-btn",
|
||||
"New Gemini Thread",
|
||||
IconName::AiGemini,
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-claude-thread-btn",
|
||||
"New Claude Code Thread",
|
||||
IconName::AiClaude,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-codex-thread-btn",
|
||||
"New Codex Thread",
|
||||
IconName::AiOpenAi,
|
||||
)
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::Codex,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Gemini),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-claude-thread-btn",
|
||||
"New Claude Code Thread",
|
||||
IconName::AiClaude,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::ClaudeCode,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-native-agent-thread-btn",
|
||||
"New Native Agent Thread",
|
||||
IconName::ZedAssistant,
|
||||
)
|
||||
// .keybinding(KeyBinding::for_action_in(
|
||||
// &OpenHistory,
|
||||
// &self.focus_handle(cx),
|
||||
// window,
|
||||
// cx,
|
||||
// ))
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::NativeAgent,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
), // })
|
||||
)
|
||||
.when_some(configuration_error.as_ref(), |this, err| {
|
||||
this.child(self.render_configuration_error(err, &focus_handle, window, cx))
|
||||
@@ -2911,7 +2912,7 @@ impl AgentPanel {
|
||||
) -> AnyElement {
|
||||
let error_message = match plan {
|
||||
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
|
||||
Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.",
|
||||
Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.",
|
||||
};
|
||||
|
||||
let icon = Icon::new(IconName::XCircle)
|
||||
|
||||
@@ -150,7 +150,7 @@ enum ExternalAgent {
|
||||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
Codex,
|
||||
NativeAgent,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
@@ -158,7 +158,7 @@ impl ExternalAgent {
|
||||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use agent::{
|
||||
use agent_settings::{AgentSettings, CompletionMode};
|
||||
use ai_onboarding::ApiKeysWithProviders;
|
||||
use buffer_diff::BufferDiff;
|
||||
use client::UserStore;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
@@ -43,7 +42,6 @@ use language_model::{
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use proto::Plan;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
@@ -79,7 +77,6 @@ pub struct MessageEditor {
|
||||
editor: Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
user_store: Entity<UserStore>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
history_store: Option<WeakEntity<HistoryStore>>,
|
||||
@@ -159,7 +156,6 @@ impl MessageEditor {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
user_store: Entity<UserStore>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
@@ -231,7 +227,6 @@ impl MessageEditor {
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
project: thread.read(cx).project().clone(),
|
||||
user_store,
|
||||
thread,
|
||||
incompatible_tools_state: incompatible_tools.clone(),
|
||||
workspace,
|
||||
@@ -1287,24 +1282,12 @@ impl MessageEditor {
|
||||
return None;
|
||||
}
|
||||
|
||||
let user_store = self.user_store.read(cx);
|
||||
|
||||
let ubb_enable = user_store
|
||||
.usage_based_billing_enabled()
|
||||
.map_or(false, |enabled| enabled);
|
||||
|
||||
if ubb_enable {
|
||||
let user_store = self.project.read(cx).user_store().read(cx);
|
||||
if user_store.is_usage_based_billing_enabled() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let plan = user_store
|
||||
.current_plan()
|
||||
.map(|plan| match plan {
|
||||
Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
})
|
||||
.unwrap_or(cloud_llm_client::Plan::ZedFree);
|
||||
let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree);
|
||||
|
||||
let usage = user_store.model_request_usage()?;
|
||||
|
||||
@@ -1769,7 +1752,6 @@ impl AgentPreview for MessageEditor {
|
||||
) -> Option<AnyElement> {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
let fs = workspace.read(cx).app_state().fs.clone();
|
||||
let user_store = workspace.read(cx).app_state().user_store.clone();
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let weak_project = project.downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
|
||||
@@ -1782,7 +1764,6 @@ impl AgentPreview for MessageEditor {
|
||||
MessageEditor::new(
|
||||
fs,
|
||||
workspace.downgrade(),
|
||||
user_store,
|
||||
context_store,
|
||||
None,
|
||||
thread_store.downgrade(),
|
||||
|
||||
@@ -16,10 +16,10 @@ default = []
|
||||
|
||||
[dependencies]
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
component.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
proto.workspace = true
|
||||
serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
telemetry.workspace = true
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, UserStore};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{Entity, IntoElement, ParentElement};
|
||||
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use ui::prelude::*;
|
||||
@@ -56,15 +57,8 @@ impl AgentPanelOnboarding {
|
||||
|
||||
impl Render for AgentPanelOnboarding {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let enrolled_in_trial = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedProTrial)
|
||||
);
|
||||
|
||||
let is_pro_user = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedPro)
|
||||
);
|
||||
let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial);
|
||||
let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro);
|
||||
|
||||
AgentPanelOnboardingCard::new()
|
||||
.child(
|
||||
|
||||
@@ -9,6 +9,7 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider
|
||||
pub use agent_panel_onboarding_card::AgentPanelOnboardingCard;
|
||||
pub use agent_panel_onboarding_content::AgentPanelOnboarding;
|
||||
pub use ai_upsell_card::AiUpsellCard;
|
||||
use cloud_llm_client::Plan;
|
||||
pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
|
||||
pub use young_account_banner::YoungAccountBanner;
|
||||
|
||||
@@ -79,7 +80,7 @@ impl From<client::Status> for SignInStatus {
|
||||
pub struct ZedAiOnboarding {
|
||||
pub sign_in_status: SignInStatus,
|
||||
pub has_accepted_terms_of_service: bool,
|
||||
pub plan: Option<proto::Plan>,
|
||||
pub plan: Option<Plan>,
|
||||
pub account_too_young: bool,
|
||||
pub continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
@@ -99,8 +100,8 @@ impl ZedAiOnboarding {
|
||||
|
||||
Self {
|
||||
sign_in_status: status.into(),
|
||||
has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false),
|
||||
plan: store.current_plan(),
|
||||
has_accepted_terms_of_service: store.has_accepted_terms_of_service(),
|
||||
plan: store.plan(),
|
||||
account_too_young: store.account_too_young(),
|
||||
continue_with_zed_ai,
|
||||
accept_terms_of_service: Arc::new({
|
||||
@@ -113,11 +114,9 @@ impl ZedAiOnboarding {
|
||||
sign_in: Arc::new(move |_window, cx| {
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| {
|
||||
client.authenticate_and_connect(true, cx).await;
|
||||
}
|
||||
async move |cx| client.sign_in_with_optional_connect(true, cx).await
|
||||
})
|
||||
.detach();
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
dismiss_onboarding: None,
|
||||
}
|
||||
@@ -411,9 +410,9 @@ impl RenderOnce for ZedAiOnboarding {
|
||||
if matches!(self.sign_in_status, SignInStatus::SignedIn) {
|
||||
if self.has_accepted_terms_of_service {
|
||||
match self.plan {
|
||||
None | Some(proto::Plan::Free) => self.render_free_plan_state(cx),
|
||||
Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx),
|
||||
Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx),
|
||||
None | Some(Plan::ZedFree) => self.render_free_plan_state(cx),
|
||||
Some(Plan::ZedProTrial) => self.render_trial_state(cx),
|
||||
Some(Plan::ZedPro) => self.render_pro_plan_state(cx),
|
||||
}
|
||||
} else {
|
||||
self.render_accept_terms_of_service()
|
||||
@@ -433,7 +432,7 @@ impl Component for ZedAiOnboarding {
|
||||
fn onboarding(
|
||||
sign_in_status: SignInStatus,
|
||||
has_accepted_terms_of_service: bool,
|
||||
plan: Option<proto::Plan>,
|
||||
plan: Option<Plan>,
|
||||
account_too_young: bool,
|
||||
) -> AnyElement {
|
||||
ZedAiOnboarding {
|
||||
@@ -468,25 +467,15 @@ impl Component for ZedAiOnboarding {
|
||||
),
|
||||
single_example(
|
||||
"Free Plan",
|
||||
onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false),
|
||||
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false),
|
||||
),
|
||||
single_example(
|
||||
"Pro Trial",
|
||||
onboarding(
|
||||
SignInStatus::SignedIn,
|
||||
true,
|
||||
Some(proto::Plan::ZedProTrial),
|
||||
false,
|
||||
),
|
||||
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false),
|
||||
),
|
||||
single_example(
|
||||
"Pro Plan",
|
||||
onboarding(
|
||||
SignInStatus::SignedIn,
|
||||
true,
|
||||
Some(proto::Plan::ZedPro),
|
||||
false,
|
||||
),
|
||||
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false),
|
||||
),
|
||||
])
|
||||
.into_any_element(),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, zed_urls};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Divider, List, Vector, VectorName, prelude::*};
|
||||
|
||||
@@ -10,22 +11,22 @@ use crate::{BulletItem, SignInStatus};
|
||||
pub struct AiUpsellCard {
|
||||
pub sign_in_status: SignInStatus,
|
||||
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
pub user_plan: Option<Plan>,
|
||||
}
|
||||
|
||||
impl AiUpsellCard {
|
||||
pub fn new(client: Arc<Client>) -> Self {
|
||||
pub fn new(client: Arc<Client>, user_plan: Option<Plan>) -> Self {
|
||||
let status = *client.status().borrow();
|
||||
|
||||
Self {
|
||||
user_plan,
|
||||
sign_in_status: status.into(),
|
||||
sign_in: Arc::new(move |_window, cx| {
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| {
|
||||
client.authenticate_and_connect(true, cx).await;
|
||||
}
|
||||
async move |cx| client.sign_in_with_optional_connect(true, cx).await
|
||||
})
|
||||
.detach();
|
||||
.detach_and_log_err(cx);
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -34,6 +35,7 @@ impl AiUpsellCard {
|
||||
impl RenderOnce for AiUpsellCard {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let pro_section = v_flex()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
@@ -56,6 +58,7 @@ impl RenderOnce for AiUpsellCard {
|
||||
);
|
||||
|
||||
let free_section = v_flex()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
@@ -71,7 +74,7 @@ impl RenderOnce for AiUpsellCard {
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts with the Claude models"))
|
||||
.child(BulletItem::new("50 prompts with Claude models"))
|
||||
.child(BulletItem::new("2,000 accepted edit predictions")),
|
||||
);
|
||||
|
||||
@@ -132,22 +135,28 @@ impl RenderOnce for AiUpsellCard {
|
||||
|
||||
v_flex()
|
||||
.relative()
|
||||
.p_6()
|
||||
.pt_4()
|
||||
.p_4()
|
||||
.pt_3()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(grid_bg)
|
||||
.child(gradient_bg)
|
||||
.child(Headline::new("Try Zed AI"))
|
||||
.child(Label::new(DESCRIPTION).color(Color::Muted).mb_2())
|
||||
.child(Label::new("Try Zed AI").size(LabelSize::Large))
|
||||
.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(DESCRIPTION).color(Color::Muted)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.mt_1p5()
|
||||
.mb_2p5()
|
||||
.items_start()
|
||||
.gap_12()
|
||||
.gap_6()
|
||||
.child(free_section)
|
||||
.child(pro_section),
|
||||
)
|
||||
@@ -183,6 +192,7 @@ impl Component for AiUpsellCard {
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedOut,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: None,
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
@@ -191,6 +201,7 @@ impl Component for AiUpsellCard {
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: None,
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
|
||||
@@ -126,7 +126,7 @@ impl ChannelMembership {
|
||||
proto::channel_member::Kind::Member => 0,
|
||||
proto::channel_member::Kind::Invitee => 1,
|
||||
},
|
||||
username_order: self.user.github_login.as_str(),
|
||||
username_order: &self.user.github_login,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,20 +259,6 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx);
|
||||
});
|
||||
|
||||
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
|
||||
assert_eq!(get_users.payload.user_ids, vec![5]);
|
||||
server.respond(
|
||||
get_users.receipt(),
|
||||
proto::UsersResponse {
|
||||
users: vec![proto::User {
|
||||
id: 5,
|
||||
github_login: "nathansobo".into(),
|
||||
avatar_url: "http://avatar.com/nathansobo".into(),
|
||||
name: None,
|
||||
}],
|
||||
},
|
||||
);
|
||||
|
||||
// Join a channel and populate its existing messages.
|
||||
let channel = channel_store.update(cx, |store, cx| {
|
||||
let channel_id = store.ordered_channels().next().unwrap().1.id;
|
||||
@@ -334,7 +320,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
&[
|
||||
("nathansobo".into(), "a".into()),
|
||||
("user-5".into(), "a".into()),
|
||||
("maxbrunsfeld".into(), "b".into())
|
||||
]
|
||||
);
|
||||
@@ -437,7 +423,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
&[
|
||||
("nathansobo".into(), "y".into()),
|
||||
("user-5".into(), "y".into()),
|
||||
("maxbrunsfeld".into(), "z".into())
|
||||
]
|
||||
);
|
||||
|
||||
@@ -17,7 +17,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
|
||||
base64.workspace = true
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod cloud;
|
||||
mod proxy;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use async_recursion::async_recursion;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use async_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
error::Error as WebsocketError,
|
||||
http::{HeaderValue, Request, StatusCode},
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use clock::SystemClock;
|
||||
use cloud_api_client::CloudApiClient;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
@@ -23,7 +20,7 @@ use futures::{
|
||||
channel::oneshot, future::BoxFuture,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl, http};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http};
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
@@ -53,7 +50,6 @@ use tokio::net::TcpStream;
|
||||
use url::Url;
|
||||
use util::{ConnectionResult, ResultExt};
|
||||
|
||||
pub use cloud::*;
|
||||
pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
pub use user::*;
|
||||
@@ -165,20 +161,8 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
|
||||
let client = client.clone();
|
||||
move |_: &SignIn, cx| {
|
||||
if let Some(client) = client.upgrade() {
|
||||
cx.spawn(
|
||||
async move |cx| match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("Initial authentication timed out");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::error!("Initial authentication connection reset");
|
||||
}
|
||||
ConnectionResult::Result(r) => {
|
||||
r.log_err();
|
||||
}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -287,6 +271,8 @@ pub enum Status {
|
||||
SignedOut,
|
||||
UpgradeRequired,
|
||||
Authenticating,
|
||||
Authenticated,
|
||||
AuthenticationError,
|
||||
Connecting,
|
||||
ConnectionError,
|
||||
Connected {
|
||||
@@ -713,7 +699,7 @@ impl Client {
|
||||
|
||||
let mut delay = INITIAL_RECONNECTION_DELAY;
|
||||
loop {
|
||||
match client.authenticate_and_connect(true, &cx).await {
|
||||
match client.connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("client connect attempt timed out")
|
||||
}
|
||||
@@ -883,17 +869,122 @@ impl Client {
|
||||
.is_some()
|
||||
}
|
||||
|
||||
#[async_recursion(?Send)]
|
||||
pub async fn authenticate_and_connect(
|
||||
pub async fn sign_in(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Credentials> {
|
||||
if self.status().borrow().is_signed_out() {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx);
|
||||
}
|
||||
|
||||
let mut credentials = None;
|
||||
|
||||
let old_credentials = self.state.read().credentials.clone();
|
||||
if let Some(old_credentials) = old_credentials {
|
||||
self.cloud_client.set_credentials(
|
||||
old_credentials.user_id as u32,
|
||||
old_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(old_credentials);
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() && try_provider {
|
||||
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
|
||||
self.cloud_client.set_credentials(
|
||||
stored_credentials.user_id as u32,
|
||||
stored_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the stored credentials, and
|
||||
// clear them from the credentials provider if that fails.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(stored_credentials);
|
||||
} else {
|
||||
self.credentials_provider
|
||||
.delete_credentials(cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => {
|
||||
if IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider
|
||||
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
credentials = Some(creds);
|
||||
},
|
||||
Err(err) => {
|
||||
self.set_status(Status::AuthenticationError, cx);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return Err(anyhow!("authentication canceled"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
self.set_status(Status::Authenticated, cx);
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Performs a sign-in and also connects to Collab.
|
||||
///
|
||||
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
|
||||
/// to `sign_in` when we're ready to remove auto-connection to Collab.
|
||||
pub async fn sign_in_with_optional_connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let credentials = self.sign_in(try_provider, cx).await?;
|
||||
|
||||
let connect_result = match self.connect_with_credentials(credentials, cx).await {
|
||||
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
|
||||
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
|
||||
ConnectionResult::Result(result) => result.context("client auth and connect"),
|
||||
};
|
||||
connect_result.log_err();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::SignedOut => true,
|
||||
Status::SignedOut | Status::Authenticated => true,
|
||||
Status::ConnectionError
|
||||
| Status::ConnectionLost
|
||||
| Status::Authenticating { .. }
|
||||
| Status::AuthenticationError
|
||||
| Status::Reauthenticating { .. }
|
||||
| Status::ReconnectionError { .. } => false,
|
||||
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
|
||||
@@ -906,41 +997,10 @@ impl Client {
|
||||
);
|
||||
}
|
||||
};
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx)
|
||||
}
|
||||
|
||||
let mut read_from_provider = false;
|
||||
let mut credentials = self.state.read().credentials.clone();
|
||||
if credentials.is_none() && try_provider {
|
||||
credentials = self.credentials_provider.read_credentials(cx).await;
|
||||
read_from_provider = credentials.is_some();
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => credentials = Some(creds),
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return ConnectionResult::Result(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
|
||||
}
|
||||
}
|
||||
}
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
let credentials = match self.sign_in(try_provider, cx).await {
|
||||
Ok(credentials) => credentials,
|
||||
Err(err) => return ConnectionResult::Result(Err(err)),
|
||||
};
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting, cx);
|
||||
@@ -948,17 +1008,20 @@ impl Client {
|
||||
self.set_status(Status::Reconnecting, cx);
|
||||
}
|
||||
|
||||
self.connect_with_credentials(credentials, cx).await
|
||||
}
|
||||
|
||||
async fn connect_with_credentials(
|
||||
self: &Arc<Self>,
|
||||
credentials: Credentials,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let mut timeout =
|
||||
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
|
||||
futures::select_biased! {
|
||||
connection = self.establish_connection(&credentials, cx).fuse() => {
|
||||
match connection {
|
||||
Ok(conn) => {
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
|
||||
}
|
||||
|
||||
futures::select_biased! {
|
||||
result = self.set_connection(conn, cx).fuse() => {
|
||||
match result.context("client auth and connect") {
|
||||
@@ -976,15 +1039,8 @@ impl Client {
|
||||
}
|
||||
}
|
||||
Err(EstablishConnectionError::Unauthorized) => {
|
||||
self.state.write().credentials.take();
|
||||
if read_from_provider {
|
||||
self.credentials_provider.delete_credentials(cx).await.log_err();
|
||||
self.set_status(Status::SignedOut, cx);
|
||||
self.authenticate_and_connect(false, cx).await
|
||||
} else {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
Err(EstablishConnectionError::UpgradeRequired) => {
|
||||
self.set_status(Status::UpgradeRequired, cx);
|
||||
@@ -1379,96 +1435,31 @@ impl Client {
|
||||
self: &Arc<Self>,
|
||||
http: Arc<HttpClientWithUrl>,
|
||||
login: String,
|
||||
mut api_token: String,
|
||||
api_token: String,
|
||||
) -> Result<Credentials> {
|
||||
#[derive(Deserialize)]
|
||||
struct AuthenticatedUserResponse {
|
||||
user: User,
|
||||
#[derive(Serialize)]
|
||||
struct ImpersonateUserBody {
|
||||
github_login: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct User {
|
||||
id: u64,
|
||||
struct ImpersonateUserResponse {
|
||||
user_id: u64,
|
||||
access_token: String,
|
||||
}
|
||||
|
||||
let github_user = {
|
||||
#[derive(Deserialize)]
|
||||
struct GithubUser {
|
||||
id: i32,
|
||||
login: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
let request = {
|
||||
let mut request_builder =
|
||||
Request::get(&format!("https://api.github.com/users/{login}"));
|
||||
if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", github_token));
|
||||
}
|
||||
|
||||
request_builder.body(AsyncBody::empty())?
|
||||
};
|
||||
|
||||
let mut response = http
|
||||
.send(request)
|
||||
.await
|
||||
.context("error fetching GitHub user")?;
|
||||
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("error reading GitHub user")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = String::from_utf8_lossy(body.as_slice());
|
||||
bail!(
|
||||
"status error {}, response: {text:?}",
|
||||
response.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
|
||||
log::error!("Error deserializing: {:?}", err);
|
||||
log::error!(
|
||||
"GitHub API response text: {:?}",
|
||||
String::from_utf8_lossy(body.as_slice())
|
||||
);
|
||||
anyhow!("error deserializing GitHub user")
|
||||
})?
|
||||
};
|
||||
|
||||
let query_params = [
|
||||
("github_login", &github_user.login),
|
||||
("github_user_id", &github_user.id.to_string()),
|
||||
(
|
||||
"github_user_created_at",
|
||||
&github_user.created_at.to_rfc3339(),
|
||||
),
|
||||
];
|
||||
|
||||
// Use the collab server's admin API to retrieve the ID
|
||||
// of the impersonated user.
|
||||
let mut url = self.rpc_url(http.clone(), None).await?;
|
||||
url.set_path("/user");
|
||||
url.set_query(Some(
|
||||
&query_params
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
format!(
|
||||
"{}={}",
|
||||
key,
|
||||
url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("&"),
|
||||
));
|
||||
let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
|
||||
.header("Authorization", format!("token {api_token}"))
|
||||
.body("".into())?;
|
||||
let url = self
|
||||
.http
|
||||
.build_zed_cloud_url("/internal/users/impersonate", &[])?;
|
||||
let request = Request::post(url.as_str())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {api_token}"))
|
||||
.body(
|
||||
serde_json::to_string(&ImpersonateUserBody {
|
||||
github_login: login,
|
||||
})?
|
||||
.into(),
|
||||
)?;
|
||||
|
||||
let mut response = http.send(request).await?;
|
||||
let mut body = String::new();
|
||||
@@ -1479,18 +1470,17 @@ impl Client {
|
||||
response.status().as_u16(),
|
||||
body,
|
||||
);
|
||||
let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
|
||||
let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
|
||||
|
||||
// Use the admin API token to authenticate as the impersonated user.
|
||||
api_token.insert_str(0, "ADMIN_TOKEN:");
|
||||
Ok(Credentials {
|
||||
user_id: response.user.id,
|
||||
access_token: api_token,
|
||||
user_id: response.user_id,
|
||||
access_token: response.access_token,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
|
||||
self.state.write().credentials = None;
|
||||
self.cloud_client.clear_credentials();
|
||||
self.disconnect(cx);
|
||||
|
||||
if self.has_credentials(cx).await {
|
||||
@@ -1800,7 +1790,7 @@ mod tests {
|
||||
});
|
||||
let auth_and_connect = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert!(matches!(status.next().await, Some(Status::Connecting)));
|
||||
@@ -1877,7 +1867,7 @@ mod tests {
|
||||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
move |cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 1);
|
||||
@@ -1885,7 +1875,7 @@ mod tests {
|
||||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 2);
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
mod user_store;
|
||||
|
||||
pub use user_store::*;
|
||||
@@ -1,41 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use cloud_api_client::{AuthenticatedUser, CloudApiClient};
|
||||
use gpui::{Context, Task};
|
||||
use util::{ResultExt as _, maybe};
|
||||
|
||||
pub struct CloudUserStore {
|
||||
authenticated_user: Option<AuthenticatedUser>,
|
||||
_fetch_authenticated_user_task: Task<()>,
|
||||
}
|
||||
|
||||
impl CloudUserStore {
|
||||
pub fn new(cloud_client: Arc<CloudApiClient>, cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
authenticated_user: None,
|
||||
_fetch_authenticated_user_task: cx.spawn(async move |this, cx| {
|
||||
maybe!(async move {
|
||||
loop {
|
||||
if cloud_client.has_credentials() {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(100))
|
||||
.await;
|
||||
}
|
||||
|
||||
let response = cloud_client.get_authenticated_user().await?;
|
||||
this.update(cx, |this, _cx| {
|
||||
this.authenticated_user = Some(response.user);
|
||||
})
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch authenticated user")
|
||||
.log_err();
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::Duration;
|
||||
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
|
||||
use futures::{StreamExt, stream::BoxStream};
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
|
||||
use http_client::{AsyncBody, Method, Request, http};
|
||||
use parking_lot::Mutex;
|
||||
use rpc::{
|
||||
ConnectionId, Peer, Receipt, TypedEnvelope,
|
||||
@@ -39,6 +42,44 @@ impl FakeServer {
|
||||
executor: cx.executor(),
|
||||
};
|
||||
|
||||
client.http_client().as_fake().replace_handler({
|
||||
let state = server.state.clone();
|
||||
move |old_handler, req| {
|
||||
let state = state.clone();
|
||||
let old_handler = old_handler.clone();
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/client/users/me") => {
|
||||
let credentials = parse_authorization_header(&req);
|
||||
if credentials
|
||||
!= Some(Credentials {
|
||||
user_id: client_user_id,
|
||||
access_token: state.lock().access_token.to_string(),
|
||||
})
|
||||
{
|
||||
return Ok(http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
Ok(http_client::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
serde_json::to_string(&make_get_authenticated_user_response(
|
||||
client_user_id as i32,
|
||||
format!("user-{client_user_id}"),
|
||||
))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
_ => old_handler(req).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
client
|
||||
.override_authenticate({
|
||||
let state = Arc::downgrade(&server.state);
|
||||
@@ -105,7 +146,7 @@ impl FakeServer {
|
||||
});
|
||||
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -223,3 +264,54 @@ impl Drop for FakeServer {
|
||||
self.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
|
||||
let mut auth_header = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)?
|
||||
.to_str()
|
||||
.ok()?
|
||||
.split_whitespace();
|
||||
let user_id = auth_header.next()?.parse().ok()?;
|
||||
let access_token = auth_header.next()?;
|
||||
Some(Credentials {
|
||||
user_id,
|
||||
access_token: access_token.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_get_authenticated_user_response(
|
||||
user_id: i32,
|
||||
github_login: String,
|
||||
) -> GetAuthenticatedUserResponse {
|
||||
GetAuthenticatedUserResponse {
|
||||
user: AuthenticatedUser {
|
||||
id: user_id,
|
||||
metrics_id: format!("metrics-id-{user_id}"),
|
||||
avatar_url: "".to_string(),
|
||||
github_login,
|
||||
name: None,
|
||||
is_staff: false,
|
||||
accepted_tos_at: None,
|
||||
},
|
||||
feature_flags: vec![],
|
||||
plan: PlanInfo {
|
||||
plan: Plan::ZedPro,
|
||||
subscription_period: None,
|
||||
usage: CurrentUsage {
|
||||
model_requests: UsageData {
|
||||
used: 0,
|
||||
limit: UsageLimit::Limited(500),
|
||||
},
|
||||
edit_predictions: UsageData {
|
||||
used: 250,
|
||||
limit: UsageLimit::Unlimited,
|
||||
},
|
||||
},
|
||||
trial_started_at: None,
|
||||
is_usage_based_billing_enabled: false,
|
||||
is_account_too_young: false,
|
||||
has_overdue_invoices: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::{Client, Status, TypedEnvelope, proto};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{
|
||||
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||
@@ -20,7 +21,7 @@ use std::{
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use text::ReplicaId;
|
||||
use util::{TryFutureExt as _, maybe};
|
||||
use util::{ResultExt, TryFutureExt as _};
|
||||
|
||||
pub type UserId = u64;
|
||||
|
||||
@@ -55,7 +56,7 @@ pub struct ParticipantIndex(pub u32);
|
||||
#[derive(Default, Debug)]
|
||||
pub struct User {
|
||||
pub id: UserId,
|
||||
pub github_login: String,
|
||||
pub github_login: SharedString,
|
||||
pub avatar_uri: SharedUri,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
@@ -107,19 +108,14 @@ pub enum ContactRequestStatus {
|
||||
|
||||
pub struct UserStore {
|
||||
users: HashMap<u64, Arc<User>>,
|
||||
by_github_login: HashMap<String, u64>,
|
||||
by_github_login: HashMap<SharedString, u64>,
|
||||
participant_indices: HashMap<u64, ParticipantIndex>,
|
||||
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
|
||||
current_plan: Option<proto::Plan>,
|
||||
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
trial_started_at: Option<DateTime<Utc>>,
|
||||
model_request_usage: Option<ModelRequestUsage>,
|
||||
edit_prediction_usage: Option<EditPredictionUsage>,
|
||||
is_usage_based_billing_enabled: Option<bool>,
|
||||
account_too_young: Option<bool>,
|
||||
has_overdue_invoices: Option<bool>,
|
||||
plan_info: Option<PlanInfo>,
|
||||
current_user: watch::Receiver<Option<Arc<User>>>,
|
||||
accepted_tos_at: Option<Option<DateTime<Utc>>>,
|
||||
accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
|
||||
contacts: Vec<Arc<Contact>>,
|
||||
incoming_contact_requests: Vec<Arc<User>>,
|
||||
outgoing_contact_requests: Vec<Arc<User>>,
|
||||
@@ -145,6 +141,7 @@ pub enum Event {
|
||||
ShowContacts,
|
||||
ParticipantIndicesChanged,
|
||||
PrivateUserInfoUpdated,
|
||||
PlanUpdated,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -188,14 +185,9 @@ impl UserStore {
|
||||
users: Default::default(),
|
||||
by_github_login: Default::default(),
|
||||
current_user: current_user_rx,
|
||||
current_plan: None,
|
||||
subscription_period: None,
|
||||
trial_started_at: None,
|
||||
plan_info: None,
|
||||
model_request_usage: None,
|
||||
edit_prediction_usage: None,
|
||||
is_usage_based_billing_enabled: None,
|
||||
account_too_young: None,
|
||||
has_overdue_invoices: None,
|
||||
accepted_tos_at: None,
|
||||
contacts: Default::default(),
|
||||
incoming_contact_requests: Default::default(),
|
||||
@@ -225,53 +217,30 @@ impl UserStore {
|
||||
return Ok(());
|
||||
};
|
||||
match status {
|
||||
Status::Connected { .. } => {
|
||||
Status::Authenticated | Status::Connected { .. } => {
|
||||
if let Some(user_id) = client.user_id() {
|
||||
let fetch_user = if let Ok(fetch_user) =
|
||||
this.update(cx, |this, cx| this.get_user(user_id, cx).log_err())
|
||||
{
|
||||
fetch_user
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
let fetch_private_user_info =
|
||||
client.request(proto::GetPrivateUserInfo {}).log_err();
|
||||
let (user, info) =
|
||||
futures::join!(fetch_user, fetch_private_user_info);
|
||||
|
||||
let response = client.cloud_client().get_authenticated_user().await;
|
||||
let mut current_user = None;
|
||||
cx.update(|cx| {
|
||||
if let Some(info) = info {
|
||||
let staff =
|
||||
info.staff && !*feature_flags::ZED_DISABLE_STAFF;
|
||||
cx.update_flags(staff, info.flags);
|
||||
client.telemetry.set_authenticated_user_info(
|
||||
Some(info.metrics_id.clone()),
|
||||
staff,
|
||||
);
|
||||
|
||||
if let Some(response) = response.log_err() {
|
||||
let user = Arc::new(User {
|
||||
id: user_id,
|
||||
github_login: response.user.github_login.clone().into(),
|
||||
avatar_uri: response.user.avatar_url.clone().into(),
|
||||
name: response.user.name.clone(),
|
||||
});
|
||||
current_user = Some(user.clone());
|
||||
this.update(cx, |this, cx| {
|
||||
let accepted_tos_at = {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
info.accepted_tos_at
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
info.accepted_tos_at
|
||||
};
|
||||
|
||||
this.set_current_user_accepted_tos_at(accepted_tos_at);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
this.by_github_login
|
||||
.insert(user.github_login.clone(), user_id);
|
||||
this.users.insert(user_id, user);
|
||||
this.update_authenticated_user(response, cx)
|
||||
})
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
|
||||
current_user_tx.send(user).await.ok();
|
||||
current_user_tx.send(current_user).await.ok();
|
||||
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
@@ -352,59 +321,22 @@ impl UserStore {
|
||||
|
||||
async fn handle_update_plan(
|
||||
this: Entity<Self>,
|
||||
message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
_message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = this
|
||||
.read_with(&cx, |this, _| this.client.upgrade())?
|
||||
.context("client was dropped")?;
|
||||
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user")?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.current_plan = Some(message.payload.plan());
|
||||
this.subscription_period = maybe!({
|
||||
let period = message.payload.subscription_period?;
|
||||
let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?;
|
||||
let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?;
|
||||
|
||||
Some((started_at, ended_at))
|
||||
});
|
||||
this.trial_started_at = message
|
||||
.payload
|
||||
.trial_started_at
|
||||
.and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
|
||||
this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
|
||||
this.account_too_young = message.payload.account_too_young;
|
||||
this.has_overdue_invoices = message.payload.has_overdue_invoices;
|
||||
|
||||
if let Some(usage) = message.payload.usage {
|
||||
// limits are always present even though they are wrapped in Option
|
||||
this.model_request_usage = usage
|
||||
.model_requests_usage_limit
|
||||
.and_then(|limit| {
|
||||
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
||||
})
|
||||
.map(ModelRequestUsage);
|
||||
this.edit_prediction_usage = usage
|
||||
.edit_predictions_usage_limit
|
||||
.and_then(|limit| {
|
||||
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
||||
})
|
||||
.map(EditPredictionUsage);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
|
||||
self.model_request_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_usage(
|
||||
&mut self,
|
||||
usage: EditPredictionUsage,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.edit_prediction_usage = Some(usage);
|
||||
cx.notify();
|
||||
this.update_authenticated_user(response, cx);
|
||||
})
|
||||
}
|
||||
|
||||
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
@@ -763,59 +695,131 @@ impl UserStore {
|
||||
self.current_user.borrow().clone()
|
||||
}
|
||||
|
||||
pub fn current_plan(&self) -> Option<proto::Plan> {
|
||||
pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
|
||||
#[cfg(debug_assertions)]
|
||||
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
|
||||
return match plan.as_str() {
|
||||
"free" => Some(proto::Plan::Free),
|
||||
"trial" => Some(proto::Plan::ZedProTrial),
|
||||
"pro" => Some(proto::Plan::ZedPro),
|
||||
"free" => Some(cloud_llm_client::Plan::ZedFree),
|
||||
"trial" => Some(cloud_llm_client::Plan::ZedProTrial),
|
||||
"pro" => Some(cloud_llm_client::Plan::ZedPro),
|
||||
_ => {
|
||||
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
self.current_plan
|
||||
self.plan_info.as_ref().map(|info| info.plan)
|
||||
}
|
||||
|
||||
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
|
||||
self.subscription_period
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.subscription_period)
|
||||
.map(|subscription_period| {
|
||||
(
|
||||
subscription_period.started_at.0,
|
||||
subscription_period.ended_at.0,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
|
||||
self.trial_started_at
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.0)
|
||||
}
|
||||
|
||||
pub fn usage_based_billing_enabled(&self) -> Option<bool> {
|
||||
self.is_usage_based_billing_enabled
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_account_too_young)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||
pub fn has_overdue_invoices(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.has_overdue_invoices)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn is_usage_based_billing_enabled(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_usage_based_billing_enabled)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
|
||||
self.model_request_usage
|
||||
}
|
||||
|
||||
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
|
||||
self.model_request_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
|
||||
self.edit_prediction_usage
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_usage(
|
||||
&mut self,
|
||||
usage: EditPredictionUsage,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.edit_prediction_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn update_authenticated_user(
|
||||
&mut self,
|
||||
response: GetAuthenticatedUserResponse,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
|
||||
cx.update_flags(staff, response.feature_flags);
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
client
|
||||
.telemetry
|
||||
.set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
|
||||
}
|
||||
|
||||
let accepted_tos_at = {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
|
||||
None
|
||||
} else {
|
||||
response.user.accepted_tos_at
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
response.user.accepted_tos_at
|
||||
};
|
||||
|
||||
self.accepted_tos_at = Some(accepted_tos_at);
|
||||
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
|
||||
limit: response.plan.usage.model_requests.limit,
|
||||
amount: response.plan.usage.model_requests.used as i32,
|
||||
}));
|
||||
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
|
||||
limit: response.plan.usage.edit_predictions.limit,
|
||||
amount: response.plan.usage.edit_predictions.used as i32,
|
||||
}));
|
||||
self.plan_info = Some(response.plan);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
}
|
||||
|
||||
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
|
||||
self.current_user.clone()
|
||||
}
|
||||
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.account_too_young.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||
pub fn has_overdue_invoices(&self) -> bool {
|
||||
self.has_overdue_invoices.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
|
||||
pub fn has_accepted_terms_of_service(&self) -> bool {
|
||||
self.accepted_tos_at
|
||||
.map(|accepted_tos_at| accepted_tos_at.is_some())
|
||||
.map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
|
||||
}
|
||||
|
||||
pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
@@ -827,23 +831,18 @@ impl UserStore {
|
||||
cx.spawn(async move |this, cx| -> anyhow::Result<()> {
|
||||
let client = client.upgrade().context("client not found")?;
|
||||
let response = client
|
||||
.request(proto::AcceptTermsOfService {})
|
||||
.cloud_client()
|
||||
.accept_terms_of_service()
|
||||
.await
|
||||
.context("error accepting tos")?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
|
||||
this.accepted_tos_at = Some(response.user.accepted_tos_at);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
|
||||
self.accepted_tos_at = Some(
|
||||
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
|
||||
);
|
||||
}
|
||||
|
||||
fn load_users(
|
||||
&self,
|
||||
request: impl RequestMessage<Response = UsersResponse>,
|
||||
@@ -902,7 +901,7 @@ impl UserStore {
|
||||
let mut missing_user_ids = Vec::new();
|
||||
for id in user_ids {
|
||||
if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) {
|
||||
ret.insert(id, github_login.into());
|
||||
ret.insert(id, github_login);
|
||||
} else {
|
||||
missing_user_ids.push(id)
|
||||
}
|
||||
@@ -923,7 +922,7 @@ impl User {
|
||||
fn new(message: proto::User) -> Arc<Self> {
|
||||
Arc::new(User {
|
||||
id: message.id,
|
||||
github_login: message.github_login,
|
||||
github_login: message.github_login.into(),
|
||||
avatar_uri: message.avatar_url.into(),
|
||||
name: message.name,
|
||||
})
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use cloud_api_types::*;
|
||||
use futures::AsyncReadExt as _;
|
||||
use http_client::http::request;
|
||||
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
@@ -35,6 +36,10 @@ impl CloudApiClient {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn clear_credentials(&self) {
|
||||
*self.credentials.write() = None;
|
||||
}
|
||||
|
||||
fn authorization_header(&self) -> Result<String> {
|
||||
let guard = self.credentials.read();
|
||||
let credentials = guard
|
||||
@@ -47,17 +52,26 @@ impl CloudApiClient {
|
||||
))
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
&self,
|
||||
req: request::Builder,
|
||||
body: impl Into<AsyncBody>,
|
||||
) -> Result<Request<AsyncBody>> {
|
||||
Ok(req
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", self.authorization_header()?)
|
||||
.body(body.into())?)
|
||||
}
|
||||
|
||||
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
|
||||
let request = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(
|
||||
let request = self.build_request(
|
||||
Request::builder().method(Method::GET).uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/users/me", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", self.authorization_header()?)
|
||||
.body(AsyncBody::default())?;
|
||||
),
|
||||
AsyncBody::default(),
|
||||
)?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
@@ -76,4 +90,66 @@ impl CloudApiClient {
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
pub async fn accept_terms_of_service(&self) -> Result<AcceptTermsOfServiceResponse> {
|
||||
let request = self.build_request(
|
||||
Request::builder().method(Method::POST).uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/terms_of_service/accept", &[])?
|
||||
.as_ref(),
|
||||
),
|
||||
AsyncBody::default(),
|
||||
)?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to accept terms of service.\nStatus: {:?}\nBody: {body}",
|
||||
response.status()
|
||||
)
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
pub async fn create_llm_token(
|
||||
&self,
|
||||
system_id: Option<String>,
|
||||
) -> Result<CreateLlmTokenResponse> {
|
||||
let mut request_builder = Request::builder().method(Method::POST).uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/llm_tokens", &[])?
|
||||
.as_ref(),
|
||||
);
|
||||
|
||||
if let Some(system_id) = system_id {
|
||||
request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id);
|
||||
}
|
||||
|
||||
let request = self.build_request(request_builder, AsyncBody::default())?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to create LLM token.\nStatus: {:?}\nBody: {body}",
|
||||
response.status()
|
||||
)
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,5 +12,11 @@ workspace = true
|
||||
path = "src/cloud_api_types.rs"
|
||||
|
||||
[dependencies]
|
||||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
serde.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -1,14 +1,55 @@
|
||||
mod timestamp;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use crate::timestamp::Timestamp;
|
||||
|
||||
pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id";
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct GetAuthenticatedUserResponse {
|
||||
pub user: AuthenticatedUser,
|
||||
pub feature_flags: Vec<String>,
|
||||
pub plan: PlanInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AuthenticatedUser {
|
||||
pub id: i32,
|
||||
pub metrics_id: String,
|
||||
pub avatar_url: String,
|
||||
pub github_login: String,
|
||||
pub name: Option<String>,
|
||||
pub is_staff: bool,
|
||||
pub accepted_tos_at: Option<Timestamp>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PlanInfo {
|
||||
pub plan: cloud_llm_client::Plan,
|
||||
pub subscription_period: Option<SubscriptionPeriod>,
|
||||
pub usage: cloud_llm_client::CurrentUsage,
|
||||
pub trial_started_at: Option<Timestamp>,
|
||||
pub is_usage_based_billing_enabled: bool,
|
||||
pub is_account_too_young: bool,
|
||||
pub has_overdue_invoices: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct SubscriptionPeriod {
|
||||
pub started_at: Timestamp,
|
||||
pub ended_at: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AcceptTermsOfServiceResponse {
|
||||
pub user: AuthenticatedUser,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToken(pub String);
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateLlmTokenResponse {
|
||||
pub token: LlmToken,
|
||||
}
|
||||
|
||||
166
crates/cloud_api_types/src/timestamp.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
/// A timestamp with a serialized representation in RFC 3339 format.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub struct Timestamp(pub DateTime<Utc>);
|
||||
|
||||
impl Timestamp {
|
||||
pub fn new(datetime: DateTime<Utc>) -> Self {
|
||||
Self(datetime)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DateTime<Utc>> for Timestamp {
|
||||
fn from(value: DateTime<Utc>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NaiveDateTime> for Timestamp {
|
||||
fn from(value: NaiveDateTime) -> Self {
|
||||
Self(value.and_utc())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Timestamp {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true);
|
||||
serializer.serialize_str(&rfc3339_string)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Timestamp {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = String::deserialize(deserializer)?;
|
||||
let datetime = DateTime::parse_from_rfc3339(&value)
|
||||
.map_err(serde::de::Error::custom)?
|
||||
.to_utc();
|
||||
Ok(Self(datetime))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::NaiveDate;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_serialization() {
|
||||
let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
let timestamp = Timestamp::new(datetime);
|
||||
|
||||
let json = serde_json::to_string(×tamp).unwrap();
|
||||
assert_eq!(json, "\"2023-12-25T14:30:45.123Z\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization() {
|
||||
let json = "\"2023-12-25T14:30:45.123Z\"";
|
||||
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
|
||||
|
||||
let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_roundtrip() {
|
||||
let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
|
||||
let timestamp = Timestamp::new(original);
|
||||
let json = serde_json::to_string(×tamp).unwrap();
|
||||
let deserialized: Timestamp = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.0, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_from_datetime_utc() {
|
||||
let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
|
||||
let timestamp = Timestamp::from(datetime);
|
||||
assert_eq!(timestamp.0, datetime);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_from_naive_datetime() {
|
||||
let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_milli_opt(14, 30, 45, 123)
|
||||
.unwrap();
|
||||
|
||||
let timestamp = Timestamp::from(naive_dt);
|
||||
let expected = naive_dt.and_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_serialization_with_microseconds() {
|
||||
// Test that microseconds are truncated to milliseconds
|
||||
let datetime = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_micro_opt(14, 30, 45, 123456)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
let timestamp = Timestamp::new(datetime);
|
||||
let json = serde_json::to_string(×tamp).unwrap();
|
||||
|
||||
// Should be truncated to milliseconds
|
||||
assert_eq!(json, "\"2023-12-25T14:30:45.123Z\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization_without_milliseconds() {
|
||||
let json = "\"2023-12-25T14:30:45Z\"";
|
||||
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
|
||||
|
||||
let expected = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_opt(14, 30, 45)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization_with_timezone() {
|
||||
let json = "\"2023-12-25T14:30:45.123+05:30\"";
|
||||
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
|
||||
|
||||
// Should be converted to UTC
|
||||
let expected = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization_with_invalid_format() {
|
||||
let json = "\"invalid-date\"";
|
||||
let result: Result<Timestamp, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -308,13 +308,13 @@ pub struct GetSubscriptionResponse {
|
||||
pub usage: Option<CurrentUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CurrentUsage {
|
||||
pub model_requests: UsageData,
|
||||
pub edit_predictions: UsageData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct UsageData {
|
||||
pub used: u32,
|
||||
pub limit: UsageLimit,
|
||||
|
||||
@@ -100,7 +100,6 @@ impl std::fmt::Display for SystemIdHeader {
|
||||
|
||||
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/user", get(legacy_update_or_create_authenticated_user))
|
||||
.route("/users/look_up", get(look_up_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
||||
@@ -145,51 +144,6 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
||||
Ok::<_, Error>(next.run(req).await)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthenticatedUserParams {
|
||||
github_user_id: i32,
|
||||
github_login: String,
|
||||
github_email: Option<String>,
|
||||
github_name: Option<String>,
|
||||
github_user_created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AuthenticatedUserResponse {
|
||||
user: User,
|
||||
metrics_id: String,
|
||||
feature_flags: Vec<String>,
|
||||
}
|
||||
|
||||
/// This is a legacy endpoint that is no longer used in production.
|
||||
///
|
||||
/// It currently only exists to be used when developing Collab locally.
|
||||
async fn legacy_update_or_create_authenticated_user(
|
||||
Query(params): Query<AuthenticatedUserParams>,
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
) -> Result<Json<AuthenticatedUserResponse>> {
|
||||
let initial_channel_id = app.config.auto_join_channel_id;
|
||||
|
||||
let user = app
|
||||
.db
|
||||
.update_or_create_user_by_github_account(
|
||||
¶ms.github_login,
|
||||
params.github_user_id,
|
||||
params.github_email.as_deref(),
|
||||
params.github_name.as_deref(),
|
||||
params.github_user_created_at,
|
||||
initial_channel_id,
|
||||
)
|
||||
.await?;
|
||||
let metrics_id = app.db.get_user_metrics_id(user.id).await?;
|
||||
let feature_flags = app.db.get_user_flags(user.id).await?;
|
||||
Ok(Json(AuthenticatedUserResponse {
|
||||
user,
|
||||
metrics_id,
|
||||
feature_flags,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LookUpUserParams {
|
||||
identifier: String,
|
||||
|
||||
@@ -42,7 +42,7 @@ use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::proto::split_repository_update;
|
||||
use rpc::proto::{MultiLspQuery, split_repository_update};
|
||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||
|
||||
use futures::{
|
||||
@@ -374,7 +374,7 @@ impl Server {
|
||||
.add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
|
||||
.add_request_handler(multi_lsp_query)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
|
||||
@@ -838,7 +838,7 @@ impl Server {
|
||||
// This arrangement ensures we will attempt to process earlier messages first, but fall
|
||||
// back to processing messages arrived later in the spirit of making progress.
|
||||
let mut foreground_message_handlers = FuturesUnordered::new();
|
||||
let concurrent_handlers = Arc::new(Semaphore::new(512));
|
||||
let concurrent_handlers = Arc::new(Semaphore::new(256));
|
||||
loop {
|
||||
let next_message = async {
|
||||
let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
|
||||
@@ -865,6 +865,7 @@ impl Server {
|
||||
user_id=field::Empty,
|
||||
login=field::Empty,
|
||||
impersonator=field::Empty,
|
||||
multi_lsp_query_request=field::Empty,
|
||||
);
|
||||
principal.update_span(&span);
|
||||
let span_enter = span.enter();
|
||||
@@ -2329,6 +2330,15 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn multi_lsp_query(
|
||||
request: MultiLspQuery,
|
||||
response: Response<MultiLspQuery>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
|
||||
forward_mutating_project_request(request, response, session).await
|
||||
}
|
||||
|
||||
/// Notify other participants that a new buffer has been created
|
||||
async fn create_buffer_for_peer(
|
||||
request: proto::CreateBufferForPeer,
|
||||
|
||||
@@ -38,12 +38,12 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
|
||||
let mut remote = room
|
||||
.remote_participants()
|
||||
.values()
|
||||
.map(|participant| participant.user.github_login.clone())
|
||||
.map(|participant| participant.user.github_login.clone().to_string())
|
||||
.collect::<Vec<_>>();
|
||||
let mut pending = room
|
||||
.pending_participants()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.map(|user| user.github_login.clone().to_string())
|
||||
.collect::<Vec<_>>();
|
||||
remote.sort();
|
||||
pending.sort();
|
||||
|
||||
@@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections(
|
||||
client_b1.disconnect(&cx_b1.to_async());
|
||||
executor.advance_clock(RECEIVE_TIMEOUT);
|
||||
client_b1
|
||||
.authenticate_and_connect(false, &cx_b1.to_async())
|
||||
.connect(false, &cx_b1.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -1667,7 +1667,7 @@ async fn test_project_reconnect(
|
||||
// Client A reconnects. Their project is re-shared, and client B re-joins it.
|
||||
server.allow_connections();
|
||||
client_a
|
||||
.authenticate_and_connect(false, &cx_a.to_async())
|
||||
.connect(false, &cx_a.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -1796,7 +1796,7 @@ async fn test_project_reconnect(
|
||||
// Client B reconnects. They re-join the room and the remaining shared project.
|
||||
server.allow_connections();
|
||||
client_b
|
||||
.authenticate_and_connect(false, &cx_b.to_async())
|
||||
.connect(false, &cx_b.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -1881,7 +1881,7 @@ async fn test_active_call_events(
|
||||
vec![room::Event::RemoteProjectShared {
|
||||
owner: Arc::new(User {
|
||||
id: client_a.user_id().unwrap(),
|
||||
github_login: "user_a".to_string(),
|
||||
github_login: "user_a".into(),
|
||||
avatar_uri: "avatar_a".into(),
|
||||
name: None,
|
||||
}),
|
||||
@@ -1900,7 +1900,7 @@ async fn test_active_call_events(
|
||||
vec![room::Event::RemoteProjectShared {
|
||||
owner: Arc::new(User {
|
||||
id: client_b.user_id().unwrap(),
|
||||
github_login: "user_b".to_string(),
|
||||
github_login: "user_b".into(),
|
||||
avatar_uri: "avatar_b".into(),
|
||||
name: None,
|
||||
}),
|
||||
@@ -5738,7 +5738,7 @@ async fn test_contacts(
|
||||
|
||||
server.allow_connections();
|
||||
client_c
|
||||
.authenticate_and_connect(false, &cx_c.to_async())
|
||||
.connect(false, &cx_c.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -6079,7 +6079,7 @@ async fn test_contacts(
|
||||
.iter()
|
||||
.map(|contact| {
|
||||
(
|
||||
contact.user.github_login.clone(),
|
||||
contact.user.github_login.clone().to_string(),
|
||||
if contact.online { "online" } else { "offline" },
|
||||
if contact.busy { "busy" } else { "free" },
|
||||
)
|
||||
@@ -6269,7 +6269,7 @@ async fn test_contact_requests(
|
||||
client.disconnect(&cx.to_async());
|
||||
client.clear_contacts(cx).await;
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use gpui::{BackgroundExecutor, TestAppContext};
|
||||
use notifications::NotificationEvent;
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::{Notification, proto};
|
||||
|
||||
use crate::tests::TestServer;
|
||||
@@ -17,6 +18,9 @@ async fn test_notifications(
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
// Wait for authentication/connection to Collab to be established.
|
||||
executor.run_until_parked();
|
||||
|
||||
let notification_events_a = Arc::new(Mutex::new(Vec::new()));
|
||||
let notification_events_b = Arc::new(Mutex::new(Vec::new()));
|
||||
client_a.notification_store().update(cx_a, |_, cx| {
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{
|
||||
use anyhow::anyhow;
|
||||
use call::ActiveCall;
|
||||
use channel::{ChannelBuffer, ChannelStore};
|
||||
use client::CloudUserStore;
|
||||
use client::test::{make_get_authenticated_user_response, parse_authorization_header};
|
||||
use client::{
|
||||
self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
|
||||
proto::PeerId,
|
||||
@@ -21,7 +21,7 @@ use fs::FakeFs;
|
||||
use futures::{StreamExt as _, channel::oneshot};
|
||||
use git::GitHostingProviderRegistry;
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext};
|
||||
use http_client::FakeHttpClient;
|
||||
use http_client::{FakeHttpClient, Method};
|
||||
use language::LanguageRegistry;
|
||||
use node_runtime::NodeRuntime;
|
||||
use notifications::NotificationStore;
|
||||
@@ -162,6 +162,8 @@ impl TestServer {
|
||||
}
|
||||
|
||||
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
|
||||
const ACCESS_TOKEN: &str = "the-token";
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
cx.update(|cx| {
|
||||
@@ -176,7 +178,7 @@ impl TestServer {
|
||||
});
|
||||
|
||||
let clock = Arc::new(FakeSystemClock::new());
|
||||
let http = FakeHttpClient::with_404_response();
|
||||
|
||||
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
|
||||
{
|
||||
user.id
|
||||
@@ -198,6 +200,47 @@ impl TestServer {
|
||||
.expect("creating user failed")
|
||||
.user_id
|
||||
};
|
||||
|
||||
let http = FakeHttpClient::create({
|
||||
let name = name.to_string();
|
||||
move |req| {
|
||||
let name = name.clone();
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/client/users/me") => {
|
||||
let credentials = parse_authorization_header(&req);
|
||||
if credentials
|
||||
!= Some(Credentials {
|
||||
user_id: user_id.to_proto(),
|
||||
access_token: ACCESS_TOKEN.into(),
|
||||
})
|
||||
{
|
||||
return Ok(http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
Ok(http_client::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
serde_json::to_string(&make_get_authenticated_user_response(
|
||||
user_id.0, name,
|
||||
))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
_ => Ok(http_client::Response::builder()
|
||||
.status(404)
|
||||
.body("Not Found".into())
|
||||
.unwrap()),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let client_name = name.to_string();
|
||||
let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx));
|
||||
let server = self.server.clone();
|
||||
@@ -209,11 +252,10 @@ impl TestServer {
|
||||
.unwrap()
|
||||
.set_id(user_id.to_proto())
|
||||
.override_authenticate(move |cx| {
|
||||
let access_token = "the-token".to_string();
|
||||
cx.spawn(async move |_| {
|
||||
Ok(Credentials {
|
||||
user_id: user_id.to_proto(),
|
||||
access_token,
|
||||
access_token: ACCESS_TOKEN.into(),
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -222,7 +264,7 @@ impl TestServer {
|
||||
credentials,
|
||||
&Credentials {
|
||||
user_id: user_id.0 as u64,
|
||||
access_token: "the-token".into()
|
||||
access_token: ACCESS_TOKEN.into(),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -282,14 +324,12 @@ impl TestServer {
|
||||
.register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance()));
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
let cloud_user_store = cx.new(|cx| CloudUserStore::new(client.cloud_client(), cx));
|
||||
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
|
||||
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let session = cx.new(|cx| AppSession::new(Session::test(), cx));
|
||||
let app_state = Arc::new(workspace::AppState {
|
||||
client: client.clone(),
|
||||
user_store: user_store.clone(),
|
||||
cloud_user_store,
|
||||
workspace_store,
|
||||
languages: language_registry,
|
||||
fs: fs.clone(),
|
||||
@@ -322,7 +362,7 @@ impl TestServer {
|
||||
});
|
||||
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -695,17 +735,17 @@ impl TestClient {
|
||||
current: store
|
||||
.contacts()
|
||||
.iter()
|
||||
.map(|contact| contact.user.github_login.clone())
|
||||
.map(|contact| contact.user.github_login.clone().to_string())
|
||||
.collect(),
|
||||
outgoing_requests: store
|
||||
.outgoing_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.map(|user| user.github_login.clone().to_string())
|
||||
.collect(),
|
||||
incoming_requests: store
|
||||
.incoming_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone())
|
||||
.map(|user| user.github_login.clone().to_string())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -940,7 +940,7 @@ impl CollabPanel {
|
||||
room.read(cx).local_participant().role == proto::ChannelRole::Admin
|
||||
});
|
||||
|
||||
ListItem::new(SharedString::from(user.github_login.clone()))
|
||||
ListItem::new(user.github_login.clone())
|
||||
.start_slot(Avatar::new(user.avatar_uri.clone()))
|
||||
.child(Label::new(user.github_login.clone()))
|
||||
.toggle_state(is_selected)
|
||||
@@ -2331,7 +2331,7 @@ impl CollabPanel {
|
||||
let client = this.client.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
client
|
||||
.authenticate_and_connect(true, &cx)
|
||||
.connect(true, &cx)
|
||||
.await
|
||||
.into_response()
|
||||
.notify_async_err(cx);
|
||||
@@ -2583,7 +2583,7 @@ impl CollabPanel {
|
||||
) -> impl IntoElement {
|
||||
let online = contact.online;
|
||||
let busy = contact.busy || calling;
|
||||
let github_login = SharedString::from(contact.user.github_login.clone());
|
||||
let github_login = contact.user.github_login.clone();
|
||||
let item = ListItem::new(github_login.clone())
|
||||
.indent_level(1)
|
||||
.indent_step_size(px(20.))
|
||||
@@ -2662,7 +2662,7 @@ impl CollabPanel {
|
||||
is_selected: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let github_login = SharedString::from(user.github_login.clone());
|
||||
let github_login = user.github_login.clone();
|
||||
let user_id = user.id;
|
||||
let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user);
|
||||
let color = if is_response_pending {
|
||||
|
||||
@@ -634,13 +634,13 @@ impl Render for NotificationPanel {
|
||||
.child(Icon::new(IconName::Envelope)),
|
||||
)
|
||||
.map(|this| {
|
||||
if self.client.user_id().is_none() {
|
||||
if !self.client.status().borrow().is_connected() {
|
||||
this.child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.p_4()
|
||||
.child(
|
||||
Button::new("sign_in_prompt_button", "Sign in")
|
||||
Button::new("connect_prompt_button", "Connect")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
@@ -652,10 +652,7 @@ impl Render for NotificationPanel {
|
||||
let client = client.clone();
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
match client
|
||||
.authenticate_and_connect(true, &cx)
|
||||
.await
|
||||
{
|
||||
match client.connect(true, &cx).await {
|
||||
util::ConnectionResult::Timeout => {
|
||||
log::error!("Connection timeout");
|
||||
}
|
||||
@@ -673,7 +670,7 @@ impl Render for NotificationPanel {
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Sign in to view notifications.")
|
||||
Label::new("Connect to view notifications.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
|
||||
@@ -441,14 +441,12 @@ impl Client {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.notification_handlers
|
||||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
pub fn on_notification(
|
||||
&self,
|
||||
method: &'static str,
|
||||
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||
) {
|
||||
self.notification_handlers.lock().insert(method, f);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -95,8 +95,28 @@ impl ContextServer {
|
||||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = match &self.configuration {
|
||||
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
|
||||
self.initialize(self.new_client(cx)?).await
|
||||
}
|
||||
|
||||
/// Starts the context server, making sure handlers are registered before initialization happens
|
||||
pub async fn start_with_handlers(
|
||||
&self,
|
||||
notification_handlers: Vec<(
|
||||
&'static str,
|
||||
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
|
||||
)>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = self.new_client(cx)?;
|
||||
for (method, handler) in notification_handlers {
|
||||
client.on_notification(method, handler);
|
||||
}
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
|
||||
Ok(match &self.configuration {
|
||||
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
@@ -113,8 +133,7 @@ impl ContextServer {
|
||||
transport.clone(),
|
||||
cx.clone(),
|
||||
)?,
|
||||
};
|
||||
self.initialize(client).await
|
||||
})
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
|
||||
@@ -83,14 +83,18 @@ impl McpServer {
|
||||
}
|
||||
|
||||
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
||||
let output_schema = schemars::schema_for!(T::Output);
|
||||
let unit_schema = schemars::schema_for!(());
|
||||
let mut settings = schemars::generate::SchemaSettings::draft07();
|
||||
settings.inline_subschemas = true;
|
||||
let mut generator = settings.into_generator();
|
||||
|
||||
let output_schema = generator.root_schema_for::<T::Output>();
|
||||
let unit_schema = generator.root_schema_for::<T::Output>();
|
||||
|
||||
let registered_tool = RegisteredTool {
|
||||
tool: Tool {
|
||||
name: T::NAME.into(),
|
||||
description: Some(tool.description().into()),
|
||||
input_schema: schemars::schema_for!(T::Input).into(),
|
||||
input_schema: generator.root_schema_for::<T::Input>().into(),
|
||||
output_schema: if output_schema == unit_schema {
|
||||
None
|
||||
} else {
|
||||
|
||||
@@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
|
||||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
pub fn on_notification(
|
||||
&self,
|
||||
method: &'static str,
|
||||
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||
) {
|
||||
self.inner.on_notification(method, f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,7 +295,7 @@ mod tests {
|
||||
request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
|
||||
},
|
||||
},
|
||||
Box::new(|_| panic!("Did not expect to hit this code path")),
|
||||
Box::new(|_| {}),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -883,6 +883,7 @@ impl FakeTransport {
|
||||
break Err(anyhow!("exit in response to request"));
|
||||
}
|
||||
};
|
||||
let success = response.success;
|
||||
let message =
|
||||
serde_json::to_string(&Message::Response(response)).unwrap();
|
||||
|
||||
@@ -893,6 +894,25 @@ impl FakeTransport {
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if request.command == dap_types::requests::Initialize::COMMAND
|
||||
&& success
|
||||
{
|
||||
let message = serde_json::to_string(&Message::Event(Box::new(
|
||||
dap_types::messages::Events::Initialized(Some(
|
||||
Default::default(),
|
||||
)),
|
||||
)))
|
||||
.unwrap();
|
||||
writer
|
||||
.write_all(
|
||||
TransportDelegate::build_rpc_message(message)
|
||||
.as_bytes(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
writer.flush().await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@ license = "GPL-3.0-or-later"
|
||||
anyhow.workspace = true
|
||||
command_palette.workspace = true
|
||||
gpui.workspace = true
|
||||
mdbook = "0.4.40"
|
||||
# We are specifically pinning this version of mdbook, as later versions introduce issues with double-nested subdirectories.
|
||||
# Ask @maxdeviant about this before bumping.
|
||||
mdbook = "= 0.4.40"
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -22,6 +22,7 @@ test-support = [
|
||||
"theme/test-support",
|
||||
"util/test-support",
|
||||
"workspace/test-support",
|
||||
"tree-sitter-c",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-typescript",
|
||||
"tree-sitter-html",
|
||||
@@ -76,6 +77,7 @@ telemetry.workspace = true
|
||||
text.workspace = true
|
||||
time.workspace = true
|
||||
theme.workspace = true
|
||||
tree-sitter-c = { workspace = true, optional = true }
|
||||
tree-sitter-html = { workspace = true, optional = true }
|
||||
tree-sitter-rust = { workspace = true, optional = true }
|
||||
tree-sitter-typescript = { workspace = true, optional = true }
|
||||
@@ -106,6 +108,7 @@ settings = { workspace = true, features = ["test-support"] }
|
||||
tempfile.workspace = true
|
||||
text = { workspace = true, features = ["test-support"] }
|
||||
theme = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-c.workspace = true
|
||||
tree-sitter-html.workspace = true
|
||||
tree-sitter-rust.workspace = true
|
||||
tree-sitter-typescript.workspace = true
|
||||
|
||||
@@ -1305,6 +1305,7 @@ impl Default for SelectionHistoryMode {
|
||||
///
|
||||
/// Similarly, you might want to disable scrolling if you don't want the viewport to
|
||||
/// move.
|
||||
#[derive(Clone)]
|
||||
pub struct SelectionEffects {
|
||||
nav_history: Option<bool>,
|
||||
completions: bool,
|
||||
@@ -2944,10 +2945,12 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
|
||||
let selection_anchors = self.selections.disjoint_anchors();
|
||||
|
||||
if self.focus_handle.is_focused(window) && self.leader_id.is_none() {
|
||||
self.buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_active_selections(
|
||||
&self.selections.disjoint_anchors(),
|
||||
&selection_anchors,
|
||||
self.selections.line_mode,
|
||||
self.cursor_shape,
|
||||
cx,
|
||||
@@ -2964,9 +2967,8 @@ impl Editor {
|
||||
self.select_next_state = None;
|
||||
self.select_prev_state = None;
|
||||
self.select_syntax_node_history.try_clear();
|
||||
self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), buffer);
|
||||
self.snippet_stack
|
||||
.invalidate(&self.selections.disjoint_anchors(), buffer);
|
||||
self.invalidate_autoclose_regions(&selection_anchors, buffer);
|
||||
self.snippet_stack.invalidate(&selection_anchors, buffer);
|
||||
self.take_rename(false, window, cx);
|
||||
|
||||
let newest_selection = self.selections.newest_anchor();
|
||||
@@ -4047,7 +4049,8 @@ impl Editor {
|
||||
// then don't insert that closing bracket again; just move the selection
|
||||
// past the closing bracket.
|
||||
let should_skip = selection.end == region.range.end.to_point(&snapshot)
|
||||
&& text.as_ref() == region.pair.end.as_str();
|
||||
&& text.as_ref() == region.pair.end.as_str()
|
||||
&& snapshot.contains_str_at(region.range.end, text.as_ref());
|
||||
if should_skip {
|
||||
let anchor = snapshot.anchor_after(selection.end);
|
||||
new_selections
|
||||
@@ -4973,13 +4976,17 @@ impl Editor {
|
||||
})
|
||||
}
|
||||
|
||||
/// Remove any autoclose regions that no longer contain their selection.
|
||||
/// Remove any autoclose regions that no longer contain their selection or have invalid anchors in ranges.
|
||||
fn invalidate_autoclose_regions(
|
||||
&mut self,
|
||||
mut selections: &[Selection<Anchor>],
|
||||
buffer: &MultiBufferSnapshot,
|
||||
) {
|
||||
self.autoclose_regions.retain(|state| {
|
||||
if !state.range.start.is_valid(buffer) || !state.range.end.is_valid(buffer) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
while let Some(selection) = selections.get(i) {
|
||||
if selection.end.cmp(&state.range.start, buffer).is_lt() {
|
||||
@@ -5891,18 +5898,20 @@ impl Editor {
|
||||
text: new_text[common_prefix_len..].into(),
|
||||
});
|
||||
|
||||
self.transact(window, cx, |this, window, cx| {
|
||||
self.transact(window, cx, |editor, window, cx| {
|
||||
if let Some(mut snippet) = snippet {
|
||||
snippet.text = new_text.to_string();
|
||||
this.insert_snippet(&ranges, snippet, window, cx).log_err();
|
||||
editor
|
||||
.insert_snippet(&ranges, snippet, window, cx)
|
||||
.log_err();
|
||||
} else {
|
||||
this.buffer.update(cx, |buffer, cx| {
|
||||
editor.buffer.update(cx, |multi_buffer, cx| {
|
||||
let auto_indent = match completion.insert_text_mode {
|
||||
Some(InsertTextMode::AS_IS) => None,
|
||||
_ => this.autoindent_mode.clone(),
|
||||
_ => editor.autoindent_mode.clone(),
|
||||
};
|
||||
let edits = ranges.into_iter().map(|range| (range, new_text.as_str()));
|
||||
buffer.edit(edits, auto_indent, cx);
|
||||
multi_buffer.edit(edits, auto_indent, cx);
|
||||
});
|
||||
}
|
||||
for (buffer, edits) in linked_edits {
|
||||
@@ -5921,8 +5930,9 @@ impl Editor {
|
||||
})
|
||||
}
|
||||
|
||||
this.refresh_inline_completion(true, false, window, cx);
|
||||
editor.refresh_inline_completion(true, false, window, cx);
|
||||
});
|
||||
self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), &snapshot);
|
||||
|
||||
let show_new_completions_on_confirm = completion
|
||||
.confirm
|
||||
@@ -9562,27 +9572,46 @@ impl Editor {
|
||||
// Check whether the just-entered snippet ends with an auto-closable bracket.
|
||||
if self.autoclose_regions.is_empty() {
|
||||
let snapshot = self.buffer.read(cx).snapshot(cx);
|
||||
for selection in &mut self.selections.all::<Point>(cx) {
|
||||
let mut all_selections = self.selections.all::<Point>(cx);
|
||||
for selection in &mut all_selections {
|
||||
let selection_head = selection.head();
|
||||
let Some(scope) = snapshot.language_scope_at(selection_head) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let mut bracket_pair = None;
|
||||
let next_chars = snapshot.chars_at(selection_head).collect::<String>();
|
||||
let prev_chars = snapshot
|
||||
.reversed_chars_at(selection_head)
|
||||
.collect::<String>();
|
||||
for (pair, enabled) in scope.brackets() {
|
||||
if enabled
|
||||
&& pair.close
|
||||
&& prev_chars.starts_with(pair.start.as_str())
|
||||
&& next_chars.starts_with(pair.end.as_str())
|
||||
{
|
||||
bracket_pair = Some(pair.clone());
|
||||
break;
|
||||
let max_lookup_length = scope
|
||||
.brackets()
|
||||
.map(|(pair, _)| {
|
||||
pair.start
|
||||
.as_str()
|
||||
.chars()
|
||||
.count()
|
||||
.max(pair.end.as_str().chars().count())
|
||||
})
|
||||
.max();
|
||||
if let Some(max_lookup_length) = max_lookup_length {
|
||||
let next_text = snapshot
|
||||
.chars_at(selection_head)
|
||||
.take(max_lookup_length)
|
||||
.collect::<String>();
|
||||
let prev_text = snapshot
|
||||
.reversed_chars_at(selection_head)
|
||||
.take(max_lookup_length)
|
||||
.collect::<String>();
|
||||
|
||||
for (pair, enabled) in scope.brackets() {
|
||||
if enabled
|
||||
&& pair.close
|
||||
&& prev_text.starts_with(pair.start.as_str())
|
||||
&& next_text.starts_with(pair.end.as_str())
|
||||
{
|
||||
bracket_pair = Some(pair.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pair) = bracket_pair {
|
||||
let snapshot_settings = snapshot.language_settings_at(selection_head, cx);
|
||||
let autoclose_enabled =
|
||||
|
||||
@@ -8612,6 +8612,7 @@ async fn test_autoclose_with_embedded_language(cx: &mut TestAppContext) {
|
||||
|
||||
cx.language_registry().add(html_language.clone());
|
||||
cx.language_registry().add(javascript_language.clone());
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
cx.update_buffer(|buffer, cx| {
|
||||
buffer.set_language(Some(html_language), cx);
|
||||
@@ -13400,6 +13401,178 @@ async fn test_as_is_completions(cx: &mut TestAppContext) {
|
||||
cx.assert_editor_state("fn a() {}\n unsafeˇ");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_panic_during_c_completions(cx: &mut TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
let language =
|
||||
Arc::try_unwrap(languages::language("c", tree_sitter_c::LANGUAGE.into())).unwrap();
|
||||
let mut cx = EditorLspTestContext::new(
|
||||
language,
|
||||
lsp::ServerCapabilities {
|
||||
completion_provider: Some(lsp::CompletionOptions {
|
||||
..lsp::CompletionOptions::default()
|
||||
}),
|
||||
..lsp::ServerCapabilities::default()
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
cx.set_state(
|
||||
"#ifndef BAR_H
|
||||
#define BAR_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
int fn_branch(bool do_branch1, bool do_branch2);
|
||||
|
||||
#endif // BAR_H
|
||||
ˇ",
|
||||
);
|
||||
cx.executor().run_until_parked();
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("#", window, cx);
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("i", window, cx);
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("n", window, cx);
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.assert_editor_state(
|
||||
"#ifndef BAR_H
|
||||
#define BAR_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
int fn_branch(bool do_branch1, bool do_branch2);
|
||||
|
||||
#endif // BAR_H
|
||||
#inˇ",
|
||||
);
|
||||
|
||||
cx.lsp
|
||||
.set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move {
|
||||
Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList {
|
||||
is_incomplete: false,
|
||||
item_defaults: None,
|
||||
items: vec![lsp::CompletionItem {
|
||||
kind: Some(lsp::CompletionItemKind::SNIPPET),
|
||||
label_details: Some(lsp::CompletionItemLabelDetails {
|
||||
detail: Some("header".to_string()),
|
||||
description: None,
|
||||
}),
|
||||
label: " include".to_string(),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
range: lsp::Range {
|
||||
start: lsp::Position {
|
||||
line: 8,
|
||||
character: 1,
|
||||
},
|
||||
end: lsp::Position {
|
||||
line: 8,
|
||||
character: 1,
|
||||
},
|
||||
},
|
||||
new_text: "include \"$0\"".to_string(),
|
||||
})),
|
||||
sort_text: Some("40b67681include".to_string()),
|
||||
insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
|
||||
filter_text: Some("include".to_string()),
|
||||
insert_text: Some("include \"$0\"".to_string()),
|
||||
..lsp::CompletionItem::default()
|
||||
}],
|
||||
})))
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.show_completions(&ShowCompletions { trigger: None }, window, cx);
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.confirm_completion(&ConfirmCompletion::default(), window, cx)
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.assert_editor_state(
|
||||
"#ifndef BAR_H
|
||||
#define BAR_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
int fn_branch(bool do_branch1, bool do_branch2);
|
||||
|
||||
#endif // BAR_H
|
||||
#include \"ˇ\"",
|
||||
);
|
||||
|
||||
cx.lsp
|
||||
.set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move {
|
||||
Ok(Some(lsp::CompletionResponse::List(lsp::CompletionList {
|
||||
is_incomplete: true,
|
||||
item_defaults: None,
|
||||
items: vec![lsp::CompletionItem {
|
||||
kind: Some(lsp::CompletionItemKind::FILE),
|
||||
label: "AGL/".to_string(),
|
||||
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
|
||||
range: lsp::Range {
|
||||
start: lsp::Position {
|
||||
line: 8,
|
||||
character: 10,
|
||||
},
|
||||
end: lsp::Position {
|
||||
line: 8,
|
||||
character: 11,
|
||||
},
|
||||
},
|
||||
new_text: "AGL/".to_string(),
|
||||
})),
|
||||
sort_text: Some("40b67681AGL/".to_string()),
|
||||
insert_text_format: Some(lsp::InsertTextFormat::PLAIN_TEXT),
|
||||
filter_text: Some("AGL/".to_string()),
|
||||
insert_text: Some("AGL/".to_string()),
|
||||
..lsp::CompletionItem::default()
|
||||
}],
|
||||
})))
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.show_completions(&ShowCompletions { trigger: None }, window, cx);
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.confirm_completion(&ConfirmCompletion::default(), window, cx)
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.assert_editor_state(
|
||||
r##"#ifndef BAR_H
|
||||
#define BAR_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
int fn_branch(bool do_branch1, bool do_branch2);
|
||||
|
||||
#endif // BAR_H
|
||||
#include "AGL/ˇ"##,
|
||||
);
|
||||
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("\"", window, cx);
|
||||
});
|
||||
cx.executor().run_until_parked();
|
||||
cx.assert_editor_state(
|
||||
r##"#ifndef BAR_H
|
||||
#define BAR_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
int fn_branch(bool do_branch1, bool do_branch2);
|
||||
|
||||
#endif // BAR_H
|
||||
#include "AGL/"ˇ"##,
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_duplicated_completion_requests(cx: &mut TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
@@ -18,7 +18,7 @@ use collections::{HashMap, HashSet};
|
||||
use extension::ExtensionHostProxy;
|
||||
use futures::future;
|
||||
use gpui::http_client::read_proxy_from_env;
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry, SelectedModel};
|
||||
@@ -337,7 +337,8 @@ pub struct AgentAppState {
|
||||
}
|
||||
|
||||
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
let app_version = AppVersion::global(cx);
|
||||
release_channel::init(app_version, cx);
|
||||
gpui_tokio::init(cx);
|
||||
|
||||
let mut settings_store = SettingsStore::new(cx);
|
||||
@@ -350,7 +351,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
// Set User-Agent so we can download language servers from GitHub
|
||||
let user_agent = format!(
|
||||
"Zed/{} ({}; {})",
|
||||
AppVersion::global(cx),
|
||||
app_version,
|
||||
std::env::consts::OS,
|
||||
std::env::consts::ARCH
|
||||
);
|
||||
|
||||
@@ -2416,7 +2416,7 @@ impl GitPanel {
|
||||
.committer_name
|
||||
.clone()
|
||||
.or_else(|| participant.user.name.clone())
|
||||
.unwrap_or_else(|| participant.user.github_login.clone());
|
||||
.unwrap_or_else(|| participant.user.github_login.clone().to_string());
|
||||
new_co_authors.push((name.clone(), email.clone()))
|
||||
}
|
||||
}
|
||||
@@ -2436,7 +2436,7 @@ impl GitPanel {
|
||||
.name
|
||||
.clone()
|
||||
.or_else(|| user.name.clone())
|
||||
.unwrap_or_else(|| user.github_login.clone());
|
||||
.unwrap_or_else(|| user.github_login.clone().to_string());
|
||||
Some((name, email))
|
||||
}
|
||||
|
||||
|
||||
@@ -310,6 +310,18 @@ mod windows {
|
||||
&rust_binding_path,
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
|
||||
.join("src/platform/windows/color_text_raster.hlsl");
|
||||
compile_shader_for_module(
|
||||
"emoji_rasterization",
|
||||
&out_dir,
|
||||
&fxc_path,
|
||||
shader_path.to_str().unwrap(),
|
||||
&rust_binding_path,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler.
|
||||
|
||||
@@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid {
|
||||
"χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р",
|
||||
"У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*",
|
||||
"_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ",
|
||||
"❮", "<=", "!=", "==", "--", "++", "=>", "->",
|
||||
"❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎",
|
||||
];
|
||||
|
||||
let columns = 11;
|
||||
|
||||
@@ -35,6 +35,7 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) {
|
||||
|
||||
/// An RGBA color
|
||||
#[derive(PartialEq, Clone, Copy, Default)]
|
||||
#[repr(C)]
|
||||
pub struct Rgba {
|
||||
/// The red component of the color, in the range 0.0 to 1.0
|
||||
pub r: f32,
|
||||
|
||||
39
crates/gpui/src/platform/windows/color_text_raster.hlsl
Normal file
@@ -0,0 +1,39 @@
|
||||
struct RasterVertexOutput {
|
||||
float4 position : SV_Position;
|
||||
float2 texcoord : TEXCOORD0;
|
||||
};
|
||||
|
||||
RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID)
|
||||
{
|
||||
RasterVertexOutput output;
|
||||
output.texcoord = float2((vertexID << 1) & 2, vertexID & 2);
|
||||
output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f);
|
||||
output.position.y = -output.position.y;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
struct PixelInput {
|
||||
float4 position: SV_Position;
|
||||
float2 texcoord : TEXCOORD0;
|
||||
};
|
||||
|
||||
struct Bounds {
|
||||
int2 origin;
|
||||
int2 size;
|
||||
};
|
||||
|
||||
Texture2D<float4> t_layer : register(t0);
|
||||
SamplerState s_layer : register(s0);
|
||||
|
||||
cbuffer GlyphLayerTextureParams : register(b0) {
|
||||
Bounds bounds;
|
||||
float4 run_color;
|
||||
};
|
||||
|
||||
float4 emoji_rasterization_fragment(PixelInput input): SV_Target {
|
||||
float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb;
|
||||
float alpha = (sampled.r + sampled.g + sampled.b) / 3;
|
||||
|
||||
return float4(run_color.rgb, alpha);
|
||||
}
|
||||
@@ -10,10 +10,11 @@ use windows::{
|
||||
Foundation::*,
|
||||
Globalization::GetUserDefaultLocaleName,
|
||||
Graphics::{
|
||||
Direct2D::{Common::*, *},
|
||||
Direct3D::D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP,
|
||||
Direct3D11::*,
|
||||
DirectWrite::*,
|
||||
Dxgi::Common::*,
|
||||
Gdi::LOGFONTW,
|
||||
Gdi::{IsRectEmpty, LOGFONTW},
|
||||
Imaging::*,
|
||||
},
|
||||
System::SystemServices::LOCALE_NAME_MAX_LENGTH,
|
||||
@@ -40,16 +41,21 @@ struct DirectWriteComponent {
|
||||
locale: String,
|
||||
factory: IDWriteFactory5,
|
||||
bitmap_factory: AgileReference<IWICImagingFactory>,
|
||||
d2d1_factory: ID2D1Factory,
|
||||
in_memory_loader: IDWriteInMemoryFontFileLoader,
|
||||
builder: IDWriteFontSetBuilder1,
|
||||
text_renderer: Arc<TextRendererWrapper>,
|
||||
render_context: GlyphRenderContext,
|
||||
|
||||
render_params: IDWriteRenderingParams3,
|
||||
gpu_state: GPUState,
|
||||
}
|
||||
|
||||
struct GlyphRenderContext {
|
||||
params: IDWriteRenderingParams3,
|
||||
dc_target: ID2D1DeviceContext4,
|
||||
struct GPUState {
|
||||
device: ID3D11Device,
|
||||
device_context: ID3D11DeviceContext,
|
||||
sampler: [Option<ID3D11SamplerState>; 1],
|
||||
blend_state: ID3D11BlendState,
|
||||
vertex_shader: ID3D11VertexShader,
|
||||
pixel_shader: ID3D11PixelShader,
|
||||
}
|
||||
|
||||
struct DirectWriteState {
|
||||
@@ -70,12 +76,11 @@ struct FontIdentifier {
|
||||
}
|
||||
|
||||
impl DirectWriteComponent {
|
||||
pub fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> {
|
||||
pub fn new(bitmap_factory: &IWICImagingFactory, gpu_context: &DirectXDevices) -> Result<Self> {
|
||||
// todo: ideally this would not be a large unsafe block but smaller isolated ones for easier auditing
|
||||
unsafe {
|
||||
let factory: IDWriteFactory5 = DWriteCreateFactory(DWRITE_FACTORY_TYPE_SHARED)?;
|
||||
let bitmap_factory = AgileReference::new(bitmap_factory)?;
|
||||
let d2d1_factory: ID2D1Factory =
|
||||
D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, None)?;
|
||||
// The `IDWriteInMemoryFontFileLoader` here is supported starting from
|
||||
// Windows 10 Creators Update, which consequently requires the entire
|
||||
// `DirectWriteTextSystem` to run on `win10 1703`+.
|
||||
@@ -86,60 +91,132 @@ impl DirectWriteComponent {
|
||||
GetUserDefaultLocaleName(&mut locale_vec);
|
||||
let locale = String::from_utf16_lossy(&locale_vec);
|
||||
let text_renderer = Arc::new(TextRendererWrapper::new(&locale));
|
||||
let render_context = GlyphRenderContext::new(&factory, &d2d1_factory)?;
|
||||
|
||||
let render_params = {
|
||||
let default_params: IDWriteRenderingParams3 =
|
||||
factory.CreateRenderingParams()?.cast()?;
|
||||
let gamma = default_params.GetGamma();
|
||||
let enhanced_contrast = default_params.GetEnhancedContrast();
|
||||
let gray_contrast = default_params.GetGrayscaleEnhancedContrast();
|
||||
let cleartype_level = default_params.GetClearTypeLevel();
|
||||
let grid_fit_mode = default_params.GetGridFitMode();
|
||||
|
||||
factory.CreateCustomRenderingParams(
|
||||
gamma,
|
||||
enhanced_contrast,
|
||||
gray_contrast,
|
||||
cleartype_level,
|
||||
DWRITE_PIXEL_GEOMETRY_RGB,
|
||||
DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC,
|
||||
grid_fit_mode,
|
||||
)?
|
||||
};
|
||||
|
||||
let gpu_state = GPUState::new(gpu_context)?;
|
||||
|
||||
Ok(DirectWriteComponent {
|
||||
locale,
|
||||
factory,
|
||||
bitmap_factory,
|
||||
d2d1_factory,
|
||||
in_memory_loader,
|
||||
builder,
|
||||
text_renderer,
|
||||
render_context,
|
||||
render_params,
|
||||
gpu_state,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GlyphRenderContext {
|
||||
pub fn new(factory: &IDWriteFactory5, d2d1_factory: &ID2D1Factory) -> Result<Self> {
|
||||
unsafe {
|
||||
let default_params: IDWriteRenderingParams3 =
|
||||
factory.CreateRenderingParams()?.cast()?;
|
||||
let gamma = default_params.GetGamma();
|
||||
let enhanced_contrast = default_params.GetEnhancedContrast();
|
||||
let gray_contrast = default_params.GetGrayscaleEnhancedContrast();
|
||||
let cleartype_level = default_params.GetClearTypeLevel();
|
||||
let grid_fit_mode = default_params.GetGridFitMode();
|
||||
impl GPUState {
|
||||
fn new(gpu_context: &DirectXDevices) -> Result<Self> {
|
||||
let device = gpu_context.device.clone();
|
||||
let device_context = gpu_context.device_context.clone();
|
||||
|
||||
let params = factory.CreateCustomRenderingParams(
|
||||
gamma,
|
||||
enhanced_contrast,
|
||||
gray_contrast,
|
||||
cleartype_level,
|
||||
DWRITE_PIXEL_GEOMETRY_RGB,
|
||||
DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC,
|
||||
grid_fit_mode,
|
||||
)?;
|
||||
let dc_target = {
|
||||
let target = d2d1_factory.CreateDCRenderTarget(&get_render_target_property(
|
||||
DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
D2D1_ALPHA_MODE_PREMULTIPLIED,
|
||||
))?;
|
||||
let target = target.cast::<ID2D1DeviceContext4>()?;
|
||||
target.SetTextRenderingParams(¶ms);
|
||||
target
|
||||
let blend_state = {
|
||||
let mut blend_state = None;
|
||||
let desc = D3D11_BLEND_DESC {
|
||||
AlphaToCoverageEnable: false.into(),
|
||||
IndependentBlendEnable: false.into(),
|
||||
RenderTarget: [
|
||||
D3D11_RENDER_TARGET_BLEND_DESC {
|
||||
BlendEnable: true.into(),
|
||||
SrcBlend: D3D11_BLEND_SRC_ALPHA,
|
||||
DestBlend: D3D11_BLEND_INV_SRC_ALPHA,
|
||||
BlendOp: D3D11_BLEND_OP_ADD,
|
||||
SrcBlendAlpha: D3D11_BLEND_SRC_ALPHA,
|
||||
DestBlendAlpha: D3D11_BLEND_INV_SRC_ALPHA,
|
||||
BlendOpAlpha: D3D11_BLEND_OP_ADD,
|
||||
RenderTargetWriteMask: D3D11_COLOR_WRITE_ENABLE_ALL.0 as u8,
|
||||
},
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
],
|
||||
};
|
||||
unsafe { device.CreateBlendState(&desc, Some(&mut blend_state)) }?;
|
||||
blend_state.unwrap()
|
||||
};
|
||||
|
||||
Ok(Self { params, dc_target })
|
||||
}
|
||||
let sampler = {
|
||||
let mut sampler = None;
|
||||
let desc = D3D11_SAMPLER_DESC {
|
||||
Filter: D3D11_FILTER_MIN_MAG_MIP_POINT,
|
||||
AddressU: D3D11_TEXTURE_ADDRESS_BORDER,
|
||||
AddressV: D3D11_TEXTURE_ADDRESS_BORDER,
|
||||
AddressW: D3D11_TEXTURE_ADDRESS_BORDER,
|
||||
MipLODBias: 0.0,
|
||||
MaxAnisotropy: 1,
|
||||
ComparisonFunc: D3D11_COMPARISON_ALWAYS,
|
||||
BorderColor: [0.0, 0.0, 0.0, 0.0],
|
||||
MinLOD: 0.0,
|
||||
MaxLOD: 0.0,
|
||||
};
|
||||
unsafe { device.CreateSamplerState(&desc, Some(&mut sampler)) }?;
|
||||
[sampler]
|
||||
};
|
||||
|
||||
let vertex_shader = {
|
||||
let source = shader_resources::RawShaderBytes::new(
|
||||
shader_resources::ShaderModule::EmojiRasterization,
|
||||
shader_resources::ShaderTarget::Vertex,
|
||||
)?;
|
||||
let mut shader = None;
|
||||
unsafe { device.CreateVertexShader(source.as_bytes(), None, Some(&mut shader)) }?;
|
||||
shader.unwrap()
|
||||
};
|
||||
|
||||
let pixel_shader = {
|
||||
let source = shader_resources::RawShaderBytes::new(
|
||||
shader_resources::ShaderModule::EmojiRasterization,
|
||||
shader_resources::ShaderTarget::Fragment,
|
||||
)?;
|
||||
let mut shader = None;
|
||||
unsafe { device.CreatePixelShader(source.as_bytes(), None, Some(&mut shader)) }?;
|
||||
shader.unwrap()
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
device,
|
||||
device_context,
|
||||
sampler,
|
||||
blend_state,
|
||||
vertex_shader,
|
||||
pixel_shader,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectWriteTextSystem {
|
||||
pub(crate) fn new(bitmap_factory: &IWICImagingFactory) -> Result<Self> {
|
||||
let components = DirectWriteComponent::new(bitmap_factory)?;
|
||||
pub(crate) fn new(
|
||||
gpu_context: &DirectXDevices,
|
||||
bitmap_factory: &IWICImagingFactory,
|
||||
) -> Result<Self> {
|
||||
let components = DirectWriteComponent::new(bitmap_factory, gpu_context)?;
|
||||
let system_font_collection = unsafe {
|
||||
let mut result = std::mem::zeroed();
|
||||
components
|
||||
@@ -648,15 +725,13 @@ impl DirectWriteState {
|
||||
}
|
||||
}
|
||||
|
||||
fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> {
|
||||
let render_target = &self.components.render_context.dc_target;
|
||||
unsafe {
|
||||
render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS);
|
||||
render_target.SetDpi(96.0 * params.scale_factor, 96.0 * params.scale_factor);
|
||||
}
|
||||
fn create_glyph_run_analysis(
|
||||
&self,
|
||||
params: &RenderGlyphParams,
|
||||
) -> Result<IDWriteGlyphRunAnalysis> {
|
||||
let font = &self.fonts[params.font_id.0];
|
||||
let glyph_id = [params.glyph_id.0 as u16];
|
||||
let advance = [0.0f32];
|
||||
let advance = [0.0];
|
||||
let offset = [DWRITE_GLYPH_OFFSET::default()];
|
||||
let glyph_run = DWRITE_GLYPH_RUN {
|
||||
fontFace: unsafe { std::mem::transmute_copy(&font.font_face) },
|
||||
@@ -668,44 +743,87 @@ impl DirectWriteState {
|
||||
isSideways: BOOL(0),
|
||||
bidiLevel: 0,
|
||||
};
|
||||
let bounds = unsafe {
|
||||
render_target.GetGlyphRunWorldBounds(
|
||||
Vector2 { X: 0.0, Y: 0.0 },
|
||||
&glyph_run,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
)?
|
||||
let transform = DWRITE_MATRIX {
|
||||
m11: params.scale_factor,
|
||||
m12: 0.0,
|
||||
m21: 0.0,
|
||||
m22: params.scale_factor,
|
||||
dx: 0.0,
|
||||
dy: 0.0,
|
||||
};
|
||||
// todo(windows)
|
||||
// This is a walkaround, deleted when figured out.
|
||||
let y_offset;
|
||||
let extra_height;
|
||||
if params.is_emoji {
|
||||
y_offset = 0;
|
||||
extra_height = 0;
|
||||
} else {
|
||||
// make some room for scaler.
|
||||
y_offset = -1;
|
||||
extra_height = 2;
|
||||
let subpixel_shift = params
|
||||
.subpixel_variant
|
||||
.map(|v| v as f32 / SUBPIXEL_VARIANTS as f32);
|
||||
let baseline_origin_x = subpixel_shift.x / params.scale_factor;
|
||||
let baseline_origin_y = subpixel_shift.y / params.scale_factor;
|
||||
|
||||
let mut rendering_mode = DWRITE_RENDERING_MODE1::default();
|
||||
let mut grid_fit_mode = DWRITE_GRID_FIT_MODE::default();
|
||||
unsafe {
|
||||
font.font_face.GetRecommendedRenderingMode(
|
||||
params.font_size.0,
|
||||
// The dpi here seems that it has the same effect with `Some(&transform)`
|
||||
1.0,
|
||||
1.0,
|
||||
Some(&transform),
|
||||
false,
|
||||
DWRITE_OUTLINE_THRESHOLD_ANTIALIASED,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
&self.components.render_params,
|
||||
&mut rendering_mode,
|
||||
&mut grid_fit_mode,
|
||||
)?;
|
||||
}
|
||||
|
||||
if bounds.right < bounds.left {
|
||||
let glyph_analysis = unsafe {
|
||||
self.components.factory.CreateGlyphRunAnalysis(
|
||||
&glyph_run,
|
||||
Some(&transform),
|
||||
rendering_mode,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
grid_fit_mode,
|
||||
// We're using cleartype not grayscale for monochrome is because it provides better quality
|
||||
DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE,
|
||||
baseline_origin_x,
|
||||
baseline_origin_y,
|
||||
)
|
||||
}?;
|
||||
Ok(glyph_analysis)
|
||||
}
|
||||
|
||||
fn raster_bounds(&self, params: &RenderGlyphParams) -> Result<Bounds<DevicePixels>> {
|
||||
let glyph_analysis = self.create_glyph_run_analysis(params)?;
|
||||
|
||||
let bounds = unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1)? };
|
||||
// Some glyphs cannot be drawn with ClearType, such as bitmap fonts. In that case
|
||||
// GetAlphaTextureBounds() supposedly returns an empty RECT, but I haven't tested that yet.
|
||||
if !unsafe { IsRectEmpty(&bounds) }.as_bool() {
|
||||
Ok(Bounds {
|
||||
origin: point(0.into(), 0.into()),
|
||||
size: size(0.into(), 0.into()),
|
||||
origin: point(bounds.left.into(), bounds.top.into()),
|
||||
size: size(
|
||||
(bounds.right - bounds.left).into(),
|
||||
(bounds.bottom - bounds.top).into(),
|
||||
),
|
||||
})
|
||||
} else {
|
||||
Ok(Bounds {
|
||||
origin: point(
|
||||
((bounds.left * params.scale_factor).ceil() as i32).into(),
|
||||
((bounds.top * params.scale_factor).ceil() as i32 + y_offset).into(),
|
||||
),
|
||||
size: size(
|
||||
(((bounds.right - bounds.left) * params.scale_factor).ceil() as i32).into(),
|
||||
(((bounds.bottom - bounds.top) * params.scale_factor).ceil() as i32
|
||||
+ extra_height)
|
||||
.into(),
|
||||
),
|
||||
})
|
||||
// If it's empty, retry with grayscale AA.
|
||||
let bounds =
|
||||
unsafe { glyph_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_ALIASED_1x1)? };
|
||||
|
||||
if bounds.right < bounds.left {
|
||||
Ok(Bounds {
|
||||
origin: point(0.into(), 0.into()),
|
||||
size: size(0.into(), 0.into()),
|
||||
})
|
||||
} else {
|
||||
Ok(Bounds {
|
||||
origin: point(bounds.left.into(), bounds.top.into()),
|
||||
size: size(
|
||||
(bounds.right - bounds.left).into(),
|
||||
(bounds.bottom - bounds.top).into(),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,7 +849,95 @@ impl DirectWriteState {
|
||||
anyhow::bail!("glyph bounds are empty");
|
||||
}
|
||||
|
||||
let font_info = &self.fonts[params.font_id.0];
|
||||
let bitmap_data = if params.is_emoji {
|
||||
if let Ok(color) = self.rasterize_color(¶ms, glyph_bounds) {
|
||||
color
|
||||
} else {
|
||||
let monochrome = self.rasterize_monochrome(params, glyph_bounds)?;
|
||||
monochrome
|
||||
.into_iter()
|
||||
.flat_map(|pixel| [0, 0, 0, pixel])
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
} else {
|
||||
self.rasterize_monochrome(params, glyph_bounds)?
|
||||
};
|
||||
|
||||
Ok((glyph_bounds.size, bitmap_data))
|
||||
}
|
||||
|
||||
fn rasterize_monochrome(
|
||||
&self,
|
||||
params: &RenderGlyphParams,
|
||||
glyph_bounds: Bounds<DevicePixels>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut bitmap_data =
|
||||
vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize * 3];
|
||||
|
||||
let glyph_analysis = self.create_glyph_run_analysis(params)?;
|
||||
unsafe {
|
||||
glyph_analysis.CreateAlphaTexture(
|
||||
// We're using cleartype not grayscale for monochrome is because it provides better quality
|
||||
DWRITE_TEXTURE_CLEARTYPE_3x1,
|
||||
&RECT {
|
||||
left: glyph_bounds.origin.x.0,
|
||||
top: glyph_bounds.origin.y.0,
|
||||
right: glyph_bounds.size.width.0 + glyph_bounds.origin.x.0,
|
||||
bottom: glyph_bounds.size.height.0 + glyph_bounds.origin.y.0,
|
||||
},
|
||||
&mut bitmap_data,
|
||||
)?;
|
||||
}
|
||||
|
||||
let bitmap_factory = self.components.bitmap_factory.resolve()?;
|
||||
let bitmap = unsafe {
|
||||
bitmap_factory.CreateBitmapFromMemory(
|
||||
glyph_bounds.size.width.0 as u32,
|
||||
glyph_bounds.size.height.0 as u32,
|
||||
&GUID_WICPixelFormat24bppRGB,
|
||||
glyph_bounds.size.width.0 as u32 * 3,
|
||||
&bitmap_data,
|
||||
)
|
||||
}?;
|
||||
|
||||
let grayscale_bitmap =
|
||||
unsafe { WICConvertBitmapSource(&GUID_WICPixelFormat8bppGray, &bitmap) }?;
|
||||
|
||||
let mut bitmap_data =
|
||||
vec![0u8; glyph_bounds.size.width.0 as usize * glyph_bounds.size.height.0 as usize];
|
||||
unsafe {
|
||||
grayscale_bitmap.CopyPixels(
|
||||
std::ptr::null() as _,
|
||||
glyph_bounds.size.width.0 as u32,
|
||||
&mut bitmap_data,
|
||||
)
|
||||
}?;
|
||||
|
||||
Ok(bitmap_data)
|
||||
}
|
||||
|
||||
fn rasterize_color(
|
||||
&self,
|
||||
params: &RenderGlyphParams,
|
||||
glyph_bounds: Bounds<DevicePixels>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let bitmap_size = glyph_bounds.size;
|
||||
let subpixel_shift = params
|
||||
.subpixel_variant
|
||||
.map(|v| v as f32 / SUBPIXEL_VARIANTS as f32);
|
||||
let baseline_origin_x = subpixel_shift.x / params.scale_factor;
|
||||
let baseline_origin_y = subpixel_shift.y / params.scale_factor;
|
||||
|
||||
let transform = DWRITE_MATRIX {
|
||||
m11: params.scale_factor,
|
||||
m12: 0.0,
|
||||
m21: 0.0,
|
||||
m22: params.scale_factor,
|
||||
dx: 0.0,
|
||||
dy: 0.0,
|
||||
};
|
||||
|
||||
let font = &self.fonts[params.font_id.0];
|
||||
let glyph_id = [params.glyph_id.0 as u16];
|
||||
let advance = [glyph_bounds.size.width.0 as f32];
|
||||
let offset = [DWRITE_GLYPH_OFFSET {
|
||||
@@ -739,7 +945,7 @@ impl DirectWriteState {
|
||||
ascenderOffset: glyph_bounds.origin.y.0 as f32 / params.scale_factor,
|
||||
}];
|
||||
let glyph_run = DWRITE_GLYPH_RUN {
|
||||
fontFace: unsafe { std::mem::transmute_copy(&font_info.font_face) },
|
||||
fontFace: unsafe { std::mem::transmute_copy(&font.font_face) },
|
||||
fontEmSize: params.font_size.0,
|
||||
glyphCount: 1,
|
||||
glyphIndices: glyph_id.as_ptr(),
|
||||
@@ -749,160 +955,254 @@ impl DirectWriteState {
|
||||
bidiLevel: 0,
|
||||
};
|
||||
|
||||
// Add an extra pixel when the subpixel variant isn't zero to make room for anti-aliasing.
|
||||
let mut bitmap_size = glyph_bounds.size;
|
||||
if params.subpixel_variant.x > 0 {
|
||||
bitmap_size.width += DevicePixels(1);
|
||||
}
|
||||
if params.subpixel_variant.y > 0 {
|
||||
bitmap_size.height += DevicePixels(1);
|
||||
}
|
||||
let bitmap_size = bitmap_size;
|
||||
// todo: support formats other than COLR
|
||||
let color_enumerator = unsafe {
|
||||
self.components.factory.TranslateColorGlyphRun(
|
||||
Vector2::new(baseline_origin_x, baseline_origin_y),
|
||||
&glyph_run,
|
||||
None,
|
||||
DWRITE_GLYPH_IMAGE_FORMATS_COLR,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
Some(&transform),
|
||||
0,
|
||||
)
|
||||
}?;
|
||||
|
||||
let total_bytes;
|
||||
let bitmap_format;
|
||||
let render_target_property;
|
||||
let bitmap_width;
|
||||
let bitmap_height;
|
||||
let bitmap_stride;
|
||||
let bitmap_dpi;
|
||||
if params.is_emoji {
|
||||
total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize * 4;
|
||||
bitmap_format = &GUID_WICPixelFormat32bppPBGRA;
|
||||
render_target_property = get_render_target_property(
|
||||
DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
D2D1_ALPHA_MODE_PREMULTIPLIED,
|
||||
);
|
||||
bitmap_width = bitmap_size.width.0 as u32;
|
||||
bitmap_height = bitmap_size.height.0 as u32;
|
||||
bitmap_stride = bitmap_size.width.0 as u32 * 4;
|
||||
bitmap_dpi = 96.0;
|
||||
} else {
|
||||
total_bytes = bitmap_size.height.0 as usize * bitmap_size.width.0 as usize;
|
||||
bitmap_format = &GUID_WICPixelFormat8bppAlpha;
|
||||
render_target_property =
|
||||
get_render_target_property(DXGI_FORMAT_A8_UNORM, D2D1_ALPHA_MODE_STRAIGHT);
|
||||
bitmap_width = bitmap_size.width.0 as u32 * 2;
|
||||
bitmap_height = bitmap_size.height.0 as u32 * 2;
|
||||
bitmap_stride = bitmap_size.width.0 as u32;
|
||||
bitmap_dpi = 192.0;
|
||||
let mut glyph_layers = Vec::new();
|
||||
loop {
|
||||
let color_run = unsafe { color_enumerator.GetCurrentRun() }?;
|
||||
let color_run = unsafe { &*color_run };
|
||||
let image_format = color_run.glyphImageFormat & !DWRITE_GLYPH_IMAGE_FORMATS_TRUETYPE;
|
||||
if image_format == DWRITE_GLYPH_IMAGE_FORMATS_COLR {
|
||||
let color_analysis = unsafe {
|
||||
self.components.factory.CreateGlyphRunAnalysis(
|
||||
&color_run.Base.glyphRun as *const _,
|
||||
Some(&transform),
|
||||
DWRITE_RENDERING_MODE1_NATURAL_SYMMETRIC,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
DWRITE_GRID_FIT_MODE_DEFAULT,
|
||||
DWRITE_TEXT_ANTIALIAS_MODE_CLEARTYPE,
|
||||
baseline_origin_x,
|
||||
baseline_origin_y,
|
||||
)
|
||||
}?;
|
||||
|
||||
let color_bounds =
|
||||
unsafe { color_analysis.GetAlphaTextureBounds(DWRITE_TEXTURE_CLEARTYPE_3x1) }?;
|
||||
|
||||
let color_size = size(
|
||||
color_bounds.right - color_bounds.left,
|
||||
color_bounds.bottom - color_bounds.top,
|
||||
);
|
||||
if color_size.width > 0 && color_size.height > 0 {
|
||||
let mut alpha_data =
|
||||
vec![0u8; (color_size.width * color_size.height * 3) as usize];
|
||||
unsafe {
|
||||
color_analysis.CreateAlphaTexture(
|
||||
DWRITE_TEXTURE_CLEARTYPE_3x1,
|
||||
&color_bounds,
|
||||
&mut alpha_data,
|
||||
)
|
||||
}?;
|
||||
|
||||
let run_color = {
|
||||
let run_color = color_run.Base.runColor;
|
||||
Rgba {
|
||||
r: run_color.r,
|
||||
g: run_color.g,
|
||||
b: run_color.b,
|
||||
a: run_color.a,
|
||||
}
|
||||
};
|
||||
let bounds = bounds(point(color_bounds.left, color_bounds.top), color_size);
|
||||
let alpha_data = alpha_data
|
||||
.chunks_exact(3)
|
||||
.flat_map(|chunk| [chunk[0], chunk[1], chunk[2], 255])
|
||||
.collect::<Vec<_>>();
|
||||
glyph_layers.push(GlyphLayerTexture::new(
|
||||
&self.components.gpu_state,
|
||||
run_color,
|
||||
bounds,
|
||||
&alpha_data,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
||||
let has_next = unsafe { color_enumerator.MoveNext() }
|
||||
.map(|e| e.as_bool())
|
||||
.unwrap_or(false);
|
||||
if !has_next {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let bitmap_factory = self.components.bitmap_factory.resolve()?;
|
||||
unsafe {
|
||||
let bitmap = bitmap_factory.CreateBitmap(
|
||||
bitmap_width,
|
||||
bitmap_height,
|
||||
bitmap_format,
|
||||
WICBitmapCacheOnLoad,
|
||||
)?;
|
||||
let render_target = self
|
||||
.components
|
||||
.d2d1_factory
|
||||
.CreateWicBitmapRenderTarget(&bitmap, &render_target_property)?;
|
||||
let brush = render_target.CreateSolidColorBrush(&BRUSH_COLOR, None)?;
|
||||
let subpixel_shift = params
|
||||
.subpixel_variant
|
||||
.map(|v| v as f32 / SUBPIXEL_VARIANTS as f32);
|
||||
let baseline_origin = Vector2 {
|
||||
X: subpixel_shift.x / params.scale_factor,
|
||||
Y: subpixel_shift.y / params.scale_factor,
|
||||
let gpu_state = &self.components.gpu_state;
|
||||
let params_buffer = {
|
||||
let desc = D3D11_BUFFER_DESC {
|
||||
ByteWidth: std::mem::size_of::<GlyphLayerTextureParams>() as u32,
|
||||
Usage: D3D11_USAGE_DYNAMIC,
|
||||
BindFlags: D3D11_BIND_CONSTANT_BUFFER.0 as u32,
|
||||
CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32,
|
||||
MiscFlags: 0,
|
||||
StructureByteStride: 0,
|
||||
};
|
||||
|
||||
// This `cast()` action here should never fail since we are running on Win10+, and
|
||||
// ID2D1DeviceContext4 requires Win8+
|
||||
let render_target = render_target.cast::<ID2D1DeviceContext4>().unwrap();
|
||||
render_target.SetUnitMode(D2D1_UNIT_MODE_DIPS);
|
||||
render_target.SetDpi(
|
||||
bitmap_dpi * params.scale_factor,
|
||||
bitmap_dpi * params.scale_factor,
|
||||
);
|
||||
render_target.SetTextRenderingParams(&self.components.render_context.params);
|
||||
render_target.BeginDraw();
|
||||
let mut buffer = None;
|
||||
unsafe {
|
||||
gpu_state
|
||||
.device
|
||||
.CreateBuffer(&desc, None, Some(&mut buffer))
|
||||
}?;
|
||||
[buffer]
|
||||
};
|
||||
|
||||
if params.is_emoji {
|
||||
// WARN: only DWRITE_GLYPH_IMAGE_FORMATS_COLR has been tested
|
||||
let enumerator = self.components.factory.TranslateColorGlyphRun(
|
||||
baseline_origin,
|
||||
&glyph_run as _,
|
||||
None,
|
||||
DWRITE_GLYPH_IMAGE_FORMATS_COLR
|
||||
| DWRITE_GLYPH_IMAGE_FORMATS_SVG
|
||||
| DWRITE_GLYPH_IMAGE_FORMATS_PNG
|
||||
| DWRITE_GLYPH_IMAGE_FORMATS_JPEG
|
||||
| DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
None,
|
||||
let render_target_texture = {
|
||||
let mut texture = None;
|
||||
let desc = D3D11_TEXTURE2D_DESC {
|
||||
Width: bitmap_size.width.0 as u32,
|
||||
Height: bitmap_size.height.0 as u32,
|
||||
MipLevels: 1,
|
||||
ArraySize: 1,
|
||||
Format: DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
SampleDesc: DXGI_SAMPLE_DESC {
|
||||
Count: 1,
|
||||
Quality: 0,
|
||||
},
|
||||
Usage: D3D11_USAGE_DEFAULT,
|
||||
BindFlags: D3D11_BIND_RENDER_TARGET.0 as u32,
|
||||
CPUAccessFlags: 0,
|
||||
MiscFlags: 0,
|
||||
};
|
||||
unsafe {
|
||||
gpu_state
|
||||
.device
|
||||
.CreateTexture2D(&desc, None, Some(&mut texture))
|
||||
}?;
|
||||
texture.unwrap()
|
||||
};
|
||||
|
||||
let render_target_view = {
|
||||
let desc = D3D11_RENDER_TARGET_VIEW_DESC {
|
||||
Format: DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
ViewDimension: D3D11_RTV_DIMENSION_TEXTURE2D,
|
||||
Anonymous: D3D11_RENDER_TARGET_VIEW_DESC_0 {
|
||||
Texture2D: D3D11_TEX2D_RTV { MipSlice: 0 },
|
||||
},
|
||||
};
|
||||
let mut rtv = None;
|
||||
unsafe {
|
||||
gpu_state.device.CreateRenderTargetView(
|
||||
&render_target_texture,
|
||||
Some(&desc),
|
||||
Some(&mut rtv),
|
||||
)
|
||||
}?;
|
||||
[rtv]
|
||||
};
|
||||
|
||||
let staging_texture = {
|
||||
let mut texture = None;
|
||||
let desc = D3D11_TEXTURE2D_DESC {
|
||||
Width: bitmap_size.width.0 as u32,
|
||||
Height: bitmap_size.height.0 as u32,
|
||||
MipLevels: 1,
|
||||
ArraySize: 1,
|
||||
Format: DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
SampleDesc: DXGI_SAMPLE_DESC {
|
||||
Count: 1,
|
||||
Quality: 0,
|
||||
},
|
||||
Usage: D3D11_USAGE_STAGING,
|
||||
BindFlags: 0,
|
||||
CPUAccessFlags: D3D11_CPU_ACCESS_READ.0 as u32,
|
||||
MiscFlags: 0,
|
||||
};
|
||||
unsafe {
|
||||
gpu_state
|
||||
.device
|
||||
.CreateTexture2D(&desc, None, Some(&mut texture))
|
||||
}?;
|
||||
texture.unwrap()
|
||||
};
|
||||
|
||||
let device_context = &gpu_state.device_context;
|
||||
unsafe { device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP) };
|
||||
unsafe { device_context.VSSetShader(&gpu_state.vertex_shader, None) };
|
||||
unsafe { device_context.PSSetShader(&gpu_state.pixel_shader, None) };
|
||||
unsafe { device_context.VSSetConstantBuffers(0, Some(¶ms_buffer)) };
|
||||
unsafe { device_context.PSSetConstantBuffers(0, Some(¶ms_buffer)) };
|
||||
unsafe { device_context.OMSetRenderTargets(Some(&render_target_view), None) };
|
||||
unsafe { device_context.PSSetSamplers(0, Some(&gpu_state.sampler)) };
|
||||
unsafe { device_context.OMSetBlendState(&gpu_state.blend_state, None, 0xffffffff) };
|
||||
|
||||
for layer in glyph_layers {
|
||||
let params = GlyphLayerTextureParams {
|
||||
run_color: layer.run_color,
|
||||
bounds: layer.bounds,
|
||||
};
|
||||
unsafe {
|
||||
let mut dest = std::mem::zeroed();
|
||||
gpu_state.device_context.Map(
|
||||
params_buffer[0].as_ref().unwrap(),
|
||||
0,
|
||||
D3D11_MAP_WRITE_DISCARD,
|
||||
0,
|
||||
Some(&mut dest),
|
||||
)?;
|
||||
while enumerator.MoveNext().is_ok() {
|
||||
let Ok(color_glyph) = enumerator.GetCurrentRun() else {
|
||||
break;
|
||||
};
|
||||
let color_glyph = &*color_glyph;
|
||||
let brush_color = translate_color(&color_glyph.Base.runColor);
|
||||
brush.SetColor(&brush_color);
|
||||
match color_glyph.glyphImageFormat {
|
||||
DWRITE_GLYPH_IMAGE_FORMATS_PNG
|
||||
| DWRITE_GLYPH_IMAGE_FORMATS_JPEG
|
||||
| DWRITE_GLYPH_IMAGE_FORMATS_PREMULTIPLIED_B8G8R8A8 => render_target
|
||||
.DrawColorBitmapGlyphRun(
|
||||
color_glyph.glyphImageFormat,
|
||||
baseline_origin,
|
||||
&color_glyph.Base.glyphRun,
|
||||
color_glyph.measuringMode,
|
||||
D2D1_COLOR_BITMAP_GLYPH_SNAP_OPTION_DEFAULT,
|
||||
),
|
||||
DWRITE_GLYPH_IMAGE_FORMATS_SVG => render_target.DrawSvgGlyphRun(
|
||||
baseline_origin,
|
||||
&color_glyph.Base.glyphRun,
|
||||
&brush,
|
||||
None,
|
||||
color_glyph.Base.paletteIndex as u32,
|
||||
color_glyph.measuringMode,
|
||||
),
|
||||
_ => render_target.DrawGlyphRun(
|
||||
baseline_origin,
|
||||
&color_glyph.Base.glyphRun,
|
||||
Some(color_glyph.Base.glyphRunDescription as *const _),
|
||||
&brush,
|
||||
color_glyph.measuringMode,
|
||||
),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
render_target.DrawGlyphRun(
|
||||
baseline_origin,
|
||||
&glyph_run,
|
||||
None,
|
||||
&brush,
|
||||
DWRITE_MEASURING_MODE_NATURAL,
|
||||
);
|
||||
}
|
||||
render_target.EndDraw(None, None)?;
|
||||
std::ptr::copy_nonoverlapping(¶ms as *const _, dest.pData as *mut _, 1);
|
||||
gpu_state
|
||||
.device_context
|
||||
.Unmap(params_buffer[0].as_ref().unwrap(), 0);
|
||||
};
|
||||
|
||||
let mut raw_data = vec![0u8; total_bytes];
|
||||
if params.is_emoji {
|
||||
bitmap.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?;
|
||||
// Convert from BGRA with premultiplied alpha to BGRA with straight alpha.
|
||||
for pixel in raw_data.chunks_exact_mut(4) {
|
||||
let a = pixel[3] as f32 / 255.;
|
||||
pixel[0] = (pixel[0] as f32 / a) as u8;
|
||||
pixel[1] = (pixel[1] as f32 / a) as u8;
|
||||
pixel[2] = (pixel[2] as f32 / a) as u8;
|
||||
}
|
||||
} else {
|
||||
let scaler = bitmap_factory.CreateBitmapScaler()?;
|
||||
scaler.Initialize(
|
||||
&bitmap,
|
||||
bitmap_size.width.0 as u32,
|
||||
bitmap_size.height.0 as u32,
|
||||
WICBitmapInterpolationModeHighQualityCubic,
|
||||
)?;
|
||||
scaler.CopyPixels(std::ptr::null() as _, bitmap_stride, &mut raw_data)?;
|
||||
}
|
||||
Ok((bitmap_size, raw_data))
|
||||
let texture = [Some(layer.texture_view)];
|
||||
unsafe { device_context.PSSetShaderResources(0, Some(&texture)) };
|
||||
|
||||
let viewport = [D3D11_VIEWPORT {
|
||||
TopLeftX: layer.bounds.origin.x as f32,
|
||||
TopLeftY: layer.bounds.origin.y as f32,
|
||||
Width: layer.bounds.size.width as f32,
|
||||
Height: layer.bounds.size.height as f32,
|
||||
MinDepth: 0.0,
|
||||
MaxDepth: 1.0,
|
||||
}];
|
||||
unsafe { device_context.RSSetViewports(Some(&viewport)) };
|
||||
|
||||
unsafe { device_context.Draw(4, 0) };
|
||||
}
|
||||
|
||||
unsafe { device_context.CopyResource(&staging_texture, &render_target_texture) };
|
||||
|
||||
let mapped_data = {
|
||||
let mut mapped_data = D3D11_MAPPED_SUBRESOURCE::default();
|
||||
unsafe {
|
||||
device_context.Map(
|
||||
&staging_texture,
|
||||
0,
|
||||
D3D11_MAP_READ,
|
||||
0,
|
||||
Some(&mut mapped_data),
|
||||
)
|
||||
}?;
|
||||
mapped_data
|
||||
};
|
||||
let mut rasterized =
|
||||
vec![0u8; (bitmap_size.width.0 as u32 * bitmap_size.height.0 as u32 * 4) as usize];
|
||||
|
||||
for y in 0..bitmap_size.height.0 as usize {
|
||||
let width = bitmap_size.width.0 as usize;
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping::<u8>(
|
||||
(mapped_data.pData as *const u8).byte_add(mapped_data.RowPitch as usize * y),
|
||||
rasterized
|
||||
.as_mut_ptr()
|
||||
.byte_add(width * y * std::mem::size_of::<u32>()),
|
||||
width * std::mem::size_of::<u32>(),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
Ok(rasterized)
|
||||
}
|
||||
|
||||
fn get_typographic_bounds(&self, font_id: FontId, glyph_id: GlyphId) -> Result<Bounds<f32>> {
|
||||
@@ -976,6 +1276,84 @@ impl Drop for DirectWriteState {
|
||||
}
|
||||
}
|
||||
|
||||
struct GlyphLayerTexture {
|
||||
run_color: Rgba,
|
||||
bounds: Bounds<i32>,
|
||||
texture_view: ID3D11ShaderResourceView,
|
||||
// holding on to the texture to not RAII drop it
|
||||
_texture: ID3D11Texture2D,
|
||||
}
|
||||
|
||||
impl GlyphLayerTexture {
|
||||
pub fn new(
|
||||
gpu_state: &GPUState,
|
||||
run_color: Rgba,
|
||||
bounds: Bounds<i32>,
|
||||
alpha_data: &[u8],
|
||||
) -> Result<Self> {
|
||||
let texture_size = bounds.size;
|
||||
|
||||
let desc = D3D11_TEXTURE2D_DESC {
|
||||
Width: texture_size.width as u32,
|
||||
Height: texture_size.height as u32,
|
||||
MipLevels: 1,
|
||||
ArraySize: 1,
|
||||
Format: DXGI_FORMAT_R8G8B8A8_UNORM,
|
||||
SampleDesc: DXGI_SAMPLE_DESC {
|
||||
Count: 1,
|
||||
Quality: 0,
|
||||
},
|
||||
Usage: D3D11_USAGE_DEFAULT,
|
||||
BindFlags: D3D11_BIND_SHADER_RESOURCE.0 as u32,
|
||||
CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32,
|
||||
MiscFlags: 0,
|
||||
};
|
||||
|
||||
let texture = {
|
||||
let mut texture: Option<ID3D11Texture2D> = None;
|
||||
unsafe {
|
||||
gpu_state
|
||||
.device
|
||||
.CreateTexture2D(&desc, None, Some(&mut texture))?
|
||||
};
|
||||
texture.unwrap()
|
||||
};
|
||||
let texture_view = {
|
||||
let mut view: Option<ID3D11ShaderResourceView> = None;
|
||||
unsafe {
|
||||
gpu_state
|
||||
.device
|
||||
.CreateShaderResourceView(&texture, None, Some(&mut view))?
|
||||
};
|
||||
view.unwrap()
|
||||
};
|
||||
|
||||
unsafe {
|
||||
gpu_state.device_context.UpdateSubresource(
|
||||
&texture,
|
||||
0,
|
||||
None,
|
||||
alpha_data.as_ptr() as _,
|
||||
(texture_size.width * 4) as u32,
|
||||
0,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(GlyphLayerTexture {
|
||||
run_color,
|
||||
bounds,
|
||||
texture_view,
|
||||
_texture: texture,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
struct GlyphLayerTextureParams {
|
||||
bounds: Bounds<i32>,
|
||||
run_color: Rgba,
|
||||
}
|
||||
|
||||
struct TextRendererWrapper(pub IDWriteTextRenderer);
|
||||
|
||||
impl TextRendererWrapper {
|
||||
@@ -1470,16 +1848,6 @@ fn get_name(string: IDWriteLocalizedStrings, locale: &str) -> Result<String> {
|
||||
Ok(String::from_utf16_lossy(&name_vec[..name_length]))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn translate_color(color: &DWRITE_COLOR_F) -> D2D1_COLOR_F {
|
||||
D2D1_COLOR_F {
|
||||
r: color.r,
|
||||
g: color.g,
|
||||
b: color.b,
|
||||
a: color.a,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_system_ui_font_name() -> SharedString {
|
||||
unsafe {
|
||||
let mut info: LOGFONTW = std::mem::zeroed();
|
||||
@@ -1504,24 +1872,6 @@ fn get_system_ui_font_name() -> SharedString {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_render_target_property(
|
||||
pixel_format: DXGI_FORMAT,
|
||||
alpha_mode: D2D1_ALPHA_MODE,
|
||||
) -> D2D1_RENDER_TARGET_PROPERTIES {
|
||||
D2D1_RENDER_TARGET_PROPERTIES {
|
||||
r#type: D2D1_RENDER_TARGET_TYPE_DEFAULT,
|
||||
pixelFormat: D2D1_PIXEL_FORMAT {
|
||||
format: pixel_format,
|
||||
alphaMode: alpha_mode,
|
||||
},
|
||||
dpiX: 96.0,
|
||||
dpiY: 96.0,
|
||||
usage: D2D1_RENDER_TARGET_USAGE_NONE,
|
||||
minLevel: D2D1_FEATURE_LEVEL_DEFAULT,
|
||||
}
|
||||
}
|
||||
|
||||
// One would think that with newer DirectWrite method: IDWriteFontFace4::GetGlyphImageFormats
|
||||
// but that doesn't seem to work for some glyphs, say ❤
|
||||
fn is_color_glyph(
|
||||
@@ -1561,12 +1911,6 @@ fn is_color_glyph(
|
||||
}
|
||||
|
||||
const DEFAULT_LOCALE_NAME: PCWSTR = windows::core::w!("en-US");
|
||||
const BRUSH_COLOR: D2D1_COLOR_F = D2D1_COLOR_F {
|
||||
r: 1.0,
|
||||
g: 1.0,
|
||||
b: 1.0,
|
||||
a: 1.0,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -7,7 +7,7 @@ use windows::Win32::Graphics::{
|
||||
D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView,
|
||||
ID3D11Texture2D,
|
||||
},
|
||||
Dxgi::Common::{DXGI_FORMAT_A8_UNORM, DXGI_FORMAT_B8G8R8A8_UNORM, DXGI_SAMPLE_DESC},
|
||||
Dxgi::Common::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -167,7 +167,7 @@ impl DirectXAtlasState {
|
||||
let bytes_per_pixel;
|
||||
match kind {
|
||||
AtlasTextureKind::Monochrome => {
|
||||
pixel_format = DXGI_FORMAT_A8_UNORM;
|
||||
pixel_format = DXGI_FORMAT_R8_UNORM;
|
||||
bind_flag = D3D11_BIND_SHADER_RESOURCE;
|
||||
bytes_per_pixel = 1;
|
||||
}
|
||||
|
||||
@@ -42,8 +42,8 @@ pub(crate) struct DirectXRenderer {
|
||||
pub(crate) struct DirectXDevices {
|
||||
adapter: IDXGIAdapter1,
|
||||
dxgi_factory: IDXGIFactory6,
|
||||
device: ID3D11Device,
|
||||
device_context: ID3D11DeviceContext,
|
||||
pub(crate) device: ID3D11Device,
|
||||
pub(crate) device_context: ID3D11DeviceContext,
|
||||
dxgi_device: Option<IDXGIDevice>,
|
||||
}
|
||||
|
||||
@@ -88,8 +88,11 @@ struct DirectComposition {
|
||||
|
||||
impl DirectXDevices {
|
||||
pub(crate) fn new(disable_direct_composition: bool) -> Result<ManuallyDrop<Self>> {
|
||||
let dxgi_factory = get_dxgi_factory().context("Creating DXGI factory")?;
|
||||
let adapter = get_adapter(&dxgi_factory).context("Getting DXGI adapter")?;
|
||||
let debug_layer_available = check_debug_layer_available();
|
||||
let dxgi_factory =
|
||||
get_dxgi_factory(debug_layer_available).context("Creating DXGI factory")?;
|
||||
let adapter =
|
||||
get_adapter(&dxgi_factory, debug_layer_available).context("Getting DXGI adapter")?;
|
||||
let (device, device_context) = {
|
||||
let mut device: Option<ID3D11Device> = None;
|
||||
let mut context: Option<ID3D11DeviceContext> = None;
|
||||
@@ -99,6 +102,7 @@ impl DirectXDevices {
|
||||
Some(&mut device),
|
||||
Some(&mut context),
|
||||
Some(&mut feature_level),
|
||||
debug_layer_available,
|
||||
)
|
||||
.context("Creating Direct3D device")?;
|
||||
match feature_level {
|
||||
@@ -183,7 +187,7 @@ impl DirectXRenderer {
|
||||
self.resources.viewport[0].Width,
|
||||
self.resources.viewport[0].Height,
|
||||
],
|
||||
..Default::default()
|
||||
_pad: 0,
|
||||
}],
|
||||
)?;
|
||||
unsafe {
|
||||
@@ -977,25 +981,34 @@ impl Drop for DirectXResources {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_dxgi_factory() -> Result<IDXGIFactory6> {
|
||||
fn check_debug_layer_available() -> bool {
|
||||
#[cfg(debug_assertions)]
|
||||
let factory_flag = if unsafe { DXGIGetDebugInterface1::<IDXGIInfoQueue>(0) }
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
unsafe { DXGIGetDebugInterface1::<IDXGIInfoQueue>(0) }
|
||||
.log_err()
|
||||
.is_some()
|
||||
}
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_dxgi_factory(debug_layer_available: bool) -> Result<IDXGIFactory6> {
|
||||
let factory_flag = if debug_layer_available {
|
||||
DXGI_CREATE_FACTORY_DEBUG
|
||||
} else {
|
||||
#[cfg(debug_assertions)]
|
||||
log::warn!(
|
||||
"Failed to get DXGI debug interface. DirectX debugging features will be disabled."
|
||||
);
|
||||
DXGI_CREATE_FACTORY_FLAGS::default()
|
||||
};
|
||||
#[cfg(not(debug_assertions))]
|
||||
let factory_flag = DXGI_CREATE_FACTORY_FLAGS::default();
|
||||
unsafe { Ok(CreateDXGIFactory2(factory_flag)?) }
|
||||
}
|
||||
|
||||
fn get_adapter(dxgi_factory: &IDXGIFactory6) -> Result<IDXGIAdapter1> {
|
||||
fn get_adapter(dxgi_factory: &IDXGIFactory6, debug_layer_available: bool) -> Result<IDXGIAdapter1> {
|
||||
for adapter_index in 0.. {
|
||||
let adapter: IDXGIAdapter1 = unsafe {
|
||||
dxgi_factory
|
||||
@@ -1009,7 +1022,10 @@ fn get_adapter(dxgi_factory: &IDXGIFactory6) -> Result<IDXGIAdapter1> {
|
||||
}
|
||||
// Check to see whether the adapter supports Direct3D 11, but don't
|
||||
// create the actual device yet.
|
||||
if get_device(&adapter, None, None, None).log_err().is_some() {
|
||||
if get_device(&adapter, None, None, None, debug_layer_available)
|
||||
.log_err()
|
||||
.is_some()
|
||||
{
|
||||
return Ok(adapter);
|
||||
}
|
||||
}
|
||||
@@ -1022,11 +1038,13 @@ fn get_device(
|
||||
device: Option<*mut Option<ID3D11Device>>,
|
||||
context: Option<*mut Option<ID3D11DeviceContext>>,
|
||||
feature_level: Option<*mut D3D_FEATURE_LEVEL>,
|
||||
debug_layer_available: bool,
|
||||
) -> Result<()> {
|
||||
#[cfg(debug_assertions)]
|
||||
let device_flags = D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG;
|
||||
#[cfg(not(debug_assertions))]
|
||||
let device_flags = D3D11_CREATE_DEVICE_BGRA_SUPPORT;
|
||||
let device_flags = if debug_layer_available {
|
||||
D3D11_CREATE_DEVICE_BGRA_SUPPORT | D3D11_CREATE_DEVICE_DEBUG
|
||||
} else {
|
||||
D3D11_CREATE_DEVICE_BGRA_SUPPORT
|
||||
};
|
||||
unsafe {
|
||||
D3D11CreateDevice(
|
||||
adapter,
|
||||
@@ -1423,7 +1441,7 @@ fn report_live_objects(device: &ID3D11Device) -> Result<()> {
|
||||
|
||||
const BUFFER_COUNT: usize = 3;
|
||||
|
||||
mod shader_resources {
|
||||
pub(crate) mod shader_resources {
|
||||
use anyhow::Result;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -1436,7 +1454,7 @@ mod shader_resources {
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub(super) enum ShaderModule {
|
||||
pub(crate) enum ShaderModule {
|
||||
Quad,
|
||||
Shadow,
|
||||
Underline,
|
||||
@@ -1444,15 +1462,16 @@ mod shader_resources {
|
||||
PathSprite,
|
||||
MonochromeSprite,
|
||||
PolychromeSprite,
|
||||
EmojiRasterization,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub(super) enum ShaderTarget {
|
||||
pub(crate) enum ShaderTarget {
|
||||
Vertex,
|
||||
Fragment,
|
||||
}
|
||||
|
||||
pub(super) struct RawShaderBytes<'t> {
|
||||
pub(crate) struct RawShaderBytes<'t> {
|
||||
inner: &'t [u8],
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -1460,7 +1479,7 @@ mod shader_resources {
|
||||
}
|
||||
|
||||
impl<'t> RawShaderBytes<'t> {
|
||||
pub(super) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> {
|
||||
pub(crate) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> {
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
Ok(Self::from_bytes(module, target))
|
||||
@@ -1478,7 +1497,7 @@ mod shader_resources {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn as_bytes(&'t self) -> &'t [u8] {
|
||||
pub(crate) fn as_bytes(&'t self) -> &'t [u8] {
|
||||
self.inner
|
||||
}
|
||||
|
||||
@@ -1513,6 +1532,10 @@ mod shader_resources {
|
||||
ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES,
|
||||
ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES,
|
||||
},
|
||||
ShaderModule::EmojiRasterization => match target {
|
||||
ShaderTarget::Vertex => EMOJI_RASTERIZATION_VERTEX_BYTES,
|
||||
ShaderTarget::Fragment => EMOJI_RASTERIZATION_FRAGMENT_BYTES,
|
||||
},
|
||||
};
|
||||
Self { inner: bytes }
|
||||
}
|
||||
@@ -1521,6 +1544,12 @@ mod shader_resources {
|
||||
#[cfg(debug_assertions)]
|
||||
pub(super) fn build_shader_blob(entry: ShaderModule, target: ShaderTarget) -> Result<ID3DBlob> {
|
||||
unsafe {
|
||||
let shader_name = if matches!(entry, ShaderModule::EmojiRasterization) {
|
||||
"color_text_raster.hlsl"
|
||||
} else {
|
||||
"shaders.hlsl"
|
||||
};
|
||||
|
||||
let entry = format!(
|
||||
"{}_{}\0",
|
||||
entry.as_str(),
|
||||
@@ -1537,7 +1566,7 @@ mod shader_resources {
|
||||
let mut compile_blob = None;
|
||||
let mut error_blob = None;
|
||||
let shader_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("src/platform/windows/shaders.hlsl")
|
||||
.join(&format!("src/platform/windows/{}", shader_name))
|
||||
.canonicalize()?;
|
||||
|
||||
let entry_point = PCSTR::from_raw(entry.as_ptr());
|
||||
@@ -1583,6 +1612,7 @@ mod shader_resources {
|
||||
ShaderModule::PathSprite => "path_sprite",
|
||||
ShaderModule::MonochromeSprite => "monochrome_sprite",
|
||||
ShaderModule::PolychromeSprite => "polychrome_sprite",
|
||||
ShaderModule::EmojiRasterization => "emoji_rasterization",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ pub(crate) struct WindowsPlatform {
|
||||
drop_target_helper: IDropTargetHelper,
|
||||
validation_number: usize,
|
||||
main_thread_id_win32: u32,
|
||||
disable_direct_composition: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct WindowsPlatformState {
|
||||
@@ -93,14 +94,18 @@ impl WindowsPlatform {
|
||||
main_thread_id_win32,
|
||||
validation_number,
|
||||
));
|
||||
let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION)
|
||||
.is_ok_and(|value| value == "true" || value == "1");
|
||||
let background_executor = BackgroundExecutor::new(dispatcher.clone());
|
||||
let foreground_executor = ForegroundExecutor::new(dispatcher);
|
||||
let directx_devices = DirectXDevices::new(disable_direct_composition)
|
||||
.context("Unable to init directx devices.")?;
|
||||
let bitmap_factory = ManuallyDrop::new(unsafe {
|
||||
CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER)
|
||||
.context("Error creating bitmap factory.")?
|
||||
});
|
||||
let text_system = Arc::new(
|
||||
DirectWriteTextSystem::new(&bitmap_factory)
|
||||
DirectWriteTextSystem::new(&directx_devices, &bitmap_factory)
|
||||
.context("Error creating DirectWriteTextSystem")?,
|
||||
);
|
||||
let drop_target_helper: IDropTargetHelper = unsafe {
|
||||
@@ -120,6 +125,7 @@ impl WindowsPlatform {
|
||||
background_executor,
|
||||
foreground_executor,
|
||||
text_system,
|
||||
disable_direct_composition,
|
||||
windows_version,
|
||||
bitmap_factory,
|
||||
drop_target_helper,
|
||||
@@ -184,6 +190,7 @@ impl WindowsPlatform {
|
||||
validation_number: self.validation_number,
|
||||
main_receiver: self.main_receiver.clone(),
|
||||
main_thread_id_win32: self.main_thread_id_win32,
|
||||
disable_direct_composition: self.disable_direct_composition,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,6 +722,7 @@ pub(crate) struct WindowCreationInfo {
|
||||
pub(crate) validation_number: usize,
|
||||
pub(crate) main_receiver: flume::Receiver<Runnable>,
|
||||
pub(crate) main_thread_id_win32: u32,
|
||||
pub(crate) disable_direct_composition: bool,
|
||||
}
|
||||
|
||||
fn open_target(target: &str) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
cbuffer GlobalParams: register(b0) {
|
||||
float2 global_viewport_size;
|
||||
uint2 _global_pad;
|
||||
uint2 _pad;
|
||||
};
|
||||
|
||||
Texture2D<float4> t_sprite: register(t0);
|
||||
@@ -1069,6 +1069,7 @@ struct MonochromeSpriteFragmentInput {
|
||||
float4 position: SV_Position;
|
||||
float2 tile_position: POSITION;
|
||||
nointerpolation float4 color: COLOR;
|
||||
float4 clip_distance: SV_ClipDistance;
|
||||
};
|
||||
|
||||
StructuredBuffer<MonochromeSprite> mono_sprites: register(t1);
|
||||
@@ -1091,10 +1092,8 @@ MonochromeSpriteVertexOutput monochrome_sprite_vertex(uint vertex_id: SV_VertexI
|
||||
}
|
||||
|
||||
float4 monochrome_sprite_fragment(MonochromeSpriteFragmentInput input): SV_Target {
|
||||
float4 sample = t_sprite.Sample(s_sprite, input.tile_position);
|
||||
float4 color = input.color;
|
||||
color.a *= sample.a;
|
||||
return color;
|
||||
float sample = t_sprite.Sample(s_sprite, input.tile_position).r;
|
||||
return float4(input.color.rgb, input.color.a * sample);
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -360,6 +360,7 @@ impl WindowsWindow {
|
||||
validation_number,
|
||||
main_receiver,
|
||||
main_thread_id_win32,
|
||||
disable_direct_composition,
|
||||
} = creation_info;
|
||||
let classname = register_wnd_class(icon);
|
||||
let hide_title_bar = params
|
||||
@@ -375,8 +376,6 @@ impl WindowsWindow {
|
||||
.map(|title| title.as_ref())
|
||||
.unwrap_or(""),
|
||||
);
|
||||
let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION)
|
||||
.is_ok_and(|value| value == "true" || value == "1");
|
||||
|
||||
let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp {
|
||||
(WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0))
|
||||
|
||||
@@ -23,6 +23,7 @@ futures.workspace = true
|
||||
http.workspace = true
|
||||
http-body.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
url.workspace = true
|
||||
|
||||
@@ -9,12 +9,10 @@ pub use http::{self, Method, Request, Response, StatusCode, Uri};
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
use http::request::Builder;
|
||||
use parking_lot::Mutex;
|
||||
#[cfg(feature = "test-support")]
|
||||
use std::fmt;
|
||||
use std::{
|
||||
any::type_name,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use std::{any::type_name, sync::Arc};
|
||||
pub use url::Url;
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
|
||||
@@ -86,6 +84,11 @@ pub trait HttpClient: 'static + Send + Sync {
|
||||
}
|
||||
|
||||
fn proxy(&self) -> Option<&Url>;
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
panic!("called as_fake on {}", type_name::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`HttpClient`] that may have a proxy.
|
||||
@@ -132,6 +135,11 @@ impl HttpClient for HttpClientWithProxy {
|
||||
fn type_name(&self) -> &'static str {
|
||||
self.client.type_name()
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
self.client.as_fake()
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for Arc<HttpClientWithProxy> {
|
||||
@@ -153,6 +161,11 @@ impl HttpClient for Arc<HttpClientWithProxy> {
|
||||
fn type_name(&self) -> &'static str {
|
||||
self.client.type_name()
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
self.client.as_fake()
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`HttpClient`] that has a base URL.
|
||||
@@ -199,20 +212,13 @@ impl HttpClientWithUrl {
|
||||
|
||||
/// Returns the base URL.
|
||||
pub fn base_url(&self) -> String {
|
||||
self.base_url
|
||||
.lock()
|
||||
.map_or_else(|_| Default::default(), |url| url.clone())
|
||||
self.base_url.lock().clone()
|
||||
}
|
||||
|
||||
/// Sets the base URL.
|
||||
pub fn set_base_url(&self, base_url: impl Into<String>) {
|
||||
let base_url = base_url.into();
|
||||
self.base_url
|
||||
.lock()
|
||||
.map(|mut url| {
|
||||
*url = base_url;
|
||||
})
|
||||
.ok();
|
||||
*self.base_url.lock() = base_url;
|
||||
}
|
||||
|
||||
/// Builds a URL using the given path.
|
||||
@@ -288,6 +294,11 @@ impl HttpClient for Arc<HttpClientWithUrl> {
|
||||
fn type_name(&self) -> &'static str {
|
||||
self.client.type_name()
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
self.client.as_fake()
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for HttpClientWithUrl {
|
||||
@@ -309,6 +320,11 @@ impl HttpClient for HttpClientWithUrl {
|
||||
fn type_name(&self) -> &'static str {
|
||||
self.client.type_name()
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
self.client.as_fake()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_proxy_from_env() -> Option<Url> {
|
||||
@@ -360,10 +376,15 @@ impl HttpClient for BlockedHttpClient {
|
||||
fn type_name(&self) -> &'static str {
|
||||
type_name::<Self>()
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
panic!("called as_fake on {}", type_name::<Self>())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
type FakeHttpHandler = Box<
|
||||
type FakeHttpHandler = Arc<
|
||||
dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>>
|
||||
+ Send
|
||||
+ Sync
|
||||
@@ -372,7 +393,7 @@ type FakeHttpHandler = Box<
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
pub struct FakeHttpClient {
|
||||
handler: FakeHttpHandler,
|
||||
handler: Mutex<Option<FakeHttpHandler>>,
|
||||
user_agent: HeaderValue,
|
||||
}
|
||||
|
||||
@@ -387,7 +408,7 @@ impl FakeHttpClient {
|
||||
base_url: Mutex::new("http://test.example".into()),
|
||||
client: HttpClientWithProxy {
|
||||
client: Arc::new(Self {
|
||||
handler: Box::new(move |req| Box::pin(handler(req))),
|
||||
handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))),
|
||||
user_agent: HeaderValue::from_static(type_name::<Self>()),
|
||||
}),
|
||||
proxy: None,
|
||||
@@ -412,6 +433,18 @@ impl FakeHttpClient {
|
||||
.unwrap())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn replace_handler<Fut, F>(&self, new_handler: F)
|
||||
where
|
||||
Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static,
|
||||
F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static,
|
||||
{
|
||||
let mut handler = self.handler.lock();
|
||||
let old_handler = handler.take().unwrap();
|
||||
*handler = Some(Arc::new(move |req| {
|
||||
Box::pin(new_handler(old_handler.clone(), req))
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "test-support")]
|
||||
@@ -427,7 +460,7 @@ impl HttpClient for FakeHttpClient {
|
||||
&self,
|
||||
req: Request<AsyncBody>,
|
||||
) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
|
||||
let future = (self.handler)(req);
|
||||
let future = (self.handler.lock().as_ref().unwrap())(req);
|
||||
future
|
||||
}
|
||||
|
||||
@@ -442,4 +475,8 @@ impl HttpClient for FakeHttpClient {
|
||||
fn type_name(&self) -> &'static str {
|
||||
type_name::<Self>()
|
||||
}
|
||||
|
||||
fn as_fake(&self) -> &FakeHttpClient {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||