Compare commits

..

207 Commits

Author SHA1 Message Date
Max Brunsfeld
4a3d56749e Remove debugging code, commented code 2025-07-30 16:04:59 -07:00
Max Brunsfeld
3519e8fd7c Restore text example 2025-07-30 15:59:47 -07:00
Max Brunsfeld
b3372e7eac Read direct_composition env var once, pass it everywhere 2025-07-30 15:51:26 -07:00
Max Brunsfeld
4f7bb14acf Merge branch 'windows/dx11' into windows/remove-d2d 2025-07-30 15:51:03 -07:00
Kate
99c5b72b3d finally the font looks nice 2025-07-31 00:00:35 +02:00
Junkui Zhang
c995dd2016 more context 2025-07-31 01:18:29 +08:00
Junkui Zhang
80be0e29b9 add more context for errors 2025-07-31 01:15:39 +08:00
Junkui Zhang
b75f6e2210 use if-else 2025-07-31 00:59:46 +08:00
Junkui Zhang
b8f85be372 use log::warn instead 2025-07-31 00:58:23 +08:00
Junkui Zhang
34e433ad90 check debug layer before creating 2025-07-31 00:05:56 +08:00
Junkui Zhang
6a91ac26d7 remove unneeded change 2025-07-30 23:29:32 +08:00
Junkui Zhang
345fd526fc log the feature level we are using 2025-07-30 23:17:15 +08:00
Junkui Zhang
7c8074ce5c update feature level 2025-07-30 22:56:28 +08:00
Junkui Zhang
0a0803e2a7 remove enable-renderdoc 2025-07-30 18:22:30 +08:00
Junkui Zhang
8c8b91470a remove unused 2025-07-30 18:10:04 +08:00
Junkui Zhang
98692cc928 fix linux 2025-07-30 17:42:47 +08:00
Junkui Zhang
d194bf4f52 QOL improvement when device lost happens 2025-07-30 17:36:14 +08:00
Junkui Zhang
cc763729a0 add GPUI_DISABLE_DIRECT_COMPOSITION env 2025-07-30 14:53:36 +08:00
Junkui Zhang
e370c3d601 misc 2025-07-30 09:29:09 +08:00
Kate
1d60984cb6 get the blending state a bit closer to vulkan
still looks bad :(
2025-07-29 21:41:11 +02:00
Junkui Zhang
eb3bb95c91 fix error msg 2025-07-29 22:45:58 +08:00
Junkui Zhang
554b36fd3c fix where fxc.exe 2025-07-29 22:44:05 +08:00
Junkui Zhang
89a863d012 update workspace-hack 2025-07-29 20:01:39 +08:00
Junkui Zhang
441731de2e fix build.rs 2025-07-29 19:59:04 +08:00
Junkui Zhang
ead7a1e1f0 remove blade 2025-07-29 19:54:34 +08:00
Junkui Zhang
73eaee8f6f use none instead of stretch 2025-07-29 17:44:54 +08:00
Junkui Zhang
98f31172ab fix atlas sometime fails 2025-07-29 16:56:01 +08:00
Junkui Zhang
181f324473 fix device lost 2025-07-29 16:50:10 +08:00
Junkui Zhang
a1f03ee42c log unknown vendor id 2025-07-29 16:03:59 +08:00
Junkui Zhang
741b38f906 remove unused repr(c) 2025-07-29 14:45:28 +08:00
Junkui Zhang
599b82fc9d remove unused 2025-07-29 14:41:52 +08:00
Junkui Zhang
64b3b050e3 fix 2025-07-29 14:24:43 +08:00
Junkui Zhang
62d1b7e36f fix PathRasterization pipeline 2025-07-29 14:05:38 +08:00
Junkui Zhang
d7b14d8dc5 rename pipeline 2025-07-29 13:24:56 +08:00
Junkui Zhang
ce67ce1482 Revert "Use pre-multiplied alpha for path rasterization"
This reverts commit 8eea9aad40.
2025-07-29 13:07:00 +08:00
Max Brunsfeld
8eea9aad40 Use pre-multiplied alpha for path rasterization 2025-07-28 17:52:55 -07:00
Max Brunsfeld
e9697e4639 Start work on doing path MSAA using intermediate texture 2025-07-28 17:45:35 -07:00
Max Brunsfeld
92b0a7e760 Merge branch 'main' into windows/dx11 2025-07-28 15:41:04 -07:00
Kate
d96bafb1e5 Merge branch 'windows/dx11' into windows/remove-d2d 2025-07-28 14:55:51 +02:00
Junkui Zhang
5f3a1bdbd1 trigger ci 2025-07-28 19:34:15 +08:00
Junkui Zhang
1ee81a507b bundle ags.dll 2025-07-28 18:50:25 +08:00
Junkui Zhang
89d34e1513 clippy 2025-07-28 17:44:40 +08:00
Junkui Zhang
a9058346bf remove static linking of ags 2025-07-28 17:28:58 +08:00
Junkui Zhang
ac1ea0f96d revert idle 2025-07-28 17:10:17 +08:00
Junkui Zhang
0065e5fd76 handle WM_DEVICECHANGE 2025-07-28 17:04:28 +08:00
Junkui Zhang
ca6aa25d1e remove walkaround for close animation 2025-07-28 14:59:30 +08:00
Junkui Zhang
6964cecc14 ensure app is idle 2025-07-27 17:51:17 +08:00
Junkui Zhang
63daf44693 remove debug print 2025-07-27 17:51:17 +08:00
Junkui Zhang
4de2ebf954 acctually enable vsync 2025-07-27 17:51:17 +08:00
Junkui Zhang
3277640f55 fix 2025-07-27 17:51:17 +08:00
Junkui Zhang
d192ac6b7f ags is not typo 2025-07-27 17:51:17 +08:00
Junkui Zhang
c67ddd7572 AGS is not typo 2025-07-27 17:51:17 +08:00
Junkui Zhang
b54eaecbbc clippy 2025-07-27 17:51:17 +08:00
Junkui Zhang
2744e6cb65 Revert "initial removal attempt"
This reverts commit 6928488aad.
2025-07-27 17:51:17 +08:00
Junkui Zhang
18937f5756 Revert "make it not crash"
This reverts commit a7e34ab0bc.
2025-07-27 17:51:17 +08:00
Junkui Zhang
347b863ac6 Revert "more fixes and debugging"
This reverts commit 2fb31a9157.
2025-07-27 17:51:17 +08:00
Junkui Zhang
9d8ef8156d Revert "Translate rasterized glyphs from texture to bitmap"
This reverts commit 6fc8d7746f.
2025-07-27 17:51:17 +08:00
Junkui Zhang
9dbbee0334 Revert "Add emojis to text example"
This reverts commit 34d5926ebd.
2025-07-27 17:51:17 +08:00
Junkui Zhang
32f2505fbf use ? 2025-07-27 17:51:17 +08:00
Junkui Zhang
2711d8823c use vsync 2025-07-27 17:51:17 +08:00
Junkui Zhang
787fee8a1a fix 2025-07-27 17:51:17 +08:00
Junkui Zhang
be7d56e11b fix 2025-07-27 17:51:17 +08:00
Junkui Zhang
fcb77979f3 fix build 2025-07-27 17:51:17 +08:00
Junkui Zhang
787c6382f9 remove unused 2025-07-27 17:51:17 +08:00
Junkui Zhang
74d953d024 checkpoint 2025-07-27 17:51:17 +08:00
Junkui Zhang
b5377c56f2 remove noise when device lost 2025-07-27 17:51:16 +08:00
Junkui Zhang
275d84d566 init handle_device_lost 2025-07-27 17:51:16 +08:00
Junkui Zhang
3978bba5a7 fix 2025-07-27 17:51:16 +08:00
Junkui Zhang
52c0fa5ce9 remove debug print 2025-07-27 17:51:16 +08:00
Junkui Zhang
d208f75f46 enable O3 optimization for fxc 2025-07-27 17:51:16 +08:00
Junkui Zhang
1b0a0aa58e add x86 support for nvidia 2025-07-27 17:51:16 +08:00
Junkui Zhang
5ff9114b18 add runtime shader 2025-07-27 17:51:16 +08:00
Junkui Zhang
d9c6d09545 checkpoint 2025-07-27 17:51:16 +08:00
Junkui Zhang
61981aabb5 allow to compile shader at building 2025-07-27 17:51:16 +08:00
Junkui Zhang
0b57c86e07 add amd gpu version support 2025-07-27 17:51:16 +08:00
Junkui Zhang
c7342a9df5 remove unused 2025-07-27 17:51:16 +08:00
Junkui Zhang
0e45ef7e43 better output for nvidia 2025-07-27 17:51:16 +08:00
Junkui Zhang
0c40bb9b5f impl intel driver version 2025-07-27 17:51:16 +08:00
Junkui Zhang
5058752f2d cleanup 2025-07-27 17:51:16 +08:00
Junkui Zhang
432d11f57b implement gpu driver version for nvidia 2025-07-27 17:51:16 +08:00
Junkui Zhang
32488e1e2d fix 2025-07-27 17:51:16 +08:00
Junkui Zhang
9acee42c38 show err if failed to create new window 2025-07-27 17:51:16 +08:00
Junkui Zhang
72c55b4653 add new feature enable-renderdoc 2025-07-27 17:51:16 +08:00
Junkui Zhang
fa1320d9aa remove unused 2025-07-27 17:51:16 +08:00
Junkui Zhang
eb310bcf7d wip 2025-07-27 17:51:16 +08:00
Junkui Zhang
8c1d9f75d1 refactor 2025-07-27 17:51:16 +08:00
Junkui Zhang
499b3b6b50 rename to DirectXResources 2025-07-27 17:51:16 +08:00
Junkui Zhang
c6e020f60f finetune transpanrency 2025-07-27 17:51:15 +08:00
Junkui Zhang
7ab2d0d800 add transparency 2025-07-27 17:51:15 +08:00
Junkui Zhang
c007121b41 remove unused 2025-07-27 17:51:15 +08:00
Junkui Zhang
22c9d133bd wip 2025-07-27 17:51:15 +08:00
Junkui Zhang
32758022df wip 2025-07-27 17:51:15 +08:00
Junkui Zhang
0d8600bf1e checkpoint msaa 2025-07-27 17:51:15 +08:00
Junkui Zhang
22cba07072 add msaa 2025-07-27 17:51:15 +08:00
Junkui Zhang
642d769502 update default buffer size 2025-07-27 17:51:15 +08:00
Junkui Zhang
bfdcc65801 reenable transparency 2025-07-27 17:51:15 +08:00
Junkui Zhang
54e2420405 introduce set_pipeline_state 2025-07-27 17:51:15 +08:00
Junkui Zhang
b012246d2b refactor 2025-07-27 17:51:15 +08:00
Junkui Zhang
667c19907a refactor 2025-07-27 17:51:15 +08:00
Junkui Zhang
5261c02d18 refactor 2025-07-27 17:51:15 +08:00
Junkui Zhang
204071e6bf remove unused 2025-07-27 17:51:15 +08:00
Junkui Zhang
5472c71f1a fix paths rendering 2025-07-27 17:51:15 +08:00
Junkui Zhang
723712e3cf Revert "Fix path rendering - draw all paths w/ one regular draw call"
This reverts commit 83d942611f.
2025-07-27 17:51:15 +08:00
Junkui Zhang
0c274370c3 wip 2025-07-27 17:51:15 +08:00
Max Brunsfeld
31fab3a37a Fix dxgi_factory type error in release mode 2025-07-27 17:51:15 +08:00
Max Brunsfeld
4f416d3818 Fix path rendering - draw all paths w/ one regular draw call 2025-07-27 17:51:15 +08:00
Junkui Zhang
ffef9fd25a bringback our colorful avatar 2025-07-27 17:51:15 +08:00
Junkui Zhang
2a6b83f190 remove debug print 2025-07-27 17:51:15 +08:00
Junkui Zhang
1b12dd39cc fix all 2025-07-27 17:51:14 +08:00
Max Brunsfeld
9162583bac Add emojis to text example 2025-07-27 17:51:14 +08:00
Max Brunsfeld
8075998c09 Translate rasterized glyphs from texture to bitmap 2025-07-27 17:51:14 +08:00
Kate
3b6105b713 more fixes and debugging 2025-07-27 17:51:14 +08:00
Kate
2b53a2cb12 make it not crash 2025-07-27 17:51:14 +08:00
Kate
96d847b6d1 initial removal attempt 2025-07-27 17:51:14 +08:00
Junkui Zhang
7fde34f85e temporarily disable transparancy 2025-07-27 17:51:14 +08:00
Junkui Zhang
401e0e6f41 wip 2025-07-27 17:51:14 +08:00
Junkui Zhang
201c274c4b wip 2025-07-27 17:51:14 +08:00
Junkui Zhang
ecde968a0c wip 2025-07-27 17:51:14 +08:00
Junkui Zhang
4a78ce7cfd wip 2025-07-27 17:51:14 +08:00
Junkui Zhang
fda3d56d87 wip 2025-07-27 17:51:14 +08:00
Junkui Zhang
9c3cfca835 apply #23576 2025-07-27 17:51:14 +08:00
Junkui Zhang
1fb689bad3 apply #19772 2025-07-27 17:51:14 +08:00
Junkui Zhang
238ccec5ee fix 2025-07-27 17:51:14 +08:00
Junkui Zhang
c8ae5a3b11 fix all 2025-07-27 17:51:14 +08:00
Junkui Zhang
dbe2ce2464 wip 2025-07-27 17:51:14 +08:00
Junkui Zhang
5287183667 apply #20812 2025-07-27 17:51:14 +08:00
Junkui Zhang
a48ae50e1a apply #15782 2025-07-27 17:51:14 +08:00
Junkui Zhang
ca3d55ee4d wip 2025-07-27 17:51:13 +08:00
Junkui Zhang
c0bad42968 wip 2025-07-27 17:51:13 +08:00
Junkui Zhang
7186f1322e init 2025-07-27 17:51:13 +08:00
Kate
21e14b5f9a Merge branch 'windows/dx11' into windows/remove-d2d 2025-07-22 16:03:55 +02:00
Max Brunsfeld
7d84014ad2 Update cargo lock 2025-07-21 13:58:59 -07:00
Max Brunsfeld
68780da673 Pass scale factor transform to glyph analysis when computing bounds
rather than simply multiplying every rect field by the scale factor.
This fixes clipping of glyphs and removes the need for magic numbers
expanding the bounds vertically.

Co-authored-by: Julia Ryan <juliaryan3.14@gmail.com>
2025-07-21 13:35:21 -07:00
Kate
1f55a0a358 cleanup code a bit 2025-07-21 17:17:42 +02:00
Kate
ba80e16339 color rasterization works now 2025-07-18 21:12:19 +02:00
Kate
11dc14ad4d gpu rasterization works now 2025-07-18 16:15:53 +02:00
Junkui Zhang
9f200ebf5a fix 2025-07-18 17:56:46 +08:00
Junkui Zhang
788865e892 remove debug print 2025-07-18 17:54:26 +08:00
Junkui Zhang
e87ee91d8e enable O3 optimization for fxc 2025-07-18 17:53:45 +08:00
Junkui Zhang
b0e48d01ce add x86 support for nvidia 2025-07-18 17:48:05 +08:00
Junkui Zhang
825ee6233b add runtime shader 2025-07-18 17:43:01 +08:00
Junkui Zhang
154705e729 checkpoint 2025-07-18 17:08:44 +08:00
Junkui Zhang
636a057373 allow to compile shader at building 2025-07-18 16:46:34 +08:00
Junkui Zhang
df1f62477c add amd gpu version support 2025-07-18 15:11:49 +08:00
Kate
6907064be6 prepare for gpu rasterization 2025-07-17 20:00:14 +02:00
Kate
c1eaf3317d Merge branch 'windows/dx11' into HEAD 2025-07-17 19:59:21 +02:00
Junkui Zhang
6477a9b056 remove unused 2025-07-17 21:35:42 +08:00
Junkui Zhang
84f75fe683 better output for nvidia 2025-07-17 21:31:00 +08:00
Junkui Zhang
7627097875 impl intel driver version 2025-07-17 21:10:33 +08:00
Junkui Zhang
78824390d0 cleanup 2025-07-17 19:48:42 +08:00
Junkui Zhang
4d936845f3 implement gpu driver version for nvidia 2025-07-17 19:30:49 +08:00
Junkui Zhang
76fb80eaeb fix 2025-07-17 17:12:27 +08:00
Junkui Zhang
29b5acf27b show err if failed to create new window 2025-07-17 17:04:36 +08:00
Junkui Zhang
e560c6813f add new feature enable-renderdoc 2025-07-17 16:57:23 +08:00
Junkui Zhang
a57cbe4636 remove unused 2025-07-17 16:34:34 +08:00
Junkui Zhang
7cf10d110c wip 2025-07-17 16:34:09 +08:00
Junkui Zhang
1888f21a14 refactor 2025-07-17 15:52:51 +08:00
Junkui Zhang
63727f99da rename to DirectXResources 2025-07-17 14:53:28 +08:00
Junkui Zhang
602bd189f6 finetune transpanrency 2025-07-17 14:36:41 +08:00
Junkui Zhang
b8314e74db add transparency 2025-07-17 11:29:44 +08:00
Kate
a486bb28f6 initial color emoji implementation, currently only monochrome, still
figuring out why it doesn't render even though it rasterizes to the
bitmap correctly
2025-07-16 23:22:08 +02:00
Junkui Zhang
b1b5a383e0 remove unused 2025-07-17 00:05:46 +08:00
Junkui Zhang
b0fe5fd56f wip 2025-07-16 23:55:32 +08:00
Junkui Zhang
398d492f85 wip 2025-07-16 22:58:51 +08:00
Junkui Zhang
55edee58fb checkpoint msaa 2025-07-16 22:20:38 +08:00
Junkui Zhang
da3736bd5f add msaa 2025-07-16 21:26:11 +08:00
Junkui Zhang
4b2ff5e251 update default buffer size 2025-07-16 19:56:56 +08:00
Junkui Zhang
46fc76fdf8 reenable transparency 2025-07-16 17:14:24 +08:00
Junkui Zhang
ffbb47452d introduce set_pipeline_state 2025-07-16 15:48:55 +08:00
Junkui Zhang
5ed8b13e4a refactor 2025-07-16 15:36:24 +08:00
Junkui Zhang
1baafae3f7 refactor 2025-07-16 14:14:31 +08:00
Junkui Zhang
2017ce3699 refactor 2025-07-16 14:04:49 +08:00
Junkui Zhang
f715acc92a remove unused 2025-07-16 11:37:51 +08:00
Junkui Zhang
291691ca0e fix paths rendering 2025-07-16 11:24:55 +08:00
Junkui Zhang
158732eb17 Revert "Fix path rendering - draw all paths w/ one regular draw call"
This reverts commit 83d942611f.
2025-07-16 11:08:30 +08:00
Junkui Zhang
cdbaff8375 wip 2025-07-16 09:43:08 +08:00
Max Brunsfeld
c014dbae8c Fix dxgi_factory type error in release mode 2025-07-15 16:48:19 -07:00
Max Brunsfeld
83d942611f Fix path rendering - draw all paths w/ one regular draw call 2025-07-15 16:15:50 -07:00
Junkui Zhang
f16f07b36f bringback our colorful avatar 2025-07-15 23:38:18 +08:00
Junkui Zhang
85cf9e405e remove debug print 2025-07-15 23:29:32 +08:00
Junkui Zhang
a1c00ed87f fix all 2025-07-15 23:28:25 +08:00
Max Brunsfeld
34d5926ebd Add emojis to text example 2025-07-15 13:56:30 +08:00
Max Brunsfeld
6fc8d7746f Translate rasterized glyphs from texture to bitmap 2025-07-15 13:56:30 +08:00
Kate
2fb31a9157 more fixes and debugging 2025-07-15 13:56:29 +08:00
Kate
a7e34ab0bc make it not crash 2025-07-15 13:56:29 +08:00
Kate
6928488aad initial removal attempt 2025-07-15 13:56:29 +08:00
Junkui Zhang
8514850ad4 temporarily disable transparancy 2025-07-15 12:35:33 +08:00
Max Brunsfeld
231c38aa41 Add emojis to text example 2025-07-14 17:57:52 -07:00
Max Brunsfeld
8d538fad0c Translate rasterized glyphs from texture to bitmap 2025-07-14 17:57:38 -07:00
Kate
f5aa88ca6a more fixes and debugging 2025-07-14 22:34:36 +02:00
Kate
b9eb18eb7f make it not crash 2025-07-14 21:44:20 +02:00
Kate
b130346ede initial removal attempt 2025-07-14 20:55:16 +02:00
Junkui Zhang
e8bd47f668 wip 2025-07-14 19:59:26 +08:00
Junkui Zhang
6a918b64bf wip 2025-07-14 18:35:52 +08:00
Junkui Zhang
c82edc38a9 wip 2025-07-14 17:55:50 +08:00
Junkui Zhang
622a42e3aa wip 2025-07-14 17:49:44 +08:00
Junkui Zhang
dcdd7404e4 wip 2025-07-14 16:47:45 +08:00
Junkui Zhang
52c181328c apply #23576 2025-07-14 14:50:21 +08:00
Junkui Zhang
2319cd8211 apply #19772 2025-07-14 14:24:47 +08:00
Junkui Zhang
d0a2257472 fix 2025-07-13 20:31:39 +08:00
Junkui Zhang
af2009710a fix all 2025-07-13 20:17:01 +08:00
Junkui Zhang
eec406bb36 wip 2025-07-13 17:13:53 +08:00
Junkui Zhang
83ea328be5 apply #20812 2025-07-13 16:34:16 +08:00
Junkui Zhang
f2c847a1b0 apply #15782 2025-07-13 13:28:52 +08:00
Junkui Zhang
5d03296dc2 wip 2025-07-13 13:15:24 +08:00
Junkui Zhang
b4771bc4f8 wip 2025-07-13 12:49:05 +08:00
Junkui Zhang
68192052fd init 2025-07-13 12:32:59 +08:00
231 changed files with 4037 additions and 10392 deletions

View File

@@ -19,7 +19,7 @@ runs:
shell: bash -euxo pipefail {0}
run: ./script/linux
- name: Check for broken links (in MD)
- name: Check for broken links
uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1
with:
args: --no-progress --exclude '^http' './docs/src/**/*'
@@ -30,9 +30,3 @@ runs:
run: |
mkdir -p target/deploy
mdbook build ./docs --dest-dir=../target/deploy/docs/
- name: Check for broken links (in HTML)
uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1
with:
args: --no-progress --exclude '^http' 'target/deploy/docs/'
fail: true

View File

@@ -771,8 +771,7 @@ jobs:
timeout-minutes: 120
name: Create a Windows installer
runs-on: [self-hosted, Windows, X64]
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
# if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
if: true && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
needs: [windows_tests]
env:
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}
@@ -808,16 +807,16 @@ jobs:
name: ZedEditorUserSetup-x64-${{ github.event.pull_request.head.sha || github.sha }}.exe
path: ${{ env.SETUP_PATH }}
- name: Upload Artifacts to release
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
# Re-enable when we are ready to publish windows preview releases
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
with:
draft: true
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
files: ${{ env.SETUP_PATH }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# - name: Upload Artifacts to release
# uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
# # Re-enable when we are ready to publish windows preview releases
# if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
# with:
# draft: true
# prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
# files: ${{ env.SETUP_PATH }}
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
auto-release-preview:
name: Auto release preview

574
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,13 @@
[workspace]
resolver = "2"
members = [
"crates/acp_thread",
"crates/activity_indicator",
"crates/agent",
"crates/agent_servers",
"crates/agent_settings",
"crates/acp_thread",
"crates/agent_ui",
"crates/agent",
"crates/agent_settings",
"crates/ai_onboarding",
"crates/agent_servers",
"crates/anthropic",
"crates/askpass",
"crates/assets",
@@ -29,9 +29,6 @@ members = [
"crates/cli",
"crates/client",
"crates/clock",
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -51,8 +48,8 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/editor",
"crates/eval",
"crates/explorer_command_injector",
"crates/eval",
"crates/extension",
"crates/extension_api",
"crates/extension_cli",
@@ -73,6 +70,7 @@ members = [
"crates/gpui",
"crates/gpui_macros",
"crates/gpui_tokio",
"crates/html_to_markdown",
"crates/http_client",
"crates/http_client_tls",
@@ -101,6 +99,7 @@ members = [
"crates/markdown_preview",
"crates/media",
"crates/menu",
"crates/svg_preview",
"crates/migrator",
"crates/mistral",
"crates/multi_buffer",
@@ -141,7 +140,6 @@ members = [
"crates/semantic_version",
"crates/session",
"crates/settings",
"crates/settings_profile_selector",
"crates/settings_ui",
"crates/snippet",
"crates/snippet_provider",
@@ -154,7 +152,6 @@ members = [
"crates/sum_tree",
"crates/supermaven",
"crates/supermaven_api",
"crates/svg_preview",
"crates/tab_switcher",
"crates/task",
"crates/tasks_ui",
@@ -189,7 +186,6 @@ members = [
"crates/zed",
"crates/zed_actions",
"crates/zeta",
"crates/zeta_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -255,9 +251,6 @@ channel = { path = "crates/channel" }
cli = { path = "crates/cli" }
client = { path = "crates/client" }
clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
collab = { path = "crates/collab" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
@@ -344,7 +337,6 @@ picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
plugin_macros = { path = "crates/plugin_macros" }
prettier = { path = "crates/prettier" }
settings_profile_selector = { path = "crates/settings_profile_selector" }
project = { path = "crates/project" }
project_panel = { path = "crates/project_panel" }
project_symbols = { path = "crates/project_symbols" }
@@ -421,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = {path="../agent-client-protocol"}
agent-client-protocol = "0.0.11"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -653,6 +645,7 @@ which = "6.0.0"
windows-core = "0.61"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "= 0.8.6"
zstd = "0.11"
[workspace.dependencies.async-stripe]
@@ -679,6 +672,8 @@ features = [
"UI_ViewManagement",
"Wdk_System_SystemServices",
"Win32_Globalization",
"Win32_Graphics_Direct2D",
"Win32_Graphics_Direct2D_Common",
"Win32_Graphics_Direct3D",
"Win32_Graphics_Direct3D11",
"Win32_Graphics_Direct3D_Fxc",
@@ -689,6 +684,7 @@ features = [
"Win32_Graphics_Dxgi_Common",
"Win32_Graphics_Gdi",
"Win32_Graphics_Imaging",
"Win32_Graphics_Imaging_D2D",
"Win32_Networking_WinSock",
"Win32_Security",
"Win32_Security_Credentials",

View File

@@ -1,6 +1,5 @@
# Zed
[![Zed](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/zed-industries/zed/main/assets/badge/v0.json)](https://zed.dev)
[![CI](https://github.com/zed-industries/zed/actions/workflows/ci.yml/badge.svg)](https://github.com/zed-industries/zed/actions/workflows/ci.yml)
Welcome to Zed, a high-performance, multiplayer code editor from the creators of [Atom](https://github.com/atom/atom) and [Tree-sitter](https://github.com/tree-sitter/tree-sitter).

View File

@@ -1,8 +0,0 @@
{
"label": "",
"message": "Zed",
"logoSvg": "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 96 96\"><rect width=\"96\" height=\"96\" fill=\"#000\"/><path fill-rule=\"evenodd\" clip-rule=\"evenodd\" d=\"M9 6C7.34315 6 6 7.34315 6 9V75H0V9C0 4.02944 4.02944 0 9 0H89.3787C93.3878 0 95.3955 4.84715 92.5607 7.68198L43.0551 57.1875H57V51H63V58.6875C63 61.1728 60.9853 63.1875 58.5 63.1875H37.0551L26.7426 73.5H73.5V36H79.5V73.5C79.5 76.8137 76.8137 79.5 73.5 79.5H20.7426L10.2426 90H87C88.6569 90 90 88.6569 90 87V21H96V87C96 91.9706 91.9706 96 87 96H6.62132C2.61224 96 0.604504 91.1529 3.43934 88.318L52.7574 39H39V45H33V37.5C33 35.0147 35.0147 33 37.5 33H58.7574L69.2574 22.5H22.5V60H16.5V22.5C16.5 19.1863 19.1863 16.5 22.5 16.5H75.2574L85.7574 6H9Z\" fill=\"#fff\"/></svg>",
"logoWidth": 16,
"labelColor": "black",
"color": "white"
}

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 6.3 KiB

View File

@@ -1,9 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path opacity="0.6" d="M3.5 11V5.5L8.5 8L3.5 11Z" fill="black"/>
<path opacity="0.4" d="M8.5 14L3.5 11L8.5 8V14Z" fill="black"/>
<path opacity="0.6" d="M8.5 5.5H3.5L8.5 2.5L8.5 5.5Z" fill="black"/>
<path opacity="0.8" d="M8.5 5.5V2.5L13.5 5.5H8.5Z" fill="black"/>
<path opacity="0.2" d="M13.5 11L8.5 14L11 9.5L13.5 11Z" fill="black"/>
<path opacity="0.5" d="M13.5 11L11 9.5L13.5 5V11Z" fill="black"/>
<path d="M3.5 11V5L8.5 2.11325L13.5 5V11L8.5 13.8868L3.5 11Z" stroke="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 583 B

View File

@@ -1,10 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_2716_663)">
<path d="M8.47552 2.45453C11.5167 2.45457 13.9814 4.94501 13.9814 8.01623C13.9814 11.0875 11.5167 13.578 8.47552 13.5781C5.43427 13.5781 2.96948 11.0875 2.96948 8.01623C2.9695 4.94498 5.43429 2.45453 8.47552 2.45453ZM10.8795 4.70348C10.7605 4.16887 10.1328 3.85468 9.53627 3.96342C8.97622 4.06552 7.62871 4.45681 7.62057 4.45916C9.29414 4.44469 9.57429 4.4726 9.69939 4.64751C9.77324 4.7508 9.66576 4.89248 9.21944 4.96538C8.73515 5.04447 7.73014 5.13958 7.72343 5.14022C6.75441 5.19776 6.07177 5.20168 5.86705 5.63512C5.73334 5.91827 6.00968 6.16857 6.13082 6.32527C6.64271 6.89455 7.38215 7.20158 7.85809 7.42767C8.03716 7.51274 8.56257 7.67345 8.56257 7.67345C7.01855 7.58853 5.90474 8.06267 5.2514 8.60855C4.51246 9.29204 4.83937 10.1067 6.35327 10.6084C7.24742 10.9047 7.69094 11.0439 9.02473 10.9238C9.81031 10.8815 9.9342 10.9068 9.94203 10.9712C9.95275 11.062 9.06932 11.2874 8.82812 11.357C8.21455 11.534 6.60645 11.8913 6.59758 11.8932C6.60115 11.8935 7.06249 11.9257 7.65531 11.8735C7.89632 11.8522 8.81142 11.7624 9.49557 11.6123C9.49557 11.6123 10.3297 11.4338 10.7759 11.2693C11.2429 11.0973 11.497 10.9512 11.6113 10.7443C11.6063 10.7019 11.6465 10.5516 11.4313 10.4613C10.8807 10.2304 10.2423 10.2721 8.9789 10.2453C7.57789 10.1972 7.11184 9.9626 6.86356 9.77373C6.62548 9.58212 6.74518 9.05204 7.76528 8.5851C8.27917 8.33646 10.2935 7.87759 10.2935 7.87759C9.61511 7.54227 8.35014 6.95284 8.09005 6.82552C7.86199 6.71388 7.49701 6.54572 7.4179 6.34233C7.32824 6.14709 7.6297 5.97888 7.79813 5.9307C8.34057 5.77424 9.10635 5.67701 9.8033 5.66609C10.1536 5.66061 10.2105 5.63806 10.2105 5.63806C10.6939 5.55787 11.0121 5.22722 10.8795 4.70348Z" fill="black"/>
</g>
<defs>
<clipPath id="clip0_2716_663">
<rect width="12" height="12" fill="white" transform="translate(2.5 2)"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 1.9 KiB

View File

@@ -1,3 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3.6725 13.9985C3.36161 13.9982 3.06354 13.8746 2.84371 13.6548C2.62388 13.435 2.50026 13.1369 2.5 12.826V7.494C2.5 6.8325 2.7675 6.185 3.2365 5.7165L6.219 2.736C6.45192 2.50247 6.72867 2.31724 7.03335 2.19094C7.33804 2.06464 7.66467 1.99975 7.9945 2H13.3275C13.6384 2.00027 13.9365 2.12388 14.1563 2.34371C14.3761 2.56354 14.4997 2.86162 14.5 3.1725V8.5045C14.4983 9.17074 14.2336 9.80936 13.7635 10.2815L10.781 13.264C10.5477 13.4976 10.2706 13.6829 9.96561 13.8092C9.66059 13.9355 9.33364 14.0003 9.0035 14V13.9985H3.6725ZM8.157 10.5715H5.243V11.257H8.157V10.5715ZM4.4815 5.257H11.243V12.0165L13.3715 9.888C13.7373 9.52036 13.9433 9.02316 13.9445 8.5045V3.1725C13.9445 2.8335 13.6685 2.5555 13.3275 2.5555H7.9945C7.73753 2.55499 7.483 2.6053 7.24556 2.70356C7.00813 2.80181 6.79246 2.94606 6.611 3.128L4.4815 5.257ZM4.3855 5.353L3.628 6.11C3.26258 6.47809 3.0569 6.97533 3.0555 7.494V12.826C3.0555 13.165 3.3315 13.443 3.6725 13.443H9.0055C9.26249 13.4434 9.51701 13.3929 9.75445 13.2946C9.99188 13.1963 10.2075 13.052 10.389 12.87L11.145 12.1145H4.3855V5.353Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -1,5 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M13.0945 8.01611C13.0945 7.87619 12.9911 7.79551 12.8642 7.8356L4.13456 10.6038C4.00742 10.6441 3.90427 10.7904 3.90427 10.9301V13.7593C3.90427 13.8992 4.00742 13.9801 4.13456 13.9398L12.8642 11.1719C12.9911 11.1315 13.0945 10.9852 13.0945 10.8453V8.01611Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.90427 7.92597C3.90427 8.06588 4.00742 8.21218 4.13456 8.25252L12.8655 11.0209C12.9926 11.0613 13.0958 10.9803 13.0958 10.8407V8.01124C13.0958 7.87158 12.9926 7.72529 12.8655 7.68494L4.13456 4.91652C4.00742 4.87618 3.90427 4.95686 3.90427 5.09677V7.92597Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M13.0945 2.20248C13.0945 2.06256 12.9911 1.98163 12.8642 2.02197L4.13456 4.78988C4.00742 4.83022 3.90427 4.97652 3.90427 5.11644V7.94563C3.90427 8.08554 4.00742 8.16622 4.13456 8.12614L12.8642 5.35797C12.9911 5.31763 13.0945 5.17133 13.0945 5.03167V2.20248Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.0 KiB

View File

@@ -1,3 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.0094 13.9181C11.1984 13.9917 11.4139 13.987 11.6047 13.8952L14.0753 12.7064C14.3349 12.5814 14.5 12.3187 14.5 12.0305V3.9696C14.5 3.68136 14.3349 3.41862 14.0753 3.2937L11.6047 2.10485C11.3543 1.98438 11.0614 2.01389 10.8416 2.17363C10.8102 2.19645 10.7803 2.22193 10.7523 2.25001L6.02261 6.56498L3.96246 5.00115C3.77068 4.85558 3.50244 4.86751 3.32432 5.02953L2.66356 5.63059C2.44569 5.82877 2.44544 6.17152 2.66302 6.37004L4.44965 8.00001L2.66302 9.62998C2.44544 9.82849 2.44569 10.1713 2.66356 10.3694L3.32432 10.9705C3.50244 11.1325 3.77068 11.1444 3.96246 10.9989L6.02261 9.43504L10.7523 13.75C10.8271 13.8249 10.915 13.8812 11.0094 13.9181ZM11.5018 5.27587L7.91309 8.00001L11.5018 10.7241V5.27587Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 876 B

View File

@@ -1,4 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M13.0001 8.62505C13.0001 11.75 10.8126 13.3125 8.21266 14.2187C8.07651 14.2648 7.92862 14.2626 7.79392 14.2125C5.18771 13.3125 3.00024 11.75 3.00024 8.62505V4.25012C3.00024 4.08436 3.06609 3.92539 3.1833 3.80818C3.30051 3.69098 3.45948 3.62513 3.62523 3.62513C4.87521 3.62513 6.43769 2.87514 7.52517 1.92516C7.65758 1.81203 7.82601 1.74988 8.00016 1.74988C8.17431 1.74988 8.34275 1.81203 8.47515 1.92516C9.56889 2.88139 11.1251 3.62513 12.3751 3.62513C12.5408 3.62513 12.6998 3.69098 12.817 3.80818C12.9342 3.92539 13.0001 4.08436 13.0001 4.25012V8.62505Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M6 8.00002L7.33333 9.33335L10 6.66669" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 883 B

View File

@@ -232,7 +232,7 @@
"ctrl-n": "agent::NewThread",
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenSettings",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-alt-p": "agent::OpenRulesLibrary",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "agent::ToggleModelSelector",
@@ -495,7 +495,7 @@
"shift-f12": "editor::GoToImplementation",
"alt-ctrl-f12": "editor::GoToTypeDefinitionSplit",
"alt-shift-f12": "editor::FindAllReferences",
"ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains
"ctrl-m": "editor::MoveToEnclosingBracket",
"ctrl-|": "editor::MoveToEnclosingBracket",
"ctrl-{": "editor::Fold",
"ctrl-}": "editor::UnfoldLines",
@@ -598,7 +598,6 @@
"ctrl-shift-t": "pane::ReopenClosedItem",
"ctrl-k ctrl-s": "zed::OpenKeymapEditor",
"ctrl-k ctrl-t": "theme_selector::Toggle",
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
"ctrl-t": "project_symbols::Toggle",
"ctrl-p": "file_finder::Toggle",
"ctrl-tab": "tab_switcher::Toggle",
@@ -1168,14 +1167,5 @@
"up": "menu::SelectPrevious",
"down": "menu::SelectNext"
}
},
{
"context": "Onboarding",
"use_key_equivalents": true,
"bindings": {
"ctrl-1": "onboarding::ActivateBasicsPage",
"ctrl-2": "onboarding::ActivateEditingPage",
"ctrl-3": "onboarding::ActivateAISetupPage"
}
}
]

View File

@@ -272,7 +272,7 @@
"cmd-n": "agent::NewThread",
"cmd-alt-n": "agent::NewTextThread",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenSettings",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-alt-p": "agent::OpenRulesLibrary",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "agent::ToggleModelSelector",
@@ -549,7 +549,7 @@
"alt-cmd-f12": "editor::GoToTypeDefinitionSplit",
"alt-shift-f12": "editor::FindAllReferences",
"cmd-|": "editor::MoveToEnclosingBracket",
"ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains
"ctrl-m": "editor::MoveToEnclosingBracket",
"alt-cmd-[": "editor::Fold",
"alt-cmd-]": "editor::UnfoldLines",
"cmd-k cmd-l": "editor::ToggleFold",
@@ -665,7 +665,6 @@
"cmd-shift-t": "pane::ReopenClosedItem",
"cmd-k cmd-s": "zed::OpenKeymapEditor",
"cmd-k cmd-t": "theme_selector::Toggle",
"ctrl-alt-cmd-p": "settings_profile_selector::Toggle",
"cmd-t": "project_symbols::Toggle",
"cmd-p": "file_finder::Toggle",
"ctrl-tab": "tab_switcher::Toggle",
@@ -1270,14 +1269,5 @@
"up": "menu::SelectPrevious",
"down": "menu::SelectNext"
}
},
{
"context": "Onboarding",
"use_key_equivalents": true,
"bindings": {
"cmd-1": "onboarding::ActivateBasicsPage",
"cmd-2": "onboarding::ActivateEditingPage",
"cmd-3": "onboarding::ActivateAISetupPage"
}
}
]

View File

@@ -8,7 +8,7 @@
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-l": "agent::ToggleFocus",
"ctrl-shift-l": "agent::ToggleFocus",
"ctrl-shift-j": "agent::OpenSettings"
"ctrl-shift-j": "agent::OpenConfiguration"
}
},
{

View File

@@ -95,7 +95,7 @@
"ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }],
"alt-shift-f10": "task::Spawn",
"ctrl-e": "file_finder::Toggle",
// "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
"ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
"ctrl-shift-n": "file_finder::Toggle",
"ctrl-shift-a": "command_palette::Toggle",
"shift shift": "command_palette::Toggle",

View File

@@ -8,7 +8,7 @@
"cmd-shift-i": "agent::ToggleFocus",
"cmd-l": "agent::ToggleFocus",
"cmd-shift-l": "agent::ToggleFocus",
"cmd-shift-j": "agent::OpenSettings"
"cmd-shift-j": "agent::OpenConfiguration"
}
},
{

View File

@@ -97,7 +97,7 @@
"cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }],
"ctrl-alt-r": "task::Spawn",
"cmd-e": "file_finder::Toggle",
// "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
"cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
"cmd-shift-o": "file_finder::Toggle",
"cmd-shift-a": "command_palette::Toggle",
"shift shift": "command_palette::Toggle",

View File

@@ -1877,25 +1877,5 @@
"save_breakpoints": true,
"dock": "bottom",
"button": true
},
// Configures any number of settings profiles that are temporarily applied on
// top of your existing user settings when selected from
// `settings profile selector: toggle`.
// Examples:
// "profiles": {
// "Presenting": {
// "agent_font_size": 20.0,
// "buffer_font_size": 20.0,
// "theme": "One Light",
// "ui_font_size": 20.0
// },
// "Python (ty)": {
// "languages": {
// "Python": {
// "language_servers": ["ty"]
// }
// }
// }
// }
"profiles": []
}
}

View File

@@ -391,7 +391,7 @@ impl ToolCallContent {
cx: &mut App,
) -> Self {
match content {
acp::ToolCallContent::Content { content } => Self::ContentBlock {
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
content: ContentBlock::new(content, &language_registry, cx),
},
acp::ToolCallContent::Diff { diff } => Self::Diff {
@@ -580,9 +580,6 @@ pub struct AcpThread {
pub enum AcpThreadEvent {
NewEntry,
EntryUpdated(usize),
ToolAuthorizationRequired,
Stopped,
Error,
}
impl EventEmitter<AcpThreadEvent> for AcpThread {}
@@ -619,7 +616,6 @@ impl Error for LoadError {}
impl AcpThread {
pub fn new(
title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
session_id: acp::SessionId,
@@ -632,7 +628,7 @@ impl AcpThread {
shared_buffers: Default::default(),
entries: Default::default(),
plan: Default::default(),
title: title.into(),
title: connection.name().into(),
project,
send_task: None,
connection,
@@ -680,32 +676,20 @@ impl AcpThread {
false
}
pub fn used_tools_since_last_user_message(&self) -> bool {
for entry in self.entries.iter().rev() {
match entry {
AgentThreadEntry::UserMessage(..) => return false,
AgentThreadEntry::AssistantMessage(..) => continue,
AgentThreadEntry::ToolCall(..) => return true,
}
}
false
}
pub fn handle_session_update(
&mut self,
update: acp::SessionUpdate,
cx: &mut Context<Self>,
) -> Result<()> {
match update {
acp::SessionUpdate::UserMessageChunk { content } => {
self.push_user_content_block(content, cx);
acp::SessionUpdate::UserMessage(content_block) => {
self.push_user_content_block(content_block, cx);
}
acp::SessionUpdate::AgentMessageChunk { content } => {
self.push_assistant_content_block(content, false, cx);
acp::SessionUpdate::AgentMessageChunk(content_block) => {
self.push_assistant_content_block(content_block, false, cx);
}
acp::SessionUpdate::AgentThoughtChunk { content } => {
self.push_assistant_content_block(content, true, cx);
acp::SessionUpdate::AgentThoughtChunk(content_block) => {
self.push_assistant_content_block(content_block, true, cx);
}
acp::SessionUpdate::ToolCall(tool_call) => {
self.upsert_tool_call(tool_call, cx);
@@ -895,7 +879,6 @@ impl AcpThread {
};
self.upsert_tool_call_inner(tool_call, status, cx);
cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
rx
}
@@ -974,6 +957,10 @@ impl AcpThread {
cx.notify();
}
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
self.connection.authenticate(cx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn send_raw(
&mut self,
@@ -1015,7 +1002,7 @@ impl AcpThread {
let result = this
.update(cx, |this, cx| {
this.connection.prompt(
acp::PromptRequest {
acp::PromptArguments {
prompt: message,
session_id: this.session_id.clone(),
},
@@ -1031,18 +1018,12 @@ impl AcpThread {
.log_err();
}));
cx.spawn(async move |this, cx| match rx.await {
Ok(Err(e)) => {
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
.log_err();
Err(e)?
async move {
match rx.await {
Ok(Err(e)) => Err(e)?,
_ => Ok(()),
}
_ => {
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
.log_err();
Ok(())
}
})
}
.boxed()
}
@@ -1616,16 +1597,9 @@ mod tests {
name: "test",
connection,
child_status: io_task,
current_thread: thread_rc,
auth_methods: [acp::AuthMethod {
id: acp::AuthMethodId("acp-old-no-id".into()),
label: "Log in".into(),
description: None,
}],
};
AcpThread::new(
"test",
Rc::new(connection),
project,
acp::SessionId("test".into()),

View File

@@ -1,6 +1,6 @@
use std::{error::Error, fmt, path::Path, rc::Rc};
use std::{path::Path, rc::Rc};
use agent_client_protocol::{self as acp};
use agent_client_protocol as acp;
use anyhow::Result;
use gpui::{AsyncApp, Entity, Task};
use project::Project;
@@ -9,6 +9,8 @@ use ui::App;
use crate::AcpThread;
pub trait AgentConnection {
fn name(&self) -> &'static str;
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -16,21 +18,9 @@ pub trait AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
fn auth_methods(&self) -> &[acp::AuthMethod];
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
}
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
}

View File

@@ -5,11 +5,10 @@ use anyhow::{Context as _, Result};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use std::{cell::RefCell, path::Path, rc::Rc};
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
use crate::{AcpThread, AgentConnection, AuthRequired};
use crate::{AcpThread, AgentConnection};
#[derive(Clone)]
pub struct OldAcpClientDelegate {
@@ -47,7 +46,7 @@ impl acp_old::Client for OldAcpClientDelegate {
thread.push_assistant_content_block(thought.into(), true, cx)
}
})
.log_err();
.ok();
})?;
Ok(())
@@ -351,15 +350,27 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
}
}
#[derive(Debug)]
pub struct Unauthenticated;
impl Error for Unauthenticated {}
impl fmt::Display for Unauthenticated {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Unauthenticated")
}
}
pub struct OldAcpAgentConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
pub child_status: Task<Result<()>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
pub auth_methods: [acp::AuthMethod; 1],
}
impl AgentConnection for OldAcpAgentConnection {
fn name(&self) -> &'static str {
self.name
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -372,31 +383,25 @@ impl AgentConnection for OldAcpAgentConnection {
}
.into_any(),
);
let current_thread = self.current_thread.clone();
cx.spawn(async move |cx| {
let result = task.await?;
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
anyhow::bail!(AuthRequired)
anyhow::bail!(Unauthenticated)
}
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
AcpThread::new("Gemini", self.clone(), project, session_id, cx)
AcpThread::new(self.clone(), project, session_id, cx)
});
current_thread.replace(thread.downgrade());
thread
})
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
@@ -406,7 +411,7 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
let chunks = params
.prompt
.into_iter()

View File

@@ -25,7 +25,6 @@ assistant_context.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
context_server.workspace = true
@@ -36,9 +35,9 @@ futures.workspace = true
git.workspace = true
gpui.workspace = true
heed.workspace = true
http_client.workspace = true
icons.workspace = true
indoc.workspace = true
http_client.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
@@ -47,6 +46,7 @@ paths.workspace = true
postage.workspace = true
project.workspace = true
prompt_store.workspace = true
proto.workspace = true
ref-cast.workspace = true
rope.workspace = true
schemars.workspace = true
@@ -63,6 +63,7 @@ time.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
zstd.workspace = true
[dev-dependencies]

View File

@@ -13,7 +13,6 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
use futures::{FutureExt, StreamExt as _, future::Shared};
@@ -37,6 +36,7 @@ use project::{
git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
};
use prompt_store::{ModelContext, PromptBuilder};
use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
@@ -49,6 +49,7 @@ use std::{
use thiserror::Error;
use util::{ResultExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
const MAX_RETRY_ATTEMPTS: u8 = 4;
const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
@@ -1680,7 +1681,7 @@ impl Thread {
let completion_mode = request
.mode
.unwrap_or(cloud_llm_client::CompletionMode::Normal);
.unwrap_or(zed_llm_client::CompletionMode::Normal);
self.last_received_chunk_at = Some(Instant::now());
@@ -3254,10 +3255,8 @@ impl Thread {
}
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
self.project
.read(cx)
.user_store()
.update(cx, |user_store, cx| {
self.project.update(cx, |project, cx| {
project.user_store().update(cx, |user_store, cx| {
user_store.update_model_request_usage(
ModelRequestUsage(RequestUsage {
amount: amount as i32,
@@ -3265,7 +3264,8 @@ impl Thread {
}),
cx,
)
});
})
});
}
pub fn deny_tool_use(

View File

@@ -1,245 +0,0 @@
use agent_client_protocol::{self as acp, Agent as _};
use collections::HashMap;
use futures::channel::oneshot;
use project::Project;
use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::AgentServerCommand;
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
pub struct AcpConnection {
server_name: &'static str,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
_io_task: Task<Result<()>>,
}
pub struct AcpSession {
thread: WeakEntity<AcpThread>,
}
impl AcpConnection {
pub async fn stdio(
server_name: &'static str,
command: AgentServerCommand,
root_dir: &Path,
cx: &mut AsyncApp,
) -> Result<Self> {
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter().map(|arg| arg.as_str()))
.envs(command.env.iter().flatten())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let stdout = child.stdout.take().expect("Failed to take stdout");
let stdin = child.stdin.take().expect("Failed to take stdin");
let sessions = Rc::new(RefCell::new(HashMap::default()));
let client = ClientDelegate {
sessions: sessions.clone(),
cx: cx.clone(),
};
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
let foreground_executor = cx.foreground_executor().clone();
move |fut| {
foreground_executor.spawn(fut).detach();
}
});
let io_task = cx.background_spawn(io_task);
let response = connection
.initialize(acp::InitializeRequest {
protocol_version: acp::VERSION,
client_capabilities: acp::ClientCapabilities {
fs: acp::FileSystemCapability {
read_text_file: true,
write_text_file: true,
},
},
})
.await?;
// todo! check version
Ok(Self {
auth_methods: response.auth_methods,
connection: connection.into(),
server_name,
sessions,
_io_task: io_task,
})
}
}
impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let response = conn
.new_session(acp::NewSessionRequest {
// todo! Zed MCP server?
mcp_servers: vec![],
cwd,
})
.await?;
let Some(session_id) = response.session_id else {
anyhow::bail!(AuthRequired);
};
let thread = cx.new(|cx| {
AcpThread::new(
self.server_name,
self.clone(),
project,
session_id.clone(),
cx,
)
})?;
let session = AcpSession {
thread: thread.downgrade(),
};
sessions.borrow_mut().insert(session_id, session);
Ok(thread)
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&self.auth_methods
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor().spawn(async move {
let result = conn
.authenticate(acp::AuthenticateRequest {
method_id: method_id.clone(),
})
.await?;
Ok(result)
})
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let conn = self.connection.clone();
cx.foreground_executor()
.spawn(async move { Ok(conn.prompt(params).await?) })
}
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
let conn = self.connection.clone();
let params = acp::CancelledNotification {
session_id: session_id.clone(),
};
cx.foreground_executor()
.spawn(async move { conn.cancelled(params).await })
.detach();
}
}
struct ClientDelegate {
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
cx: AsyncApp,
}
impl acp::Client for ClientDelegate {
async fn request_permission(
&self,
arguments: acp::RequestPermissionRequest,
) -> Result<acp::RequestPermissionResponse, acp::Error> {
let cx = &mut self.cx.clone();
let result = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
})?
.await;
let outcome = match result {
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
};
Ok(acp::RequestPermissionResponse { outcome })
}
async fn write_text_file(
&self,
arguments: acp::WriteTextFileRequest,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
self.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.write_text_file(arguments.path, arguments.content, cx)
})?
.await?;
Ok(())
}
async fn read_text_file(
&self,
arguments: acp::ReadTextFileRequest,
) -> Result<acp::ReadTextFileResponse, acp::Error> {
let cx = &mut self.cx.clone();
let content = self
.sessions
.borrow()
.get(&arguments.session_id)
.context("Failed to get session")?
.thread
.update(cx, |thread, cx| {
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
})?
.await?;
Ok(acp::ReadTextFileResponse { content })
}
async fn session_notification(
&self,
notification: acp::SessionNotification,
) -> Result<(), acp::Error> {
let cx = &mut self.cx.clone();
let sessions = self.sessions.borrow();
let session = sessions
.get(&notification.session_id)
.context("Failed to get session")?;
session.thread.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})??;
Ok(())
}
}

View File

@@ -1,12 +1,14 @@
mod acp_connection;
mod claude;
mod codex;
mod gemini;
mod mcp_server;
mod settings;
#[cfg(test)]
mod e2e_tests;
pub use claude::*;
pub use codex::*;
pub use gemini::*;
pub use settings::*;
@@ -36,6 +38,7 @@ pub trait AgentServer: Send {
fn connect(
&self,
// these will go away when old_acp is fully removed
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,

View File

@@ -70,6 +70,10 @@ struct ClaudeAgentConnection {
}
impl AgentConnection for ClaudeAgentConnection {
fn name(&self) -> &'static str {
ClaudeCode.name()
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -164,9 +168,8 @@ impl AgentConnection for ClaudeAgentConnection {
}
});
let thread = cx.new(|cx| {
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
})?;
let thread =
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
thread_tx.send(thread.downgrade())?;
@@ -183,15 +186,11 @@ impl AgentConnection for ClaudeAgentConnection {
})
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(&params.session_id) else {
return Task::ready(Err(anyhow!(

View File

@@ -0,0 +1,317 @@
use agent_client_protocol as acp;
use anyhow::anyhow;
use collections::HashMap;
use context_server::listener::McpServerTool;
use context_server::types::requests;
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use project::Project;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::cell::RefCell;
use std::rc::Rc;
use std::{path::Path, sync::Arc};
use util::ResultExt;
use anyhow::{Context, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use crate::mcp_server::ZedMcpServer;
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
use acp_thread::{AcpThread, AgentConnection};
#[derive(Clone)]
pub struct Codex;
impl AgentServer for Codex {
fn name(&self) -> &'static str {
"Codex"
}
fn empty_state_headline(&self) -> &'static str {
"Welcome to Codex"
}
fn empty_state_message(&self) -> &'static str {
"What can I help with?"
}
fn logo(&self) -> ui::IconName {
ui::IconName::AiOpenAi
}
fn connect(
&self,
_root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let project = project.clone();
cx.spawn(async move |cx| {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).codex.clone()
})?;
let Some(command) =
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
else {
anyhow::bail!("Failed to find codex binary");
};
let client: Arc<ContextServer> = ContextServer::stdio(
ContextServerId("codex-mcp-server".into()),
ContextServerCommand {
path: command.path,
args: command.args,
env: command.env,
},
)
.into();
ContextServer::start(client.clone(), cx).await?;
let (notification_tx, mut notification_rx) = mpsc::unbounded();
client
.client()
.context("Failed to subscribe")?
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
move |notification, _cx| {
let notification_tx = notification_tx.clone();
log::trace!(
"ACP Notification: {}",
serde_json::to_string_pretty(&notification).unwrap()
);
if let Some(notification) =
serde_json::from_value::<acp::SessionNotification>(notification)
.log_err()
{
notification_tx.unbounded_send(notification).ok();
}
}
});
let sessions = Rc::new(RefCell::new(HashMap::default()));
let notification_handler_task = cx.spawn({
let sessions = sessions.clone();
async move |cx| {
while let Some(notification) = notification_rx.next().await {
CodexConnection::handle_session_notification(
notification,
sessions.clone(),
cx,
)
}
}
});
let connection = CodexConnection {
client,
sessions,
_notification_handler_task: notification_handler_task,
};
Ok(Rc::new(connection) as _)
})
}
}
struct CodexConnection {
client: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
_notification_handler_task: Task<()>,
}
struct CodexSession {
thread: WeakEntity<AcpThread>,
cancel_tx: Option<oneshot::Sender<()>>,
_mcp_server: ZedMcpServer,
}
impl AgentConnection for CodexConnection {
fn name(&self) -> &'static str {
"Codex"
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let client = self.client.client();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
let client = client.context("MCP server is not initialized yet")?;
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
let response = client
.request::<requests::CallTool>(context_server::types::CallToolParams {
name: acp::NEW_SESSION_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
mcp_servers: [(
mcp_server::SERVER_NAME.to_string(),
mcp_server.server_config()?,
)]
.into(),
client_tools: acp::ClientTools {
request_permission: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
}),
read_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
}),
write_text_file: Some(acp::McpToolId {
mcp_server: mcp_server::SERVER_NAME.into(),
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
}),
},
cwd,
})?),
meta: None,
})
.await?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
let result = serde_json::from_value::<acp::NewSessionOutput>(
response.structured_content.context("Empty response")?,
)?;
let thread =
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
thread_tx.send(thread.downgrade())?;
let session = CodexSession {
thread: thread.downgrade(),
cancel_tx: None,
_mcp_server: mcp_server,
};
sessions.borrow_mut().insert(result.session_id, session);
Ok(thread)
})
}
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
fn prompt(
&self,
params: agent_client_protocol::PromptArguments,
cx: &mut App,
) -> Task<Result<()>> {
let client = self.client.client();
let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move {
let client = client.context("MCP server is not initialized yet")?;
let (new_cancel_tx, cancel_rx) = oneshot::channel();
{
let mut sessions = sessions.borrow_mut();
let session = sessions
.get_mut(&params.session_id)
.context("Session not found")?;
session.cancel_tx.replace(new_cancel_tx);
}
let result = client
.request_with::<requests::CallTool>(
context_server::types::CallToolParams {
name: acp::PROMPT_TOOL_NAME.into(),
arguments: Some(serde_json::to_value(params)?),
meta: None,
},
Some(cancel_rx),
None,
)
.await;
if let Err(err) = &result
&& err.is::<context_server::client::RequestCanceled>()
{
return Ok(());
}
let response = result?;
if response.is_error.unwrap_or_default() {
return Err(anyhow!(response.text_contents()));
}
Ok(())
})
}
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
let mut sessions = self.sessions.borrow_mut();
if let Some(cancel_tx) = sessions
.get_mut(session_id)
.and_then(|session| session.cancel_tx.take())
{
cancel_tx.send(()).ok();
}
}
}
impl CodexConnection {
pub fn handle_session_notification(
notification: acp::SessionNotification,
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
cx: &mut AsyncApp,
) {
let threads = threads.borrow();
let Some(thread) = threads
.get(&notification.session_id)
.and_then(|session| session.thread.upgrade())
else {
log::error!(
"Thread not found for session ID: {}",
notification.session_id
);
return;
};
thread
.update(cx, |thread, cx| {
thread.handle_session_update(notification.update, cx)
})
.log_err();
}
}
impl Drop for CodexConnection {
fn drop(&mut self) {
self.client.stop().log_err();
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::AgentServerCommand;
use std::path::Path;
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
pub fn local_command() -> AgentServerCommand {
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../../codex/codex-rs/target/debug/codex");
AgentServerCommand {
path: cli_path,
args: vec!["mcp".into()],
env: None,
}
}
}

View File

@@ -12,6 +12,7 @@ use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::{Entity, TestAppContext};
use indoc::indoc;
use project::{FakeFs, Project};
use serde_json::json;
use settings::{Settings, SettingsStore};
use util::path;
@@ -26,11 +27,7 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont
.unwrap();
thread.read_with(cx, |thread, _| {
assert!(
thread.entries().len() >= 2,
"Expected at least 2 entries. Got: {:?}",
thread.entries()
);
assert_eq!(thread.entries().len(), 2);
assert!(matches!(
thread.entries()[0],
AgentThreadEntry::UserMessage(_)
@@ -111,19 +108,19 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
}
pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
let _fs = init_test(cx).await;
let tempdir = tempfile::tempdir().unwrap();
let foo_path = tempdir.path().join("foo");
std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
let fs = init_test(cx).await;
fs.insert_tree(
path!("/private/tmp"),
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
)
.await;
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
thread
.update(cx, |thread, cx| {
thread.send_raw(
&format!("Read {} and tell me what you see.", foo_path.display()),
"Read the '/private/tmp/foo' file and tell me what you see.",
cx,
)
})
@@ -146,8 +143,6 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp
.any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
);
});
drop(tempdir);
}
pub async fn test_tool_call_with_confirmation(
@@ -160,7 +155,7 @@ pub async fn test_tool_call_with_confirmation(
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
cx,
)
});
@@ -180,10 +175,10 @@ pub async fn test_tool_call_with_confirmation(
)
.await;
let tool_call_id = thread.read_with(cx, |thread, cx| {
let tool_call_id = thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
id,
label,
content,
status: ToolCallStatus::WaitingForConfirmation { .. },
..
}) = &thread
@@ -195,8 +190,7 @@ pub async fn test_tool_call_with_confirmation(
panic!();
};
let label = label.read(cx).source();
assert!(label.contains("touch"), "Got: {}", label);
assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
id.clone()
});
@@ -248,7 +242,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
thread.send_raw(
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
cx,
)
});
@@ -268,10 +262,10 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
)
.await;
thread.read_with(cx, |thread, cx| {
thread.read_with(cx, |thread, _cx| {
let AgentThreadEntry::ToolCall(ToolCall {
id,
label,
content,
status: ToolCallStatus::WaitingForConfirmation { .. },
..
}) = &thread.entries()[first_tool_call_ix]
@@ -279,8 +273,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
panic!("{:?}", thread.entries()[1]);
};
let label = label.read(cx).source();
assert!(label.contains("touch"), "Got: {}", label);
assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
id.clone()
});
@@ -375,6 +368,9 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
gemini: Some(AgentServerSettings {
command: crate::gemini::tests::local_command(),
}),
codex: Some(AgentServerSettings {
command: crate::codex::tests::local_command(),
}),
},
cx,
);

View File

@@ -1,10 +1,14 @@
use anyhow::anyhow;
use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
use util::ResultExt as _;
use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection};
use acp_thread::AgentConnection;
use anyhow::Result;
use gpui::{Entity, Task};
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
use agentic_coding_protocol as acp_old;
use anyhow::{Context as _, Result};
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use settings::SettingsStore;
use ui::App;
@@ -39,27 +43,145 @@ impl AgentServer for Gemini {
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
let project = project.clone();
let server_name = self.name();
let root_dir = root_dir.to_path_buf();
let project = project.clone();
let this = self.clone();
let name = self.name();
cx.spawn(async move |cx| {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
let command = this.command(&project, cx).await?;
let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
else {
anyhow::bail!("Failed to find gemini binary");
};
// todo! check supported version
let mut child = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.current_dir(root_dir)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()?;
let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?;
Ok(Rc::new(conn) as _)
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let foreground_executor = cx.foreground_executor().clone();
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
stdin,
stdout,
move |fut| foreground_executor.spawn(fut).detach(),
);
let io_task = cx.background_spawn(async move {
io_fut.await.log_err();
});
let child_status = cx.background_spawn(async move {
let result = match child.status().await {
Err(e) => Err(anyhow!(e)),
Ok(result) if result.success() => Ok(()),
Ok(result) => {
if let Some(AgentServerVersion::Unsupported {
error_message,
upgrade_message,
upgrade_command,
}) = this.version(&command).await.log_err()
{
Err(anyhow!(LoadError::Unsupported {
error_message,
upgrade_message,
upgrade_command
}))
} else {
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
}
}
};
drop(io_task);
result
});
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
name,
connection,
child_status,
});
Ok(connection)
})
}
}
impl Gemini {
async fn command(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<AgentServerCommand> {
let settings = cx.read_global(|settings: &SettingsStore, _| {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
if let Some(command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
{
return Ok(command);
};
let (fs, node_runtime) = project.update(cx, |project, _| {
(project.fs().clone(), project.node_runtime().cloned())
})?;
let node_runtime = node_runtime.context("gemini not found on path")?;
let directory = ::paths::agent_servers_dir().join("gemini");
fs.create_dir(&directory).await?;
node_runtime
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
.await?;
let path = directory.join("node_modules/.bin/gemini");
Ok(AgentServerCommand {
path,
args: vec![ACP_ARG.into()],
env: None,
})
}
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
let version_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--version")
.kill_on_drop(true)
.output();
let help_fut = util::command::new_smol_command(&command.path)
.args(command.args.iter())
.arg("--help")
.kill_on_drop(true)
.output();
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
let current_version = String::from_utf8(version_output?.stdout)?;
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
if supported {
Ok(AgentServerVersion::Supported)
} else {
Ok(AgentServerVersion::Unsupported {
error_message: format!(
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
current_version
).into(),
upgrade_message: "Upgrade Gemini to Latest".into(),
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
})
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
@@ -76,7 +198,7 @@ pub(crate) mod tests {
AgentServerCommand {
path: "node".into(),
args: vec![cli_path],
args: vec![cli_path, ACP_ARG.into()],
env: None,
}
}

View File

@@ -0,0 +1,207 @@
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use anyhow::Result;
use context_server::listener::{McpServerTool, ToolResponse};
use context_server::types::{
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
ToolsCapabilities, requests,
};
use futures::channel::oneshot;
use gpui::{App, AsyncApp, Task, WeakEntity};
use indoc::indoc;
pub struct ZedMcpServer {
server: context_server::listener::McpServer,
}
pub const SERVER_NAME: &str = "zed";
impl ZedMcpServer {
pub async fn new(
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
cx: &AsyncApp,
) -> Result<Self> {
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
mcp_server.add_tool(RequestPermissionTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(ReadTextFileTool {
thread_rx: thread_rx.clone(),
});
mcp_server.add_tool(WriteTextFileTool {
thread_rx: thread_rx.clone(),
});
Ok(Self { server: mcp_server })
}
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
#[cfg(not(test))]
let zed_path = anyhow::Context::context(
std::env::current_exe(),
"finding current executable path for use in mcp_server",
)?;
#[cfg(test)]
let zed_path = crate::e2e_tests::get_zed_path();
Ok(acp::McpServerConfig {
command: zed_path,
args: vec![
"--nc".into(),
self.server.socket_path().display().to_string(),
],
env: None,
})
}
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
cx.foreground_executor().spawn(async move {
Ok(InitializeResponse {
protocol_version: ProtocolVersion("2025-06-18".into()),
capabilities: ServerCapabilities {
experimental: None,
logging: None,
completions: None,
prompts: None,
resources: None,
tools: Some(ToolsCapabilities {
list_changed: Some(false),
}),
},
server_info: Implementation {
name: SERVER_NAME.into(),
version: "0.1.0".into(),
},
meta: None,
})
})
}
}
// Tools
#[derive(Clone)]
pub struct RequestPermissionTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for RequestPermissionTool {
type Input = acp::RequestPermissionArguments;
type Output = acp::RequestPermissionOutput;
const NAME: &'static str = "Confirmation";
fn description(&self) -> &'static str {
indoc! {"
Request permission for tool calls.
This tool is meant to be called programmatically by the agent loop, not the LLM.
"}
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let result = thread
.update(cx, |thread, cx| {
thread.request_tool_call_permission(input.tool_call, input.options, cx)
})?
.await;
let outcome = match result {
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
};
Ok(ToolResponse {
content: vec![],
structured_content: acp::RequestPermissionOutput { outcome },
})
}
}
#[derive(Clone)]
pub struct ReadTextFileTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for ReadTextFileTool {
type Input = acp::ReadTextFileArguments;
type Output = acp::ReadTextFileOutput;
const NAME: &'static str = "Read";
fn description(&self) -> &'static str {
"Reads the content of the given file in the project including unsaved changes."
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
let content = thread
.update(cx, |thread, cx| {
thread.read_text_file(input.path, input.line, input.limit, false, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: acp::ReadTextFileOutput { content },
})
}
}
#[derive(Clone)]
pub struct WriteTextFileTool {
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
}
impl McpServerTool for WriteTextFileTool {
type Input = acp::WriteTextFileArguments;
type Output = ();
const NAME: &'static str = "Write";
fn description(&self) -> &'static str {
"Write to a file replacing its contents"
}
async fn run(
&self,
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
};
thread
.update(cx, |thread, cx| {
thread.write_text_file(input.path, input.content, cx)
})?
.await?;
Ok(ToolResponse {
content: vec![],
structured_content: (),
})
}
}

View File

@@ -13,6 +13,7 @@ pub fn init(cx: &mut App) {
pub struct AllAgentServersSettings {
pub gemini: Option<AgentServerSettings>,
pub claude: Option<AgentServerSettings>,
pub codex: Option<AgentServerSettings>,
}
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
@@ -29,13 +30,21 @@ impl settings::Settings for AllAgentServersSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default();
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
for AllAgentServersSettings {
gemini,
claude,
codex,
} in sources.defaults_and_customizations()
{
if gemini.is_some() {
settings.gemini = gemini.clone();
}
if claude.is_some() {
settings.claude = claude.clone();
}
if codex.is_some() {
settings.codex = codex.clone();
}
}
Ok(settings)

View File

@@ -13,7 +13,6 @@ path = "src/agent_settings.rs"
[dependencies]
anyhow.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
gpui.workspace = true
language_model.workspace = true
@@ -21,6 +20,7 @@ schemars.workspace = true
serde.workspace = true
settings.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
fs.workspace = true

View File

@@ -321,11 +321,11 @@ pub enum CompletionMode {
Burn,
}
impl From<CompletionMode> for cloud_llm_client::CompletionMode {
impl From<CompletionMode> for zed_llm_client::CompletionMode {
fn from(value: CompletionMode) -> Self {
match value {
CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal,
CompletionMode::Burn => cloud_llm_client::CompletionMode::Max,
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
}
}
}

View File

@@ -31,7 +31,6 @@ audio.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@@ -47,9 +46,9 @@ futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
indoc.workspace = true
http_client.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
inventory.workspace = true
itertools.workspace = true
jsonschema.workspace = true
@@ -98,6 +97,7 @@ watch.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
assistant_tools.workspace = true

View File

@@ -1,7 +1,5 @@
use acp_thread::{AgentConnection, Plan};
use agent_servers::AgentServer;
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
use audio::{Audio, Sound};
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::path::Path;
@@ -20,10 +18,10 @@ use editor::{
use file_icons::FileIcons;
use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
list, percentage, point, prelude::*, pulsating_between,
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
pulsating_between,
};
use language::language_settings::SoftWrap;
use language::{Buffer, Language};
@@ -47,10 +45,7 @@ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSe
use crate::acp::message_history::MessageHistory;
use crate::agent_diff::AgentDiff;
use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES};
use crate::ui::{AgentNotification, AgentNotificationEvent};
use crate::{
AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
};
use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll};
const RESPONSE_PADDING_X: Pixels = px(19.);
@@ -64,8 +59,6 @@ pub struct AcpThreadView {
message_set_from_history: bool,
_message_editor_subscription: Subscription,
mention_set: Arc<Mutex<MentionSet>>,
notifications: Vec<WindowHandle<AgentNotification>>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
last_error: Option<Entity<Markdown>>,
list_state: ListState,
auth_task: Option<Task<()>>,
@@ -181,8 +174,6 @@ impl AcpThreadView {
message_set_from_history: false,
_message_editor_subscription: message_editor_subscription,
mention_set,
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
diff_editors: Default::default(),
list_state: list_state,
last_error: None,
@@ -232,8 +223,7 @@ impl AcpThreadView {
{
Err(e) => {
let mut cx = cx.clone();
// todo! remove duplication
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
@@ -391,9 +381,7 @@ impl AcpThreadView {
return;
}
let Some(thread) = self.thread() else {
return;
};
let Some(thread) = self.thread() else { return };
let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx));
cx.spawn(async move |this, cx| {
@@ -576,30 +564,6 @@ impl AcpThreadView {
self.sync_thread_entry_view(index, window, cx);
self.list_state.splice(index..index + 1, 1);
}
AcpThreadEvent::ToolAuthorizationRequired => {
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
}
AcpThreadEvent::Stopped => {
let used_tools = thread.read(cx).used_tools_since_last_user_message();
self.notify_with_sound(
if used_tools {
"Finished running tools"
} else {
"New message"
},
IconName::ZedAssistant,
window,
cx,
);
}
AcpThreadEvent::Error => {
self.notify_with_sound(
"Agent stopped due to an error",
IconName::Warning,
window,
cx,
);
}
}
cx.notify();
}
@@ -676,18 +640,13 @@ impl AcpThreadView {
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
}
fn authenticate(
&mut self,
method: acp::AuthMethodId,
window: &mut Window,
cx: &mut Context<Self>,
) {
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
return;
};
self.last_error.take();
let authenticate = connection.authenticate(method, cx);
let authenticate = connection.authenticate(cx);
self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone();
let agent = self.agent.clone();
@@ -2201,154 +2160,6 @@ impl AcpThreadView {
self.list_state.scroll_to(ListOffset::default());
cx.notify();
}
fn notify_with_sound(
&mut self,
caption: impl Into<SharedString>,
icon: IconName,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.play_notification_sound(window, cx);
self.show_notification(caption, icon, window, cx);
}
fn play_notification_sound(&self, window: &Window, cx: &mut App) {
let settings = AgentSettings::get_global(cx);
if settings.play_sound_when_agent_done && !window.is_window_active() {
Audio::play_sound(Sound::AgentDone, cx);
}
}
fn show_notification(
&mut self,
caption: impl Into<SharedString>,
icon: IconName,
window: &mut Window,
cx: &mut Context<Self>,
) {
if window.is_window_active() || !self.notifications.is_empty() {
return;
}
let title = self.title(cx);
match AgentSettings::get_global(cx).notify_when_agent_waiting {
NotifyWhenAgentWaiting::PrimaryScreen => {
if let Some(primary) = cx.primary_display() {
self.pop_up(icon, caption.into(), title, window, primary, cx);
}
}
NotifyWhenAgentWaiting::AllScreens => {
let caption = caption.into();
for screen in cx.displays() {
self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx);
}
}
NotifyWhenAgentWaiting::Never => {
// Don't show anything
}
}
}
fn pop_up(
&mut self,
icon: IconName,
caption: SharedString,
title: SharedString,
window: &mut Window,
screen: Rc<dyn PlatformDisplay>,
cx: &mut Context<Self>,
) {
let options = AgentNotification::window_options(screen, cx);
let project_name = self.workspace.upgrade().and_then(|workspace| {
workspace
.read(cx)
.project()
.read(cx)
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).root_name().to_string())
});
if let Some(screen_window) = cx
.open_window(options, |_, cx| {
cx.new(|_| {
AgentNotification::new(title.clone(), caption.clone(), icon, project_name)
})
})
.log_err()
{
if let Some(pop_up) = screen_window.entity(cx).log_err() {
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push(cx.subscribe_in(&pop_up, window, {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true);
let workspace_handle = this.workspace.clone();
// If there are multiple Zed windows, activate the correct one.
cx.defer(move |cx| {
handle
.update(cx, |_view, window, _cx| {
window.activate_window();
if let Some(workspace) = workspace_handle.upgrade() {
workspace.update(_cx, |workspace, cx| {
workspace.focus_panel::<AgentPanel>(window, cx);
});
}
})
.log_err();
});
this.dismiss_notifications(cx);
}
AgentNotificationEvent::Dismissed => {
this.dismiss_notifications(cx);
}
}
}));
self.notifications.push(screen_window);
// If the user manually refocuses the original window, dismiss the popup.
self.notification_subscriptions
.entry(screen_window)
.or_insert_with(Vec::new)
.push({
let pop_up_weak = pop_up.downgrade();
cx.observe_window_activation(window, move |_, window, cx| {
if window.is_window_active() {
if let Some(pop_up) = pop_up_weak.upgrade() {
pop_up.update(cx, |_, cx| {
cx.emit(AgentNotificationEvent::Dismissed);
});
}
}
})
});
}
}
}
fn dismiss_notifications(&mut self, cx: &mut Context<Self>) {
for window in self.notifications.drain(..) {
window
.update(cx, |_, window, _| {
window.remove_window();
})
.ok();
self.notification_subscriptions.remove(&window);
}
}
}
impl Focusable for AcpThreadView {
@@ -2386,26 +2197,22 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::next_history_message))
.on_action(cx.listener(Self::open_agent_diff))
.child(match &self.thread_state {
ThreadState::Unauthenticated { connection } => v_flex()
.p_2()
.flex_1()
.items_center()
.justify_center()
.child(self.render_pending_auth_state())
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().into_iter().map(|method| {
Button::new(
SharedString::from(method.id.0.clone()),
method.label.clone(),
)
.on_click({
let method_id = method.id.clone();
cx.listener(move |this, _, window, cx| {
this.authenticate(method_id.clone(), window, cx)
})
})
}),
)),
ThreadState::Unauthenticated { .. } => {
v_flex()
.p_2()
.flex_1()
.items_center()
.justify_center()
.child(self.render_pending_auth_state())
.child(
h_flex().mt_1p5().justify_center().child(
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
.on_click(cx.listener(|this, _, window, cx| {
this.authenticate(window, cx)
})),
),
)
}
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex()
.p_2()
@@ -2634,341 +2441,3 @@ fn plan_label_markdown_style(
..default_md_style
}
}
#[cfg(test)]
mod tests {
use agent_client_protocol::SessionId;
use editor::EditorSettings;
use fs::FakeFs;
use futures::future::try_join_all;
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
use rand::Rng;
use settings::SettingsStore;
use super::*;
#[gpui::test]
async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
init_test(cx);
let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await;
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello", window, cx);
});
cx.deactivate_window();
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.chat(&Chat, window, cx);
});
cx.run_until_parked();
assert!(
cx.windows()
.iter()
.any(|window| window.downcast::<AgentNotification>().is_some())
);
}
#[gpui::test]
async fn test_notification_for_error(cx: &mut TestAppContext) {
init_test(cx);
let (thread_view, cx) =
setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await;
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello", window, cx);
});
cx.deactivate_window();
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.chat(&Chat, window, cx);
});
cx.run_until_parked();
assert!(
cx.windows()
.iter()
.any(|window| window.downcast::<AgentNotification>().is_some())
);
}
#[gpui::test]
async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
init_test(cx);
let tool_call_id = acp::ToolCallId("1".into());
let tool_call = acp::ToolCall {
id: tool_call_id.clone(),
label: "Label".into(),
kind: acp::ToolKind::Edit,
status: acp::ToolCallStatus::Pending,
content: vec!["hi".into()],
locations: vec![],
raw_input: None,
};
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
.with_permission_requests(HashMap::from_iter([(
tool_call_id,
vec![acp::PermissionOption {
id: acp::PermissionOptionId("1".into()),
label: "Allow".into(),
kind: acp::PermissionOptionKind::AllowOnce,
}],
)]));
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello", window, cx);
});
cx.deactivate_window();
thread_view.update_in(cx, |thread_view, window, cx| {
thread_view.chat(&Chat, window, cx);
});
cx.run_until_parked();
assert!(
cx.windows()
.iter()
.any(|window| window.downcast::<AgentNotification>().is_some())
);
}
async fn setup_thread_view(
agent: impl AgentServer + 'static,
cx: &mut TestAppContext,
) -> (Entity<AcpThreadView>, &mut VisualTestContext) {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_view = cx.update(|window, cx| {
cx.new(|cx| {
AcpThreadView::new(
Rc::new(agent),
workspace.downgrade(),
project,
Rc::new(RefCell::new(MessageHistory::default())),
1,
None,
window,
cx,
)
})
});
cx.run_until_parked();
(thread_view, cx)
}
struct StubAgentServer<C> {
connection: C,
}
impl<C> StubAgentServer<C> {
fn new(connection: C) -> Self {
Self { connection }
}
}
impl StubAgentServer<StubAgentConnection> {
fn default() -> Self {
Self::new(StubAgentConnection::default())
}
}
impl<C> AgentServer for StubAgentServer<C>
where
C: 'static + AgentConnection + Send + Clone,
{
fn logo(&self) -> ui::IconName {
unimplemented!()
}
fn name(&self) -> &'static str {
unimplemented!()
}
fn empty_state_headline(&self) -> &'static str {
unimplemented!()
}
fn empty_state_message(&self) -> &'static str {
unimplemented!()
}
fn connect(
&self,
_root_dir: &Path,
_project: &Entity<Project>,
_cx: &mut App,
) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
Task::ready(Ok(Rc::new(self.connection.clone())))
}
}
#[derive(Clone, Default)]
struct StubAgentConnection {
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
updates: Vec<acp::SessionUpdate>,
}
impl StubAgentConnection {
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
Self {
updates,
permission_requests: HashMap::default(),
sessions: Arc::default(),
}
}
fn with_permission_requests(
mut self,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
) -> Self {
self.permission_requests = permission_requests;
self
}
}
impl AgentConnection for StubAgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::AsyncApp,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = SessionId(
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(7)
.map(char::from)
.collect::<String>()
.into(),
);
let thread = cx
.new(|cx| {
AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx)
})
.unwrap();
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
todo!()
}
fn authenticate(
&self,
_method: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!()
}
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![];
for update in &self.updates {
let thread = thread.clone();
let update = update.clone();
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
&& let Some(options) = self.permission_requests.get(&tool_call.id)
{
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_permission(
tool_call.clone(),
options.clone(),
cx,
)
})?;
permission.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(())
})
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
}
#[derive(Clone)]
struct SaboteurAgentConnection;
impl AgentConnection for SaboteurAgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::AsyncApp,
) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx
.new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx))
.unwrap()))
}
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
todo!()
}
fn authenticate(
&self,
_method: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!()
}
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
Task::ready(Err(anyhow::anyhow!("Error prompting")))
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AgentSettings::register(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
release_channel::init(SemanticVersion::default(), cx);
EditorSettings::register(cx);
});
}
}

View File

@@ -14,7 +14,6 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
use anyhow::Context as _;
use assistant_tool::ToolUseStatus;
use audio::{Audio, Sound};
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::scroll::Autoscroll;
@@ -53,6 +52,7 @@ use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
use zed_llm_client::CompletionIntent;
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;

View File

@@ -7,7 +7,6 @@ use std::{sync::Arc, time::Duration};
use agent_settings::AgentSettings;
use assistant_tool::{ToolSource, ToolWorkingSet};
use cloud_llm_client::Plan;
use collections::HashMap;
use context_server::ContextServerId;
use extension::ExtensionManifest;
@@ -26,6 +25,7 @@ use project::{
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
project_settings::{ContextServerSettings, ProjectSettings},
};
use proto::Plan;
use settings::{Settings, update_settings_file};
use ui::{
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
@@ -180,7 +180,7 @@ impl AgentConfiguration {
let current_plan = if is_zed_provider {
self.workspace
.upgrade()
.and_then(|workspace| workspace.read(cx).user_store().read(cx).plan())
.and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan())
} else {
None
};
@@ -406,9 +406,7 @@ impl AgentConfiguration {
SwitchField::new(
"always-allow-tool-actions-switch",
"Allow running commands without asking for confirmation",
Some(
"The agent can perform potentially destructive actions without asking for your confirmation.".into(),
),
"The agent can perform potentially destructive actions without asking for your confirmation.",
always_allow_tool_actions,
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
@@ -426,7 +424,7 @@ impl AgentConfiguration {
SwitchField::new(
"single-file-review",
"Enable single-file agent reviews",
Some("Agent edits are also displayed in single-file editors for review.".into()),
"Agent edits are also displayed in single-file editors for review.",
single_file_review,
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
@@ -444,9 +442,7 @@ impl AgentConfiguration {
SwitchField::new(
"sound-notification",
"Play sound when finished generating",
Some(
"Hear a notification sound when the agent is done generating changes or needs your input.".into(),
),
"Hear a notification sound when the agent is done generating changes or needs your input.",
play_sound_when_agent_done,
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
@@ -464,9 +460,7 @@ impl AgentConfiguration {
SwitchField::new(
"modifier-send",
"Use modifier to submit a message",
Some(
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
),
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.",
use_modifier_to_send,
move |state, _window, cx| {
let allow = state == &ToggleState::Selected;
@@ -508,7 +502,7 @@ impl AgentConfiguration {
.blend(cx.theme().colors().text_accent.opacity(0.2));
let (plan_name, label_color, bg_color) = match plan {
Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
Plan::Free => ("Free", Color::Default, free_chip_bg),
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
};

View File

@@ -1521,9 +1521,6 @@ impl AgentDiff {
self.update_reviewing_editors(workspace, window, cx);
}
}
AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error => {}
}
}

View File

@@ -44,7 +44,6 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
use client::{DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use feature_flags::{self, FeatureFlagAppExt};
use fs::Fs;
@@ -60,6 +59,7 @@ use language_model::{
};
use project::{Project, ProjectPath, Worktree};
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
use proto::Plan;
use rules_library::{RulesLibrary, open_rules_library};
use search::{BufferSearchBar, buffer_search};
use settings::{Settings, update_settings_file};
@@ -77,9 +77,10 @@ use workspace::{
};
use zed_actions::{
DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize,
agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector},
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector},
assistant::{OpenRulesLibrary, ToggleFocus},
};
use zed_llm_client::{CompletionIntent, UsageLimit};
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -104,7 +105,7 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.open_history(window, cx));
}
})
.register_action(|workspace, _: &OpenSettings, window, cx| {
.register_action(|workspace, _: &OpenConfiguration, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
@@ -578,6 +579,7 @@ impl AgentPanel {
MessageEditor::new(
fs.clone(),
workspace.clone(),
user_store.clone(),
message_editor_context_store.clone(),
prompt_store.clone(),
thread_store.downgrade(),
@@ -846,6 +848,7 @@ impl AgentPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
self.user_store.clone(),
context_store.clone(),
self.prompt_store.clone(),
self.thread_store.downgrade(),
@@ -1119,6 +1122,7 @@ impl AgentPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
self.user_store.clone(),
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
@@ -1987,6 +1991,20 @@ impl AgentPanel {
);
}),
)
.item(
ContextMenuEntry::new("New Codex Thread")
.icon(IconName::AiOpenAi)
.icon_color(Color::Muted)
.handler(move |window, cx| {
window.dispatch_action(
NewExternalAgentThread {
agent: Some(crate::ExternalAgent::Codex),
}
.boxed_clone(),
cx,
);
}),
)
});
menu
}))
@@ -2070,7 +2088,7 @@ impl AgentPanel {
menu = menu
.action("Rules…", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenSettings))
.action("Settings", Box::new(OpenConfiguration))
.action(zoom_in_label, Box::new(ToggleZoom));
menu
}))
@@ -2275,10 +2293,10 @@ impl AgentPanel {
| ActiveView::Configuration => return false,
}
let plan = self.user_store.read(cx).plan();
let plan = self.user_store.read(cx).current_plan();
let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some();
matches!(plan, Some(Plan::ZedFree)) && has_previous_trial
matches!(plan, Some(Plan::Free)) && has_previous_trial
}
fn should_render_onboarding(&self, cx: &mut Context<Self>) -> bool {
@@ -2464,14 +2482,14 @@ impl AgentPanel {
.icon_color(Color::Muted)
.full_width()
.key_binding(KeyBinding::for_action_in(
&OpenSettings,
&OpenConfiguration,
&focus_handle,
window,
cx,
))
.on_click(|_event, window, cx| {
window.dispatch_action(
OpenSettings.boxed_clone(),
OpenConfiguration.boxed_clone(),
cx,
)
}),
@@ -2648,6 +2666,25 @@ impl AgentPanel {
)
},
),
)
.child(
NewThreadButton::new(
"new-codex-thread-btn",
"New Codex Thread",
IconName::AiOpenAi,
)
.on_click(
|window, cx| {
window.dispatch_action(
Box::new(NewExternalAgentThread {
agent: Some(
crate::ExternalAgent::Codex,
),
}),
cx,
)
},
),
),
)
}),
@@ -2676,11 +2713,16 @@ impl AgentPanel {
.style(ButtonStyle::Tinted(ui::TintColor::Warning))
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.))),
KeyBinding::for_action_in(
&OpenConfiguration,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(|_event, window, cx| {
window.dispatch_action(OpenSettings.boxed_clone(), cx)
window.dispatch_action(OpenConfiguration.boxed_clone(), cx)
}),
),
ConfigurationError::ProviderPendingTermsAcceptance(provider) => {
@@ -2874,7 +2916,7 @@ impl AgentPanel {
) -> AnyElement {
let error_message = match plan {
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.",
Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.",
};
let icon = Icon::new(IconName::XCircle)
@@ -3184,7 +3226,7 @@ impl Render for AgentPanel {
.on_action(cx.listener(|this, _: &OpenHistory, window, cx| {
this.open_history(window, cx);
}))
.on_action(cx.listener(|this, _: &OpenSettings, window, cx| {
.on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| {
this.open_configuration(window, cx);
}))
.on_action(cx.listener(Self::open_active_thread_as_markdown))

View File

@@ -150,6 +150,7 @@ enum ExternalAgent {
#[default]
Gemini,
ClaudeCode,
Codex,
}
impl ExternalAgent {
@@ -157,6 +158,7 @@ impl ExternalAgent {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
}
}
}
@@ -263,8 +265,8 @@ fn update_command_palette_filter(cx: &mut App) {
filter.hide_namespace("agent");
filter.hide_namespace("assistant");
filter.hide_namespace("copilot");
filter.hide_namespace("supermaven");
filter.hide_namespace("zed_predict_onboarding");
filter.hide_namespace("edit_prediction");
use editor::actions::{

View File

@@ -6,7 +6,6 @@ use agent::{
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{
@@ -36,6 +35,7 @@ use std::{
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use zed_llm_client::CompletionIntent;
pub struct BufferCodegen {
alternatives: Vec<Entity<CodegenAlternative>>,

View File

@@ -1,10 +1,10 @@
#![allow(unused, dead_code)]
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{Plan, UsageLimit};
use gpui::Global;
use std::ops::{Deref, DerefMut};
use ui::prelude::*;
use zed_llm_client::{Plan, UsageLimit};
/// Debug only: Used for testing various account states
///

View File

@@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _};
use ui::prelude::*;
use util::{RangeExt, ResultExt, maybe};
use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId};
use zed_actions::agent::OpenSettings;
use zed_actions::agent::OpenConfiguration;
pub fn init(
fs: Arc<dyn Fs>,
@@ -345,7 +345,7 @@ impl InlineAssistant {
if let Some(answer) = answer {
if answer == 0 {
cx.update(|window, cx| {
window.dispatch_action(Box::new(OpenSettings), cx)
window.dispatch_action(Box::new(OpenConfiguration), cx)
})
.ok();
}

View File

@@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.icon_position(IconPosition::Start)
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::agent::OpenSettings.boxed_clone(),
zed_actions::agent::OpenConfiguration.boxed_clone(),
cx,
);
}),

View File

@@ -17,7 +17,7 @@ use agent::{
use agent_settings::{AgentSettings, CompletionMode};
use ai_onboarding::ApiKeysWithProviders;
use buffer_diff::BufferDiff;
use cloud_llm_client::CompletionIntent;
use client::UserStore;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::display_map::CreaseId;
@@ -42,6 +42,7 @@ use language_model::{
use multi_buffer;
use project::Project;
use prompt_store::PromptStore;
use proto::Plan;
use settings::Settings;
use std::time::Duration;
use theme::ThemeSettings;
@@ -52,6 +53,7 @@ use util::ResultExt as _;
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::Chat;
use zed_actions::agent::ToggleModelSelector;
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@@ -77,6 +79,7 @@ pub struct MessageEditor {
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
user_store: Entity<UserStore>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
history_store: Option<WeakEntity<HistoryStore>>,
@@ -156,6 +159,7 @@ impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
context_store: Entity<ContextStore>,
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
@@ -227,6 +231,7 @@ impl MessageEditor {
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
user_store,
thread,
incompatible_tools_state: incompatible_tools.clone(),
workspace,
@@ -1282,12 +1287,24 @@ impl MessageEditor {
return None;
}
let user_store = self.project.read(cx).user_store().read(cx);
if user_store.is_usage_based_billing_enabled() {
let user_store = self.user_store.read(cx);
let ubb_enable = user_store
.usage_based_billing_enabled()
.map_or(false, |enabled| enabled);
if ubb_enable {
return None;
}
let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree);
let plan = user_store
.current_plan()
.map(|plan| match plan {
Plan::Free => zed_llm_client::Plan::ZedFree,
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
})
.unwrap_or(zed_llm_client::Plan::ZedFree);
let usage = user_store.model_request_usage()?;
@@ -1752,6 +1769,7 @@ impl AgentPreview for MessageEditor {
) -> Option<AnyElement> {
if let Some(workspace) = workspace.upgrade() {
let fs = workspace.read(cx).app_state().fs.clone();
let user_store = workspace.read(cx).app_state().user_store.clone();
let project = workspace.read(cx).project().clone();
let weak_project = project.downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
@@ -1764,6 +1782,7 @@ impl AgentPreview for MessageEditor {
MessageEditor::new(
fs,
workspace.downgrade(),
user_store,
context_store,
None,
thread_store.downgrade(),

View File

@@ -10,7 +10,6 @@ use agent::{
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll};
use fs::Fs;
@@ -28,6 +27,7 @@ use terminal_view::TerminalView;
use ui::prelude::*;
use util::ResultExt;
use workspace::{Toast, Workspace, notifications::NotificationId};
use zed_llm_client::CompletionIntent;
pub fn init(
fs: Arc<dyn Fs>,

View File

@@ -1,8 +1,8 @@
use client::{ModelRequestUsage, RequestUsage, zed_urls};
use cloud_llm_client::{Plan, UsageLimit};
use component::{empty_example, example_group_with_title, single_example};
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Callout, prelude::*};
use zed_llm_client::{Plan, UsageLimit};
#[derive(IntoElement, RegisterComponent)]
pub struct UsageCallout {

View File

@@ -16,10 +16,10 @@ default = []
[dependencies]
client.workspace = true
cloud_llm_client.workspace = true
component.workspace = true
gpui.workspace = true
language_model.workspace = true
proto.workspace = true
serde.workspace = true
smallvec.workspace = true
telemetry.workspace = true

View File

@@ -136,7 +136,10 @@ impl RenderOnce for ApiKeysWithoutProviders {
.full_width()
.style(ButtonStyle::Outlined)
.on_click(move |_, window, cx| {
window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
window.dispatch_action(
zed_actions::agent::OpenConfiguration.boxed_clone(),
cx,
);
}),
)
}

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use client::{Client, UserStore};
use cloud_llm_client::Plan;
use gpui::{Entity, IntoElement, ParentElement};
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use ui::prelude::*;
@@ -57,8 +56,15 @@ impl AgentPanelOnboarding {
impl Render for AgentPanelOnboarding {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial);
let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro);
let enrolled_in_trial = matches!(
self.user_store.read(cx).current_plan(),
Some(proto::Plan::ZedProTrial)
);
let is_pro_user = matches!(
self.user_store.read(cx).current_plan(),
Some(proto::Plan::ZedPro)
);
AgentPanelOnboardingCard::new()
.child(

View File

@@ -9,7 +9,6 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider
pub use agent_panel_onboarding_card::AgentPanelOnboardingCard;
pub use agent_panel_onboarding_content::AgentPanelOnboarding;
pub use ai_upsell_card::AiUpsellCard;
use cloud_llm_client::Plan;
pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
pub use young_account_banner::YoungAccountBanner;
@@ -80,7 +79,7 @@ impl From<client::Status> for SignInStatus {
pub struct ZedAiOnboarding {
pub sign_in_status: SignInStatus,
pub has_accepted_terms_of_service: bool,
pub plan: Option<Plan>,
pub plan: Option<proto::Plan>,
pub account_too_young: bool,
pub continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
@@ -100,8 +99,8 @@ impl ZedAiOnboarding {
Self {
sign_in_status: status.into(),
has_accepted_terms_of_service: store.has_accepted_terms_of_service(),
plan: store.plan(),
has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false),
plan: store.current_plan(),
account_too_young: store.account_too_young(),
continue_with_zed_ai,
accept_terms_of_service: Arc::new({
@@ -114,9 +113,11 @@ impl ZedAiOnboarding {
sign_in: Arc::new(move |_window, cx| {
cx.spawn({
let client = client.clone();
async move |cx| client.sign_in_with_optional_connect(true, cx).await
async move |cx| {
client.authenticate_and_connect(true, cx).await;
}
})
.detach_and_log_err(cx);
.detach();
}),
dismiss_onboarding: None,
}
@@ -410,9 +411,9 @@ impl RenderOnce for ZedAiOnboarding {
if matches!(self.sign_in_status, SignInStatus::SignedIn) {
if self.has_accepted_terms_of_service {
match self.plan {
None | Some(Plan::ZedFree) => self.render_free_plan_state(cx),
Some(Plan::ZedProTrial) => self.render_trial_state(cx),
Some(Plan::ZedPro) => self.render_pro_plan_state(cx),
None | Some(proto::Plan::Free) => self.render_free_plan_state(cx),
Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx),
Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx),
}
} else {
self.render_accept_terms_of_service()
@@ -432,7 +433,7 @@ impl Component for ZedAiOnboarding {
fn onboarding(
sign_in_status: SignInStatus,
has_accepted_terms_of_service: bool,
plan: Option<Plan>,
plan: Option<proto::Plan>,
account_too_young: bool,
) -> AnyElement {
ZedAiOnboarding {
@@ -467,15 +468,25 @@ impl Component for ZedAiOnboarding {
),
single_example(
"Free Plan",
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false),
onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false),
),
single_example(
"Pro Trial",
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false),
onboarding(
SignInStatus::SignedIn,
true,
Some(proto::Plan::ZedProTrial),
false,
),
),
single_example(
"Pro Plan",
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false),
onboarding(
SignInStatus::SignedIn,
true,
Some(proto::Plan::ZedPro),
false,
),
),
])
.into_any_element(),

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use client::{Client, zed_urls};
use cloud_llm_client::Plan;
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Divider, List, Vector, VectorName, prelude::*};
@@ -11,22 +10,22 @@ use crate::{BulletItem, SignInStatus};
pub struct AiUpsellCard {
pub sign_in_status: SignInStatus,
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
pub user_plan: Option<Plan>,
}
impl AiUpsellCard {
pub fn new(client: Arc<Client>, user_plan: Option<Plan>) -> Self {
pub fn new(client: Arc<Client>) -> Self {
let status = *client.status().borrow();
Self {
user_plan,
sign_in_status: status.into(),
sign_in: Arc::new(move |_window, cx| {
cx.spawn({
let client = client.clone();
async move |cx| client.sign_in_with_optional_connect(true, cx).await
async move |cx| {
client.authenticate_and_connect(true, cx).await;
}
})
.detach_and_log_err(cx);
.detach();
}),
}
}
@@ -35,7 +34,6 @@ impl AiUpsellCard {
impl RenderOnce for AiUpsellCard {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let pro_section = v_flex()
.flex_grow()
.w_full()
.gap_1()
.child(
@@ -58,7 +56,6 @@ impl RenderOnce for AiUpsellCard {
);
let free_section = v_flex()
.flex_grow()
.w_full()
.gap_1()
.child(
@@ -74,7 +71,7 @@ impl RenderOnce for AiUpsellCard {
)
.child(
List::new()
.child(BulletItem::new("50 prompts with Claude models"))
.child(BulletItem::new("50 prompts with the Claude models"))
.child(BulletItem::new("2,000 accepted edit predictions")),
);
@@ -135,28 +132,22 @@ impl RenderOnce for AiUpsellCard {
v_flex()
.relative()
.p_4()
.pt_3()
.p_6()
.pt_4()
.border_1()
.border_color(cx.theme().colors().border)
.rounded_lg()
.overflow_hidden()
.child(grid_bg)
.child(gradient_bg)
.child(Label::new("Try Zed AI").size(LabelSize::Large))
.child(
div()
.max_w_3_4()
.mb_2()
.child(Label::new(DESCRIPTION).color(Color::Muted)),
)
.child(Headline::new("Try Zed AI"))
.child(Label::new(DESCRIPTION).color(Color::Muted).mb_2())
.child(
h_flex()
.w_full()
.mt_1p5()
.mb_2p5()
.items_start()
.gap_6()
.gap_12()
.child(free_section)
.child(pro_section),
)
@@ -192,7 +183,6 @@ impl Component for AiUpsellCard {
AiUpsellCard {
sign_in_status: SignInStatus::SignedOut,
sign_in: Arc::new(|_, _| {}),
user_plan: None,
}
.into_any_element(),
),
@@ -201,7 +191,6 @@ impl Component for AiUpsellCard {
AiUpsellCard {
sign_in_status: SignInStatus::SignedIn,
sign_in: Arc::new(|_, _| {}),
user_plan: None,
}
.into_any_element(),
),

View File

@@ -19,7 +19,6 @@ assistant_slash_commands.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true
fs.workspace = true
@@ -49,6 +48,7 @@ util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
indoc.workspace = true

View File

@@ -11,7 +11,6 @@ use assistant_slash_command::{
use assistant_slash_commands::FileCommandMetadata;
use client::{self, Client, proto, telemetry::Telemetry};
use clock::ReplicaId;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared};
@@ -47,6 +46,7 @@ use text::{BufferSnapshot, ToPoint};
use ui::IconName;
use util::{ResultExt, TryFutureExt, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionIntent;
pub use crate::context_store::*;

View File

@@ -21,11 +21,9 @@ assistant_tool.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
derive_more.workspace = true
diffy = "0.4.2"
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
@@ -65,6 +63,8 @@ web_search.workspace = true
which.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_llm_client.workspace = true
diffy = "0.4.2"
[dev-dependencies]
lsp = { workspace = true, features = ["test-support"] }

View File

@@ -7,7 +7,6 @@ mod streaming_fuzzy_matcher;
use crate::{Template, Templates};
use anyhow::Result;
use assistant_tool::ActionLog;
use cloud_llm_client::CompletionIntent;
use create_file_parser::{CreateFileParser, CreateFileParserEvent};
pub use edit_parser::EditFormat;
use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
@@ -30,6 +29,7 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::
use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
use util::debug_panic;
use zed_llm_client::CompletionIntent;
#[derive(Serialize)]
struct CreateFilePromptTemplate {

View File

@@ -6,7 +6,6 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
};
use cloud_llm_client::{WebSearchResponse, WebSearchResult};
use futures::{Future, FutureExt, TryFutureExt};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
@@ -18,6 +17,7 @@ use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use workspace::Workspace;
use zed_llm_client::{WebSearchResponse, WebSearchResult};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct WebSearchToolInput {

View File

@@ -18,6 +18,6 @@ collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
rodio = { version = "0.20.0", default-features = false, features = ["wav"] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -3,9 +3,12 @@ use std::{io::Cursor, sync::Arc};
use anyhow::{Context as _, Result};
use collections::HashMap;
use gpui::{App, AssetSource, Global};
use rodio::{Decoder, Source, source::Buffered};
use rodio::{
Decoder, Source,
source::{Buffered, SamplesConverter},
};
type Sound = Buffered<Decoder<Cursor<Vec<u8>>>>;
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
pub struct SoundRegistry {
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
@@ -45,7 +48,7 @@ impl SoundRegistry {
.with_context(|| format!("No asset available for path {path}"))??
.into_owned();
let cursor = Cursor::new(bytes);
let source = Decoder::new(cursor)?.buffered();
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
self.cache.lock().insert(name.to_string(), source.clone());

View File

@@ -1,7 +1,7 @@
use assets::SoundRegistry;
use derive_more::{Deref, DerefMut};
use gpui::{App, AssetSource, BorrowAppContext, Global};
use rodio::{OutputStream, OutputStreamBuilder};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;
mod assets;
@@ -37,7 +37,8 @@ impl Sound {
#[derive(Default)]
pub struct Audio {
output_handle: Option<OutputStream>,
_output_stream: Option<OutputStream>,
output_handle: Option<OutputStreamHandle>,
}
#[derive(Deref, DerefMut)]
@@ -50,9 +51,11 @@ impl Audio {
Self::default()
}
fn ensure_output_exists(&mut self) -> Option<&OutputStream> {
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
if self.output_handle.is_none() {
self.output_handle = OutputStreamBuilder::open_default_stream().log_err();
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
self.output_handle = output_handle;
self._output_stream = _output_stream;
}
self.output_handle.as_ref()
@@ -66,7 +69,7 @@ impl Audio {
cx.update_global::<GlobalAudio, _>(|this, cx| {
let output_handle = this.ensure_output_exists()?;
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
output_handle.mixer().add(source);
output_handle.play_raw(source).log_err()?;
Some(())
});
}
@@ -77,6 +80,7 @@ impl Audio {
}
cx.update_global::<GlobalAudio, _>(|this, _| {
this._output_stream.take();
this.output_handle.take();
});
}

View File

@@ -126,7 +126,7 @@ impl ChannelMembership {
proto::channel_member::Kind::Member => 0,
proto::channel_member::Kind::Invitee => 1,
},
username_order: &self.user.github_login,
username_order: self.user.github_login.as_str(),
}
}
}

View File

@@ -259,6 +259,20 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx);
});
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![5]);
server.respond(
get_users.receipt(),
proto::UsersResponse {
users: vec![proto::User {
id: 5,
github_login: "nathansobo".into(),
avatar_url: "http://avatar.com/nathansobo".into(),
name: None,
}],
},
);
// Join a channel and populate its existing messages.
let channel = channel_store.update(cx, |store, cx| {
let channel_id = store.ordered_channels().next().unwrap().1.id;
@@ -320,7 +334,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("user-5".into(), "a".into()),
("nathansobo".into(), "a".into()),
("maxbrunsfeld".into(), "b".into())
]
);
@@ -423,7 +437,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
.collect::<Vec<_>>(),
&[
("user-5".into(), "y".into()),
("nathansobo".into(), "y".into()),
("maxbrunsfeld".into(), "z".into())
]
);

View File

@@ -17,12 +17,11 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
[dependencies]
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
cloud_api_client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
derive_more.workspace = true
@@ -34,8 +33,8 @@ http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
regex.workspace = true
@@ -47,18 +46,19 @@ serde_json.workspace = true
settings.workspace = true
sha2.workspace = true
smol.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
tokio.workspace = true
url.workspace = true
util.workspace = true
workspace-hack.workspace = true
worktree.workspace = true
telemetry.workspace = true
tokio.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -6,21 +6,22 @@ pub mod telemetry;
pub mod user;
pub mod zed_urls;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result, anyhow, bail};
use async_recursion::async_recursion;
use async_tungstenite::tungstenite::{
client::IntoClientRequest,
error::Error as WebsocketError,
http::{HeaderValue, Request, StatusCode},
};
use chrono::{DateTime, Utc};
use clock::SystemClock;
use cloud_api_client::CloudApiClient;
use credentials_provider::CredentialsProvider;
use futures::{
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
channel::oneshot, future::BoxFuture,
};
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{HttpClient, HttpClientWithUrl, http};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use parking_lot::RwLock;
use postage::watch;
use proxy::connect_proxy_stream;
@@ -30,6 +31,7 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use std::pin::Pin;
use std::{
any::TypeId,
convert::TryFrom,
@@ -43,7 +45,6 @@ use std::{
},
time::{Duration, Instant},
};
use std::{cmp, pin::Pin};
use telemetry::Telemetry;
use thiserror::Error;
use tokio::net::TcpStream;
@@ -77,7 +78,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
actions!(
@@ -161,8 +162,20 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
let client = client.clone();
move |_: &SignIn, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
.detach_and_log_err(cx);
cx.spawn(
async move |cx| match client.authenticate_and_connect(true, &cx).await {
ConnectionResult::Timeout => {
log::error!("Initial authentication timed out");
}
ConnectionResult::ConnectionReset => {
log::error!("Initial authentication connection reset");
}
ConnectionResult::Result(r) => {
r.log_err();
}
},
)
.detach();
}
}
});
@@ -200,7 +213,6 @@ pub struct Client {
id: AtomicU64,
peer: Arc<Peer>,
http: Arc<HttpClientWithUrl>,
cloud_client: Arc<CloudApiClient>,
telemetry: Arc<Telemetry>,
credentials_provider: ClientCredentialsProvider,
state: RwLock<ClientState>,
@@ -271,8 +283,6 @@ pub enum Status {
SignedOut,
UpgradeRequired,
Authenticating,
Authenticated,
AuthenticationError,
Connecting,
ConnectionError,
Connected {
@@ -576,7 +586,6 @@ impl Client {
id: AtomicU64::new(0),
peer: Peer::new(0),
telemetry: Telemetry::new(clock, http.clone(), cx),
cloud_client: Arc::new(CloudApiClient::new(http.clone())),
http,
credentials_provider: ClientCredentialsProvider::new(cx),
state: Default::default(),
@@ -609,10 +618,6 @@ impl Client {
self.http.clone()
}
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
self.cloud_client.clone()
}
pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, Ordering::SeqCst);
self
@@ -699,7 +704,7 @@ impl Client {
let mut delay = INITIAL_RECONNECTION_DELAY;
loop {
match client.connect(true, &cx).await {
match client.authenticate_and_connect(true, &cx).await {
ConnectionResult::Timeout => {
log::error!("client connect attempt timed out")
}
@@ -722,10 +727,11 @@ impl Client {
},
&cx,
);
let jitter =
Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
cx.background_executor().timer(delay + jitter).await;
delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
cx.background_executor().timer(delay).await;
delay = delay
.mul_f32(rng.gen_range(0.5..=2.5))
.max(INITIAL_RECONNECTION_DELAY)
.min(MAX_RECONNECTION_DELAY);
} else {
break;
}
@@ -869,122 +875,17 @@ impl Client {
.is_some()
}
pub async fn sign_in(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<Credentials> {
if self.status().borrow().is_signed_out() {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx);
}
let mut credentials = None;
let old_credentials = self.state.read().credentials.clone();
if let Some(old_credentials) = old_credentials {
self.cloud_client.set_credentials(
old_credentials.user_id as u32,
old_credentials.access_token.clone(),
);
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
if self.cloud_client.get_authenticated_user().await.is_ok() {
credentials = Some(old_credentials);
}
}
if credentials.is_none() && try_provider {
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
self.cloud_client.set_credentials(
stored_credentials.user_id as u32,
stored_credentials.access_token.clone(),
);
// Fetch the authenticated user with the stored credentials, and
// clear them from the credentials provider if that fails.
if self.cloud_client.get_authenticated_user().await.is_ok() {
credentials = Some(stored_credentials);
} else {
self.credentials_provider
.delete_credentials(cx)
.await
.log_err();
}
}
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => {
if IMPERSONATE_LOGIN.is_none() {
self.credentials_provider
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
.await
.log_err();
}
credentials = Some(creds);
},
Err(err) => {
self.set_status(Status::AuthenticationError, cx);
return Err(err);
}
}
}
_ = status_rx.next().fuse() => {
return Err(anyhow!("authentication canceled"));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
self.state.write().credentials = Some(credentials.clone());
self.set_status(Status::Authenticated, cx);
Ok(credentials)
}
/// Performs a sign-in and also connects to Collab.
///
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
/// to `sign_in` when we're ready to remove auto-connection to Collab.
pub async fn sign_in_with_optional_connect(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> Result<()> {
let credentials = self.sign_in(try_provider, cx).await?;
let connect_result = match self.connect_with_credentials(credentials, cx).await {
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
ConnectionResult::Result(result) => result.context("client auth and connect"),
};
connect_result.log_err();
Ok(())
}
pub async fn connect(
#[async_recursion(?Send)]
pub async fn authenticate_and_connect(
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let was_disconnected = match *self.status().borrow() {
Status::SignedOut | Status::Authenticated => true,
Status::SignedOut => true,
Status::ConnectionError
| Status::ConnectionLost
| Status::Authenticating { .. }
| Status::AuthenticationError
| Status::Reauthenticating { .. }
| Status::ReconnectionError { .. } => false,
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
@@ -997,10 +898,39 @@ impl Client {
);
}
};
let credentials = match self.sign_in(try_provider, cx).await {
Ok(credentials) => credentials,
Err(err) => return ConnectionResult::Result(Err(err)),
};
if was_disconnected {
self.set_status(Status::Authenticating, cx);
} else {
self.set_status(Status::Reauthenticating, cx)
}
let mut read_from_provider = false;
let mut credentials = self.state.read().credentials.clone();
if credentials.is_none() && try_provider {
credentials = self.credentials_provider.read_credentials(cx).await;
read_from_provider = credentials.is_some();
}
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
futures::select_biased! {
authenticate = self.authenticate(cx).fuse() => {
match authenticate {
Ok(creds) => credentials = Some(creds),
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return ConnectionResult::Result(Err(err));
}
}
}
_ = status_rx.next().fuse() => {
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
}
}
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
if was_disconnected {
self.set_status(Status::Connecting, cx);
@@ -1008,20 +938,17 @@ impl Client {
self.set_status(Status::Reconnecting, cx);
}
self.connect_with_credentials(credentials, cx).await
}
async fn connect_with_credentials(
self: &Arc<Self>,
credentials: Credentials,
cx: &AsyncApp,
) -> ConnectionResult<()> {
let mut timeout =
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
futures::select_biased! {
connection = self.establish_connection(&credentials, cx).fuse() => {
match connection {
Ok(conn) => {
self.state.write().credentials = Some(credentials.clone());
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
}
futures::select_biased! {
result = self.set_connection(conn, cx).fuse() => {
match result.context("client auth and connect") {
@@ -1039,8 +966,15 @@ impl Client {
}
}
Err(EstablishConnectionError::Unauthorized) => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
self.state.write().credentials.take();
if read_from_provider {
self.credentials_provider.delete_credentials(cx).await.log_err();
self.set_status(Status::SignedOut, cx);
self.authenticate_and_connect(false, cx).await
} else {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
}
}
Err(EstablishConnectionError::UpgradeRequired) => {
self.set_status(Status::UpgradeRequired, cx);
@@ -1204,7 +1138,7 @@ impl Client {
.to_str()
.map_err(EstablishConnectionError::other)?
.to_string();
Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}"))
}
}
@@ -1224,7 +1158,6 @@ impl Client {
let http = self.http.clone();
let proxy = http.proxy().cloned();
let user_agent = http.user_agent().cloned();
let credentials = credentials.clone();
let rpc_url = self.rpc_url(http, release_channel);
let system_id = self.telemetry.system_id();
@@ -1276,7 +1209,7 @@ impl Client {
// We then modify the request to add our desired headers.
let request_headers = request.headers_mut();
request_headers.insert(
http::header::AUTHORIZATION,
"Authorization",
HeaderValue::from_str(&credentials.authorization_header())?,
);
request_headers.insert(
@@ -1288,9 +1221,6 @@ impl Client {
"x-zed-release-channel",
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
);
if let Some(user_agent) = user_agent {
request_headers.insert(http::header::USER_AGENT, user_agent);
}
if let Some(system_id) = system_id {
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
}
@@ -1435,31 +1365,96 @@ impl Client {
self: &Arc<Self>,
http: Arc<HttpClientWithUrl>,
login: String,
api_token: String,
mut api_token: String,
) -> Result<Credentials> {
#[derive(Serialize)]
struct ImpersonateUserBody {
github_login: String,
#[derive(Deserialize)]
struct AuthenticatedUserResponse {
user: User,
}
#[derive(Deserialize)]
struct ImpersonateUserResponse {
user_id: u64,
access_token: String,
struct User {
id: u64,
}
let url = self
.http
.build_zed_cloud_url("/internal/users/impersonate", &[])?;
let request = Request::post(url.as_str())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {api_token}"))
.body(
serde_json::to_string(&ImpersonateUserBody {
github_login: login,
})?
.into(),
)?;
let github_user = {
#[derive(Deserialize)]
struct GithubUser {
id: i32,
login: String,
created_at: DateTime<Utc>,
}
let request = {
let mut request_builder =
Request::get(&format!("https://api.github.com/users/{login}"));
if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", github_token));
}
request_builder.body(AsyncBody::empty())?
};
let mut response = http
.send(request)
.await
.context("error fetching GitHub user")?;
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.context("error reading GitHub user")?;
if !response.status().is_success() {
let text = String::from_utf8_lossy(body.as_slice());
bail!(
"status error {}, response: {text:?}",
response.status().as_u16()
);
}
serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
log::error!("Error deserializing: {:?}", err);
log::error!(
"GitHub API response text: {:?}",
String::from_utf8_lossy(body.as_slice())
);
anyhow!("error deserializing GitHub user")
})?
};
let query_params = [
("github_login", &github_user.login),
("github_user_id", &github_user.id.to_string()),
(
"github_user_created_at",
&github_user.created_at.to_rfc3339(),
),
];
// Use the collab server's admin API to retrieve the ID
// of the impersonated user.
let mut url = self.rpc_url(http.clone(), None).await?;
url.set_path("/user");
url.set_query(Some(
&query_params
.iter()
.map(|(key, value)| {
format!(
"{}={}",
key,
url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
)
})
.collect::<Vec<String>>()
.join("&"),
));
let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
.header("Authorization", format!("token {api_token}"))
.body("".into())?;
let mut response = http.send(request).await?;
let mut body = String::new();
@@ -1470,17 +1465,18 @@ impl Client {
response.status().as_u16(),
body,
);
let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
// Use the admin API token to authenticate as the impersonated user.
api_token.insert_str(0, "ADMIN_TOKEN:");
Ok(Credentials {
user_id: response.user_id,
access_token: response.access_token,
user_id: response.user.id,
access_token: api_token,
})
}
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
self.state.write().credentials = None;
self.cloud_client.clear_credentials();
self.disconnect(cx);
if self.has_credentials(cx).await {
@@ -1790,7 +1786,7 @@ mod tests {
});
let auth_and_connect = cx.spawn({
let client = client.clone();
|cx| async move { client.connect(false, &cx).await }
|cx| async move { client.authenticate_and_connect(false, &cx).await }
});
executor.run_until_parked();
assert!(matches!(status.next().await, Some(Status::Connecting)));
@@ -1867,7 +1863,7 @@ mod tests {
let _authenticate = cx.spawn({
let client = client.clone();
move |cx| async move { client.connect(false, &cx).await }
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
});
executor.run_until_parked();
assert_eq!(*auth_count.lock(), 1);
@@ -1875,7 +1871,7 @@ mod tests {
let _authenticate = cx.spawn({
let client = client.clone();
|cx| async move { client.connect(false, &cx).await }
|cx| async move { client.authenticate_and_connect(false, &cx).await }
});
executor.run_until_parked();
assert_eq!(*auth_count.lock(), 2);

View File

@@ -1,11 +1,8 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{Context as _, Result, anyhow};
use chrono::Duration;
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
use futures::{StreamExt, stream::BoxStream};
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
use http_client::{AsyncBody, Method, Request, http};
use parking_lot::Mutex;
use rpc::{
ConnectionId, Peer, Receipt, TypedEnvelope,
@@ -42,44 +39,6 @@ impl FakeServer {
executor: cx.executor(),
};
client.http_client().as_fake().replace_handler({
let state = server.state.clone();
move |old_handler, req| {
let state = state.clone();
let old_handler = old_handler.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => {
let credentials = parse_authorization_header(&req);
if credentials
!= Some(Credentials {
user_id: client_user_id,
access_token: state.lock().access_token.to_string(),
})
{
return Ok(http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap());
}
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response(
client_user_id as i32,
format!("user-{client_user_id}"),
))
.unwrap()
.into(),
)
.unwrap())
}
_ => old_handler(req).await,
}
}
}
});
client
.override_authenticate({
let state = Arc::downgrade(&server.state);
@@ -146,7 +105,7 @@ impl FakeServer {
});
client
.connect(false, &cx.to_async())
.authenticate_and_connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
@@ -264,54 +223,3 @@ impl Drop for FakeServer {
self.disconnect();
}
}
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
let mut auth_header = req
.headers()
.get(http::header::AUTHORIZATION)?
.to_str()
.ok()?
.split_whitespace();
let user_id = auth_header.next()?.parse().ok()?;
let access_token = auth_header.next()?;
Some(Credentials {
user_id,
access_token: access_token.to_string(),
})
}
pub fn make_get_authenticated_user_response(
user_id: i32,
github_login: String,
) -> GetAuthenticatedUserResponse {
GetAuthenticatedUserResponse {
user: AuthenticatedUser {
id: user_id,
metrics_id: format!("metrics-id-{user_id}"),
avatar_url: "".to_string(),
github_login,
name: None,
is_staff: false,
accepted_tos_at: None,
},
feature_flags: vec![],
plan: PlanInfo {
plan: Plan::ZedPro,
subscription_period: None,
usage: CurrentUsage {
model_requests: UsageData {
used: 0,
limit: UsageLimit::Limited(500),
},
edit_predictions: UsageData {
used: 250,
limit: UsageLimit::Unlimited,
},
},
trial_started_at: None,
is_usage_based_billing_enabled: false,
is_account_too_young: false,
has_overdue_invoices: false,
},
}
}

View File

@@ -1,11 +1,6 @@
use super::{Client, Status, TypedEnvelope, proto};
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
use cloud_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
use collections::{HashMap, HashSet, hash_map::Entry};
use derive_more::Deref;
use feature_flags::FeatureFlagAppExt;
@@ -21,7 +16,11 @@ use std::{
sync::{Arc, Weak},
};
use text::ReplicaId;
use util::{ResultExt, TryFutureExt as _};
use util::{TryFutureExt as _, maybe};
use zed_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
pub type UserId = u64;
@@ -56,7 +55,7 @@ pub struct ParticipantIndex(pub u32);
#[derive(Default, Debug)]
pub struct User {
pub id: UserId,
pub github_login: SharedString,
pub github_login: String,
pub avatar_uri: SharedUri,
pub name: Option<String>,
}
@@ -108,14 +107,19 @@ pub enum ContactRequestStatus {
pub struct UserStore {
users: HashMap<u64, Arc<User>>,
by_github_login: HashMap<SharedString, u64>,
by_github_login: HashMap<String, u64>,
participant_indices: HashMap<u64, ParticipantIndex>,
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>,
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
trial_started_at: Option<DateTime<Utc>>,
model_request_usage: Option<ModelRequestUsage>,
edit_prediction_usage: Option<EditPredictionUsage>,
plan_info: Option<PlanInfo>,
is_usage_based_billing_enabled: Option<bool>,
account_too_young: Option<bool>,
has_overdue_invoices: Option<bool>,
current_user: watch::Receiver<Option<Arc<User>>>,
accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
accepted_tos_at: Option<Option<DateTime<Utc>>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
@@ -141,7 +145,6 @@ pub enum Event {
ShowContacts,
ParticipantIndicesChanged,
PrivateUserInfoUpdated,
PlanUpdated,
}
#[derive(Clone, Copy)]
@@ -185,9 +188,14 @@ impl UserStore {
users: Default::default(),
by_github_login: Default::default(),
current_user: current_user_rx,
plan_info: None,
current_plan: None,
subscription_period: None,
trial_started_at: None,
model_request_usage: None,
edit_prediction_usage: None,
is_usage_based_billing_enabled: None,
account_too_young: None,
has_overdue_invoices: None,
accepted_tos_at: None,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
@@ -217,30 +225,53 @@ impl UserStore {
return Ok(());
};
match status {
Status::Authenticated | Status::Connected { .. } => {
Status::Connected { .. } => {
if let Some(user_id) = client.user_id() {
let response = client.cloud_client().get_authenticated_user().await;
let mut current_user = None;
let fetch_user = if let Ok(fetch_user) =
this.update(cx, |this, cx| this.get_user(user_id, cx).log_err())
{
fetch_user
} else {
break;
};
let fetch_private_user_info =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) =
futures::join!(fetch_user, fetch_private_user_info);
cx.update(|cx| {
if let Some(response) = response.log_err() {
let user = Arc::new(User {
id: user_id,
github_login: response.user.github_login.clone().into(),
avatar_uri: response.user.avatar_url.clone().into(),
name: response.user.name.clone(),
});
current_user = Some(user.clone());
if let Some(info) = info {
let staff =
info.staff && !*feature_flags::ZED_DISABLE_STAFF;
cx.update_flags(staff, info.flags);
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
staff,
);
this.update(cx, |this, cx| {
this.by_github_login
.insert(user.github_login.clone(), user_id);
this.users.insert(user_id, user);
this.update_authenticated_user(response, cx)
let accepted_tos_at = {
#[cfg(debug_assertions)]
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok()
{
None
} else {
info.accepted_tos_at
}
#[cfg(not(debug_assertions))]
info.accepted_tos_at
};
this.set_current_user_accepted_tos_at(accepted_tos_at);
cx.emit(Event::PrivateUserInfoUpdated);
})
} else {
anyhow::Ok(())
}
})??;
current_user_tx.send(current_user).await.ok();
current_user_tx.send(user).await.ok();
this.update(cx, |_, cx| cx.notify())?;
}
@@ -321,22 +352,59 @@ impl UserStore {
async fn handle_update_plan(
this: Entity<Self>,
_message: TypedEnvelope<proto::UpdateUserPlan>,
message: TypedEnvelope<proto::UpdateUserPlan>,
mut cx: AsyncApp,
) -> Result<()> {
let client = this
.read_with(&cx, |this, _| this.client.upgrade())?
.context("client was dropped")?;
let response = client
.cloud_client()
.get_authenticated_user()
.await
.context("failed to fetch authenticated user")?;
this.update(&mut cx, |this, cx| {
this.update_authenticated_user(response, cx);
})
this.current_plan = Some(message.payload.plan());
this.subscription_period = maybe!({
let period = message.payload.subscription_period?;
let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?;
let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?;
Some((started_at, ended_at))
});
this.trial_started_at = message
.payload
.trial_started_at
.and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
this.account_too_young = message.payload.account_too_young;
this.has_overdue_invoices = message.payload.has_overdue_invoices;
if let Some(usage) = message.payload.usage {
// limits are always present even though they are wrapped in Option
this.model_request_usage = usage
.model_requests_usage_limit
.and_then(|limit| {
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
})
.map(ModelRequestUsage);
this.edit_prediction_usage = usage
.edit_predictions_usage_limit
.and_then(|limit| {
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
})
.map(EditPredictionUsage);
}
cx.notify();
})?;
Ok(())
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
@@ -695,131 +763,59 @@ impl UserStore {
self.current_user.borrow().clone()
}
pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
pub fn current_plan(&self) -> Option<proto::Plan> {
#[cfg(debug_assertions)]
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
return match plan.as_str() {
"free" => Some(cloud_llm_client::Plan::ZedFree),
"trial" => Some(cloud_llm_client::Plan::ZedProTrial),
"pro" => Some(cloud_llm_client::Plan::ZedPro),
"free" => Some(proto::Plan::Free),
"trial" => Some(proto::Plan::ZedProTrial),
"pro" => Some(proto::Plan::ZedPro),
_ => {
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
}
};
}
self.plan_info.as_ref().map(|info| info.plan)
self.current_plan
}
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
self.plan_info
.as_ref()
.and_then(|plan| plan.subscription_period)
.map(|subscription_period| {
(
subscription_period.started_at.0,
subscription_period.ended_at.0,
)
})
self.subscription_period
}
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
self.plan_info
.as_ref()
.and_then(|plan| plan.trial_started_at)
.map(|trial_started_at| trial_started_at.0)
self.trial_started_at
}
/// Returns whether the user's account is too new to use the service.
pub fn account_too_young(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_account_too_young)
.unwrap_or_default()
}
/// Returns whether the current user has overdue invoices and usage should be blocked.
pub fn has_overdue_invoices(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.has_overdue_invoices)
.unwrap_or_default()
}
pub fn is_usage_based_billing_enabled(&self) -> bool {
self.plan_info
.as_ref()
.map(|plan| plan.is_usage_based_billing_enabled)
.unwrap_or_default()
pub fn usage_based_billing_enabled(&self) -> Option<bool> {
self.is_usage_based_billing_enabled
}
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
self.model_request_usage
}
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
self.model_request_usage = Some(usage);
cx.notify();
}
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
self.edit_prediction_usage
}
pub fn update_edit_prediction_usage(
&mut self,
usage: EditPredictionUsage,
cx: &mut Context<Self>,
) {
self.edit_prediction_usage = Some(usage);
cx.notify();
}
fn update_authenticated_user(
&mut self,
response: GetAuthenticatedUserResponse,
cx: &mut Context<Self>,
) {
let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
cx.update_flags(staff, response.feature_flags);
if let Some(client) = self.client.upgrade() {
client
.telemetry
.set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
}
let accepted_tos_at = {
#[cfg(debug_assertions)]
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
None
} else {
response.user.accepted_tos_at
}
#[cfg(not(debug_assertions))]
response.user.accepted_tos_at
};
self.accepted_tos_at = Some(accepted_tos_at);
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
limit: response.plan.usage.model_requests.limit,
amount: response.plan.usage.model_requests.used as i32,
}));
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
limit: response.plan.usage.edit_predictions.limit,
amount: response.plan.usage.edit_predictions.used as i32,
}));
self.plan_info = Some(response.plan);
cx.emit(Event::PrivateUserInfoUpdated);
}
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
self.current_user.clone()
}
pub fn has_accepted_terms_of_service(&self) -> bool {
/// Returns whether the user's account is too new to use the service.
pub fn account_too_young(&self) -> bool {
self.account_too_young.unwrap_or(false)
}
/// Returns whether the current user has overdue invoices and usage should be blocked.
pub fn has_overdue_invoices(&self) -> bool {
self.has_overdue_invoices.unwrap_or(false)
}
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
self.accepted_tos_at
.map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
.map(|accepted_tos_at| accepted_tos_at.is_some())
}
pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
@@ -831,18 +827,23 @@ impl UserStore {
cx.spawn(async move |this, cx| -> anyhow::Result<()> {
let client = client.upgrade().context("client not found")?;
let response = client
.cloud_client()
.accept_terms_of_service()
.request(proto::AcceptTermsOfService {})
.await
.context("error accepting tos")?;
this.update(cx, |this, cx| {
this.accepted_tos_at = Some(response.user.accepted_tos_at);
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
cx.emit(Event::PrivateUserInfoUpdated);
})?;
Ok(())
})
}
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
self.accepted_tos_at = Some(
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
);
}
fn load_users(
&self,
request: impl RequestMessage<Response = UsersResponse>,
@@ -901,7 +902,7 @@ impl UserStore {
let mut missing_user_ids = Vec::new();
for id in user_ids {
if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) {
ret.insert(id, github_login);
ret.insert(id, github_login.into());
} else {
missing_user_ids.push(id)
}
@@ -922,7 +923,7 @@ impl User {
fn new(message: proto::User) -> Arc<Self> {
Arc::new(User {
id: message.id,
github_login: message.github_login.into(),
github_login: message.github_login,
avatar_uri: message.avatar_url.into(),
name: message.name,
})

View File

@@ -1,21 +0,0 @@
[package]
name = "cloud_api_client"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "Apache-2.0"
[lints]
workspace = true
[lib]
path = "src/cloud_api_client.rs"
[dependencies]
anyhow.workspace = true
cloud_api_types.workspace = true
futures.workspace = true
http_client.workspace = true
parking_lot.workspace = true
serde_json.workspace = true
workspace-hack.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1,155 +0,0 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
pub use cloud_api_types::*;
use futures::AsyncReadExt as _;
use http_client::http::request;
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
use parking_lot::RwLock;
struct Credentials {
user_id: u32,
access_token: String,
}
pub struct CloudApiClient {
credentials: RwLock<Option<Credentials>>,
http_client: Arc<HttpClientWithUrl>,
}
impl CloudApiClient {
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
Self {
credentials: RwLock::new(None),
http_client,
}
}
pub fn has_credentials(&self) -> bool {
self.credentials.read().is_some()
}
pub fn set_credentials(&self, user_id: u32, access_token: String) {
*self.credentials.write() = Some(Credentials {
user_id,
access_token,
});
}
pub fn clear_credentials(&self) {
*self.credentials.write() = None;
}
fn authorization_header(&self) -> Result<String> {
let guard = self.credentials.read();
let credentials = guard
.as_ref()
.ok_or_else(|| anyhow!("No credentials provided"))?;
Ok(format!(
"{} {}",
credentials.user_id, credentials.access_token
))
}
fn build_request(
&self,
req: request::Builder,
body: impl Into<AsyncBody>,
) -> Result<Request<AsyncBody>> {
Ok(req
.header("Content-Type", "application/json")
.header("Authorization", self.authorization_header()?)
.body(body.into())?)
}
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
let request = self.build_request(
Request::builder().method(Method::GET).uri(
self.http_client
.build_zed_cloud_url("/client/users/me", &[])?
.as_ref(),
),
AsyncBody::default(),
)?;
let mut response = self.http_client.send(request).await?;
if !response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
response.status()
)
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn accept_terms_of_service(&self) -> Result<AcceptTermsOfServiceResponse> {
let request = self.build_request(
Request::builder().method(Method::POST).uri(
self.http_client
.build_zed_cloud_url("/client/terms_of_service/accept", &[])?
.as_ref(),
),
AsyncBody::default(),
)?;
let mut response = self.http_client.send(request).await?;
if !response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Failed to accept terms of service.\nStatus: {:?}\nBody: {body}",
response.status()
)
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(serde_json::from_str(&body)?)
}
pub async fn create_llm_token(
&self,
system_id: Option<String>,
) -> Result<CreateLlmTokenResponse> {
let mut request_builder = Request::builder().method(Method::POST).uri(
self.http_client
.build_zed_cloud_url("/client/llm_tokens", &[])?
.as_ref(),
);
if let Some(system_id) = system_id {
request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id);
}
let request = self.build_request(request_builder, AsyncBody::default())?;
let mut response = self.http_client.send(request).await?;
if !response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Failed to create LLM token.\nStatus: {:?}\nBody: {body}",
response.status()
)
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(serde_json::from_str(&body)?)
}
}

View File

@@ -1,22 +0,0 @@
[package]
name = "cloud_api_types"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "Apache-2.0"
[lints]
workspace = true
[lib]
path = "src/cloud_api_types.rs"
[dependencies]
chrono.workspace = true
cloud_llm_client.workspace = true
serde.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true
serde_json.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1,55 +0,0 @@
mod timestamp;
use serde::{Deserialize, Serialize};
pub use crate::timestamp::Timestamp;
pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id";
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct GetAuthenticatedUserResponse {
pub user: AuthenticatedUser,
pub feature_flags: Vec<String>,
pub plan: PlanInfo,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct AuthenticatedUser {
pub id: i32,
pub metrics_id: String,
pub avatar_url: String,
pub github_login: String,
pub name: Option<String>,
pub is_staff: bool,
pub accepted_tos_at: Option<Timestamp>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct PlanInfo {
pub plan: cloud_llm_client::Plan,
pub subscription_period: Option<SubscriptionPeriod>,
pub usage: cloud_llm_client::CurrentUsage,
pub trial_started_at: Option<Timestamp>,
pub is_usage_based_billing_enabled: bool,
pub is_account_too_young: bool,
pub has_overdue_invoices: bool,
}
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
pub struct SubscriptionPeriod {
pub started_at: Timestamp,
pub ended_at: Timestamp,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct AcceptTermsOfServiceResponse {
pub user: AuthenticatedUser,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct LlmToken(pub String);
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct CreateLlmTokenResponse {
pub token: LlmToken,
}

View File

@@ -1,166 +0,0 @@
use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// A timestamp with a serialized representation in RFC 3339 format.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct Timestamp(pub DateTime<Utc>);
impl Timestamp {
pub fn new(datetime: DateTime<Utc>) -> Self {
Self(datetime)
}
}
impl From<DateTime<Utc>> for Timestamp {
fn from(value: DateTime<Utc>) -> Self {
Self(value)
}
}
impl From<NaiveDateTime> for Timestamp {
fn from(value: NaiveDateTime) -> Self {
Self(value.and_utc())
}
}
impl Serialize for Timestamp {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true);
serializer.serialize_str(&rfc3339_string)
}
}
impl<'de> Deserialize<'de> for Timestamp {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
let datetime = DateTime::parse_from_rfc3339(&value)
.map_err(serde::de::Error::custom)?
.to_utc();
Ok(Self(datetime))
}
}
#[cfg(test)]
mod tests {
use chrono::NaiveDate;
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn test_timestamp_serialization() {
let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
.unwrap()
.to_utc();
let timestamp = Timestamp::new(datetime);
let json = serde_json::to_string(&timestamp).unwrap();
assert_eq!(json, "\"2023-12-25T14:30:45.123Z\"");
}
#[test]
fn test_timestamp_deserialization() {
let json = "\"2023-12-25T14:30:45.123Z\"";
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
.unwrap()
.to_utc();
assert_eq!(timestamp.0, expected);
}
#[test]
fn test_timestamp_roundtrip() {
let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
.unwrap()
.to_utc();
let timestamp = Timestamp::new(original);
let json = serde_json::to_string(&timestamp).unwrap();
let deserialized: Timestamp = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.0, original);
}
#[test]
fn test_timestamp_from_datetime_utc() {
let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
.unwrap()
.to_utc();
let timestamp = Timestamp::from(datetime);
assert_eq!(timestamp.0, datetime);
}
#[test]
fn test_timestamp_from_naive_datetime() {
let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25)
.unwrap()
.and_hms_milli_opt(14, 30, 45, 123)
.unwrap();
let timestamp = Timestamp::from(naive_dt);
let expected = naive_dt.and_utc();
assert_eq!(timestamp.0, expected);
}
#[test]
fn test_timestamp_serialization_with_microseconds() {
// Test that microseconds are truncated to milliseconds
let datetime = NaiveDate::from_ymd_opt(2023, 12, 25)
.unwrap()
.and_hms_micro_opt(14, 30, 45, 123456)
.unwrap()
.and_utc();
let timestamp = Timestamp::new(datetime);
let json = serde_json::to_string(&timestamp).unwrap();
// Should be truncated to milliseconds
assert_eq!(json, "\"2023-12-25T14:30:45.123Z\"");
}
#[test]
fn test_timestamp_deserialization_without_milliseconds() {
let json = "\"2023-12-25T14:30:45Z\"";
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
let expected = NaiveDate::from_ymd_opt(2023, 12, 25)
.unwrap()
.and_hms_opt(14, 30, 45)
.unwrap()
.and_utc();
assert_eq!(timestamp.0, expected);
}
#[test]
fn test_timestamp_deserialization_with_timezone() {
let json = "\"2023-12-25T14:30:45.123+05:30\"";
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
// Should be converted to UTC
let expected = NaiveDate::from_ymd_opt(2023, 12, 25)
.unwrap()
.and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30
.unwrap()
.and_utc();
assert_eq!(timestamp.0, expected);
}
#[test]
fn test_timestamp_deserialization_with_invalid_format() {
let json = "\"invalid-date\"";
let result: Result<Timestamp, _> = serde_json::from_str(json);
assert!(result.is_err());
}
}

View File

@@ -1,23 +0,0 @@
[package]
name = "cloud_llm_client"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
license = "Apache-2.0"
[lints]
workspace = true
[lib]
path = "src/cloud_llm_client.rs"
[dependencies]
anyhow.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }
uuid = { workspace = true, features = ["serde"] }
workspace-hack.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1,370 +0,0 @@
use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumIter, EnumString};
use uuid::Uuid;
/// The name of the header used to indicate which version of Zed the client is running.
pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
/// The name of the header used to indicate when a request failed due to an
/// expired LLM token.
///
/// The client may use this as a signal to refresh the token.
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
/// The name of the header used to indicate what plan the user is currently on.
pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
/// The name of the header used to indicate the usage limit for model requests.
pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
/// The name of the header used to indicate the usage amount for model requests.
pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
/// The name of the header used to indicate the usage limit for edit predictions.
pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
/// The name of the header used to indicate the usage amount for edit predictions.
pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
/// The name of the header used to indicate the resource for which the subscription limit has been reached.
pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
/// The name of the header used to indicate the the minimum required Zed version.
///
/// This can be used to force a Zed upgrade in order to continue communicating
/// with the LLM service.
pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
"x-zed-client-supports-status-messages";
/// The name of the header used by the server to indicate to the client that it supports sending status messages.
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
"x-zed-server-supports-status-messages";
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UsageLimit {
Limited(i32),
Unlimited,
}
impl FromStr for UsageLimit {
type Err = anyhow::Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
"unlimited" => Ok(Self::Unlimited),
limit => limit
.parse::<i32>()
.map(Self::Limited)
.context("failed to parse limit"),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Plan {
#[default]
#[serde(alias = "Free")]
ZedFree,
#[serde(alias = "ZedPro")]
ZedPro,
#[serde(alias = "ZedProTrial")]
ZedProTrial,
}
impl Plan {
pub fn as_str(&self) -> &'static str {
match self {
Plan::ZedFree => "zed_free",
Plan::ZedPro => "zed_pro",
Plan::ZedProTrial => "zed_pro_trial",
}
}
pub fn model_requests_limit(&self) -> UsageLimit {
match self {
Plan::ZedPro => UsageLimit::Limited(500),
Plan::ZedProTrial => UsageLimit::Limited(150),
Plan::ZedFree => UsageLimit::Limited(50),
}
}
pub fn edit_predictions_limit(&self) -> UsageLimit {
match self {
Plan::ZedPro => UsageLimit::Unlimited,
Plan::ZedProTrial => UsageLimit::Unlimited,
Plan::ZedFree => UsageLimit::Limited(2_000),
}
}
}
impl FromStr for Plan {
type Err = anyhow::Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
"zed_free" => Ok(Plan::ZedFree),
"zed_pro" => Ok(Plan::ZedPro),
"zed_pro_trial" => Ok(Plan::ZedProTrial),
plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
}
}
}
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum LanguageModelProvider {
Anthropic,
OpenAi,
Google,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsBody {
#[serde(skip_serializing_if = "Option::is_none", default)]
pub outline: Option<String>,
pub input_events: String,
pub input_excerpt: String,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub speculated_output: Option<String>,
/// Whether the user provided consent for sampling this interaction.
#[serde(default, alias = "data_collection_permission")]
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
pub request_id: Uuid,
pub output_excerpt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcceptEditPredictionBody {
pub request_id: Uuid,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
Normal,
Max,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionIntent {
UserPrompt,
ToolResults,
ThreadSummarization,
ThreadContextSummarization,
CreateFile,
EditFile,
InlineAssist,
TerminalInlineAssist,
GenerateGitCommitMessage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionBody {
#[serde(skip_serializing_if = "Option::is_none", default)]
pub thread_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub prompt_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub intent: Option<CompletionIntent>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub mode: Option<CompletionMode>,
pub provider: LanguageModelProvider,
pub model: String,
pub provider_request: serde_json::Value,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionRequestStatus {
Queued {
position: usize,
},
Started,
Failed {
code: String,
message: String,
request_id: Uuid,
/// Retry duration in seconds.
retry_after: Option<f64>,
},
UsageUpdated {
amount: usize,
limit: UsageLimit,
},
ToolUseLimitReached,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionEvent<T> {
Status(CompletionRequestStatus),
Event(T),
}
impl<T> CompletionEvent<T> {
pub fn into_status(self) -> Option<CompletionRequestStatus> {
match self {
Self::Status(status) => Some(status),
Self::Event(_) => None,
}
}
pub fn into_event(self) -> Option<T> {
match self {
Self::Event(event) => Some(event),
Self::Status(_) => None,
}
}
}
#[derive(Serialize, Deserialize)]
pub struct WebSearchBody {
pub query: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct WebSearchResponse {
pub results: Vec<WebSearchResult>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct WebSearchResult {
pub title: String,
pub url: String,
pub text: String,
}
#[derive(Serialize, Deserialize)]
pub struct CountTokensBody {
pub provider: LanguageModelProvider,
pub model: String,
pub provider_request: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub struct CountTokensResponse {
pub tokens: usize,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelId(pub Arc<str>);
impl std::fmt::Display for LanguageModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LanguageModel {
pub provider: LanguageModelProvider,
pub id: LanguageModelId,
pub display_name: String,
pub max_token_count: usize,
pub max_token_count_in_max_mode: Option<usize>,
pub max_output_tokens: usize,
pub supports_tools: bool,
pub supports_images: bool,
pub supports_thinking: bool,
pub supports_max_mode: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<LanguageModel>,
pub default_model: LanguageModelId,
pub default_fast_model: LanguageModelId,
pub recommended_models: Vec<LanguageModelId>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetSubscriptionResponse {
pub plan: Plan,
pub usage: Option<CurrentUsage>,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct CurrentUsage {
pub model_requests: UsageData,
pub edit_predictions: UsageData,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct UsageData {
pub used: u32,
pub limit: UsageLimit,
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use super::*;
#[test]
fn test_plan_deserialize_snake_case() {
let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
assert_eq!(plan, Plan::ZedFree);
let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
assert_eq!(plan, Plan::ZedPro);
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
assert_eq!(plan, Plan::ZedProTrial);
}
#[test]
fn test_plan_deserialize_aliases() {
let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
assert_eq!(plan, Plan::ZedFree);
let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
assert_eq!(plan, Plan::ZedPro);
let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
assert_eq!(plan, Plan::ZedProTrial);
}
#[test]
fn test_usage_limit_from_str() {
let limit = UsageLimit::from_str("unlimited").unwrap();
assert!(matches!(limit, UsageLimit::Unlimited));
let limit = UsageLimit::from_str(&0.to_string()).unwrap();
assert!(matches!(limit, UsageLimit::Limited(0)));
let limit = UsageLimit::from_str(&50.to_string()).unwrap();
assert!(matches!(limit, UsageLimit::Limited(50)));
for value in ["not_a_number", "50xyz"] {
let limit = UsageLimit::from_str(value);
assert!(limit.is_err());
}
}
}

View File

@@ -23,14 +23,13 @@ async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-kinesis = "1.51.0"
aws-sdk-s3 = { version = "1.15.0" }
aws-sdk-kinesis = "1.51.0"
axum = { version = "0.6", features = ["json", "headers", "ws"] }
axum-extra = { version = "0.4", features = ["erased-json"] }
base64.workspace = true
chrono.workspace = true
clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
@@ -76,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
agent_settings.workspace = true

View File

@@ -100,6 +100,7 @@ impl std::fmt::Display for SystemIdHeader {
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/user", get(update_or_create_authenticated_user))
.route("/users/look_up", get(look_up_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
@@ -144,6 +145,48 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
Ok::<_, Error>(next.run(req).await)
}
#[derive(Debug, Deserialize)]
struct AuthenticatedUserParams {
github_user_id: i32,
github_login: String,
github_email: Option<String>,
github_name: Option<String>,
github_user_created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Serialize)]
struct AuthenticatedUserResponse {
user: User,
metrics_id: String,
feature_flags: Vec<String>,
}
async fn update_or_create_authenticated_user(
Query(params): Query<AuthenticatedUserParams>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<AuthenticatedUserResponse>> {
let initial_channel_id = app.config.auto_join_channel_id;
let user = app
.db
.update_or_create_user_by_github_account(
&params.github_login,
params.github_user_id,
params.github_email.as_deref(),
params.github_name.as_deref(),
params.github_user_created_at,
initial_channel_id,
)
.await?;
let metrics_id = app.db.get_user_metrics_id(user.id).await?;
let feature_flags = app.db.get_user_flags(user.id).await?;
Ok(Json(AuthenticatedUserResponse {
user,
metrics_id,
feature_flags,
}))
}
#[derive(Debug, Deserialize)]
struct LookUpUserParams {
identifier: String,
@@ -310,9 +353,9 @@ async fn refresh_llm_tokens(
#[derive(Debug, Serialize, Deserialize)]
struct UpdatePlanBody {
pub plan: cloud_llm_client::Plan,
pub plan: zed_llm_client::Plan,
pub subscription_period: SubscriptionPeriod,
pub usage: cloud_llm_client::CurrentUsage,
pub usage: zed_llm_client::CurrentUsage,
pub trial_started_at: Option<DateTime<Utc>>,
pub is_usage_based_billing_enabled: bool,
pub is_account_too_young: bool,
@@ -334,9 +377,9 @@ async fn update_plan(
extract::Json(body): extract::Json<UpdatePlanBody>,
) -> Result<Json<UpdatePlanResponse>> {
let plan = match body.plan {
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
let update_user_plan = proto::UpdateUserPlan {
@@ -368,15 +411,15 @@ async fn update_plan(
Ok(Json(UpdatePlanResponse {}))
}
fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
proto::UsageLimit {
variant: Some(match limit {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),

View File

@@ -1,11 +1,11 @@
use anyhow::{Context as _, bail};
use chrono::{DateTime, Utc};
use cloud_llm_client::LanguageModelProvider;
use collections::{HashMap, HashSet};
use sea_orm::ActiveValue;
use std::{sync::Arc, time::Duration};
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
use util::{ResultExt, maybe};
use zed_llm_client::LanguageModelProvider;
use crate::AppState;
use crate::db::billing_subscription::{
@@ -87,14 +87,6 @@ async fn poll_stripe_events(
stripe_client: &Arc<dyn StripeClient>,
real_stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_events_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
if sync_events_using_cloud {
return Ok(());
}
fn event_type_to_string(event_type: EventType) -> String {
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
// so we need to unquote it.
@@ -577,14 +569,6 @@ async fn sync_model_request_usage_with_stripe(
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_model_request_usage_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
if sync_model_request_usage_using_cloud {
return Ok(());
}
log::info!("Stripe usage sync: Starting");
let started_at = Utc::now();

View File

@@ -8,6 +8,7 @@ use axum::{
use chrono::{NaiveDateTime, SecondsFormat};
use serde::{Deserialize, Serialize};
use crate::api::AuthenticatedUserParams;
use crate::db::ContributorSelector;
use crate::{AppState, Result};
@@ -103,18 +104,9 @@ impl RenovateBot {
}
}
#[derive(Debug, Deserialize)]
struct AddContributorBody {
github_user_id: i32,
github_login: String,
github_email: Option<String>,
github_name: Option<String>,
github_user_created_at: chrono::DateTime<chrono::Utc>,
}
async fn add_contributor(
Extension(app): Extension<Arc<AppState>>,
extract::Json(params): extract::Json<AddContributorBody>,
extract::Json(params): extract::Json<AuthenticatedUserParams>,
) -> Result<()> {
let initial_channel_id = app.config.auto_join_channel_id;
app.db

View File

@@ -95,7 +95,7 @@ pub enum SubscriptionKind {
ZedFree,
}
impl From<SubscriptionKind> for cloud_llm_client::Plan {
impl From<SubscriptionKind> for zed_llm_client::Plan {
fn from(value: SubscriptionKind) -> Self {
match value {
SubscriptionKind::ZedPro => Self::ZedPro,

View File

@@ -6,11 +6,11 @@ mod tables;
#[cfg(test)]
mod tests;
use cloud_llm_client::LanguageModelProvider;
use collections::HashMap;
pub use ids::*;
pub use seed::*;
pub use tables::*;
use zed_llm_client::LanguageModelProvider;
#[cfg(test)]
pub use tests::TestLlmDb;

View File

@@ -1,5 +1,5 @@
use cloud_llm_client::LanguageModelProvider;
use pretty_assertions::assert_eq;
use zed_llm_client::LanguageModelProvider;
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;

View File

@@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA
use crate::{Config, db::billing_preference};
use anyhow::{Context as _, Result};
use chrono::{NaiveDateTime, Utc};
use cloud_llm_client::Plan;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
use zed_llm_client::Plan;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -23,7 +23,6 @@ use anyhow::{Context as _, anyhow, bail};
use async_tungstenite::tungstenite::{
Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
};
use axum::headers::UserAgent;
use axum::{
Extension, Router, TypedHeader,
body::Body,
@@ -42,7 +41,7 @@ use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use reqwest_client::ReqwestClient;
use rpc::proto::{MultiLspQuery, split_repository_update};
use rpc::proto::split_repository_update;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use futures::{
@@ -374,7 +373,7 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
.add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
.add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
.add_request_handler(multi_lsp_query)
.add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
.add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
.add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
.add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
@@ -751,7 +750,6 @@ impl Server {
address: String,
principal: Principal,
zed_version: ZedVersion,
user_agent: Option<String>,
geoip_country_code: Option<String>,
system_id: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
@@ -764,14 +762,9 @@ impl Server {
user_id=field::Empty,
login=field::Empty,
impersonator=field::Empty,
user_agent=field::Empty,
geoip_country_code=field::Empty
);
principal.update_span(&span);
if let Some(user_agent) = user_agent {
span.record("user_agent", user_agent);
}
if let Some(country_code) = geoip_country_code.as_ref() {
span.record("geoip_country_code", country_code);
}
@@ -838,7 +831,7 @@ impl Server {
// This arrangement ensures we will attempt to process earlier messages first, but fall
// back to processing messages arrived later in the spirit of making progress.
let mut foreground_message_handlers = FuturesUnordered::new();
let concurrent_handlers = Arc::new(Semaphore::new(256));
let concurrent_handlers = Arc::new(Semaphore::new(512));
loop {
let next_message = async {
let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
@@ -865,7 +858,6 @@ impl Server {
user_id=field::Empty,
login=field::Empty,
impersonator=field::Empty,
multi_lsp_query_request=field::Empty,
);
principal.update_span(&span);
let span_enter = span.enter();
@@ -1180,7 +1172,6 @@ pub async fn handle_websocket_request(
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
user_agent: Option<TypedHeader<UserAgent>>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
system_id_header: Option<TypedHeader<SystemIdHeader>>,
ws: WebSocketUpgrade,
@@ -1236,7 +1227,6 @@ pub async fn handle_websocket_request(
socket_address,
principal,
version,
user_agent.map(|header| header.to_string()),
country_code_header.map(|header| header.to_string()),
system_id_header.map(|header| header.to_string()),
None,
@@ -2330,15 +2320,6 @@ where
Ok(())
}
async fn multi_lsp_query(
request: MultiLspQuery,
response: Response<MultiLspQuery>,
session: Session,
) -> Result<()> {
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
forward_mutating_project_request(request, response, session).await
}
/// Notify other participants that a new buffer has been created
async fn create_buffer_for_peer(
request: proto::CreateBufferForPeer,
@@ -2878,12 +2859,12 @@ async fn make_update_user_plan_message(
}
fn model_requests_limit(
plan: cloud_llm_client::Plan,
plan: zed_llm_client::Plan,
feature_flags: &Vec<String>,
) -> cloud_llm_client::UsageLimit {
) -> zed_llm_client::UsageLimit {
match plan.model_requests_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
zed_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == zed_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
@@ -2893,9 +2874,9 @@ fn model_requests_limit(
limit
};
cloud_llm_client::UsageLimit::Limited(limit)
zed_llm_client::UsageLimit::Limited(limit)
}
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
}
}
@@ -2905,21 +2886,21 @@ fn subscription_usage_to_proto(
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
@@ -2927,12 +2908,12 @@ fn subscription_usage_to_proto(
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
@@ -2945,21 +2926,21 @@ fn make_default_subscription_usage(
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: 0,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
@@ -2967,12 +2948,12 @@ fn make_default_subscription_usage(
edit_predictions_usage_amount: 0,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),

View File

@@ -38,12 +38,12 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
let mut remote = room
.remote_participants()
.values()
.map(|participant| participant.user.github_login.clone().to_string())
.map(|participant| participant.user.github_login.clone())
.collect::<Vec<_>>();
let mut pending = room
.pending_participants()
.iter()
.map(|user| user.github_login.clone().to_string())
.map(|user| user.github_login.clone())
.collect::<Vec<_>>();
remote.sort();
pending.sort();

View File

@@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room(
// Allow user A to reconnect to the server.
server.allow_connections();
executor.advance_clock(RECONNECT_TIMEOUT);
executor.advance_clock(RECEIVE_TIMEOUT);
// Call user B again from client A.
active_call_a
@@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections(
client_b1.disconnect(&cx_b1.to_async());
executor.advance_clock(RECEIVE_TIMEOUT);
client_b1
.connect(false, &cx_b1.to_async())
.authenticate_and_connect(false, &cx_b1.to_async())
.await
.into_response()
.unwrap();
@@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections(
// User A reconnects automatically, then calls user B again.
server.allow_connections();
executor.advance_clock(RECONNECT_TIMEOUT);
executor.advance_clock(RECEIVE_TIMEOUT);
active_call_a
.update(cx_a, |call, cx| {
call.invite(client_b1.user_id().unwrap(), None, cx)
@@ -1667,7 +1667,7 @@ async fn test_project_reconnect(
// Client A reconnects. Their project is re-shared, and client B re-joins it.
server.allow_connections();
client_a
.connect(false, &cx_a.to_async())
.authenticate_and_connect(false, &cx_a.to_async())
.await
.into_response()
.unwrap();
@@ -1796,7 +1796,7 @@ async fn test_project_reconnect(
// Client B reconnects. They re-join the room and the remaining shared project.
server.allow_connections();
client_b
.connect(false, &cx_b.to_async())
.authenticate_and_connect(false, &cx_b.to_async())
.await
.into_response()
.unwrap();
@@ -1881,7 +1881,7 @@ async fn test_active_call_events(
vec![room::Event::RemoteProjectShared {
owner: Arc::new(User {
id: client_a.user_id().unwrap(),
github_login: "user_a".into(),
github_login: "user_a".to_string(),
avatar_uri: "avatar_a".into(),
name: None,
}),
@@ -1900,7 +1900,7 @@ async fn test_active_call_events(
vec![room::Event::RemoteProjectShared {
owner: Arc::new(User {
id: client_b.user_id().unwrap(),
github_login: "user_b".into(),
github_login: "user_b".to_string(),
avatar_uri: "avatar_b".into(),
name: None,
}),
@@ -5738,7 +5738,7 @@ async fn test_contacts(
server.allow_connections();
client_c
.connect(false, &cx_c.to_async())
.authenticate_and_connect(false, &cx_c.to_async())
.await
.into_response()
.unwrap();
@@ -6079,7 +6079,7 @@ async fn test_contacts(
.iter()
.map(|contact| {
(
contact.user.github_login.clone().to_string(),
contact.user.github_login.clone(),
if contact.online { "online" } else { "offline" },
if contact.busy { "busy" } else { "free" },
)
@@ -6269,7 +6269,7 @@ async fn test_contact_requests(
client.disconnect(&cx.to_async());
client.clear_contacts(cx).await;
client
.connect(false, &cx.to_async())
.authenticate_and_connect(false, &cx.to_async())
.await
.into_response()
.unwrap();

View File

@@ -3,7 +3,6 @@ use std::sync::Arc;
use gpui::{BackgroundExecutor, TestAppContext};
use notifications::NotificationEvent;
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use rpc::{Notification, proto};
use crate::tests::TestServer;
@@ -18,9 +17,6 @@ async fn test_notifications(
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
// Wait for authentication/connection to Collab to be established.
executor.run_until_parked();
let notification_events_a = Arc::new(Mutex::new(Vec::new()));
let notification_events_b = Arc::new(Mutex::new(Vec::new()));
client_a.notification_store().update(cx_a, |_, cx| {

View File

@@ -8,7 +8,6 @@ use crate::{
use anyhow::anyhow;
use call::ActiveCall;
use channel::{ChannelBuffer, ChannelStore};
use client::test::{make_get_authenticated_user_response, parse_authorization_header};
use client::{
self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
proto::PeerId,
@@ -21,7 +20,7 @@ use fs::FakeFs;
use futures::{StreamExt as _, channel::oneshot};
use git::GitHostingProviderRegistry;
use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext};
use http_client::{FakeHttpClient, Method};
use http_client::FakeHttpClient;
use language::LanguageRegistry;
use node_runtime::NodeRuntime;
use notifications::NotificationStore;
@@ -162,8 +161,6 @@ impl TestServer {
}
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
const ACCESS_TOKEN: &str = "the-token";
let fs = FakeFs::new(cx.executor());
cx.update(|cx| {
@@ -178,7 +175,7 @@ impl TestServer {
});
let clock = Arc::new(FakeSystemClock::new());
let http = FakeHttpClient::with_404_response();
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
{
user.id
@@ -200,47 +197,6 @@ impl TestServer {
.expect("creating user failed")
.user_id
};
let http = FakeHttpClient::create({
let name = name.to_string();
move |req| {
let name = name.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::GET, "/client/users/me") => {
let credentials = parse_authorization_header(&req);
if credentials
!= Some(Credentials {
user_id: user_id.to_proto(),
access_token: ACCESS_TOKEN.into(),
})
{
return Ok(http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap());
}
Ok(http_client::Response::builder()
.status(200)
.body(
serde_json::to_string(&make_get_authenticated_user_response(
user_id.0, name,
))
.unwrap()
.into(),
)
.unwrap())
}
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".into())
.unwrap()),
}
}
}
});
let client_name = name.to_string();
let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx));
let server = self.server.clone();
@@ -252,10 +208,11 @@ impl TestServer {
.unwrap()
.set_id(user_id.to_proto())
.override_authenticate(move |cx| {
let access_token = "the-token".to_string();
cx.spawn(async move |_| {
Ok(Credentials {
user_id: user_id.to_proto(),
access_token: ACCESS_TOKEN.into(),
access_token,
})
})
})
@@ -264,7 +221,7 @@ impl TestServer {
credentials,
&Credentials {
user_id: user_id.0 as u64,
access_token: ACCESS_TOKEN.into(),
access_token: "the-token".into()
}
);
@@ -299,7 +256,6 @@ impl TestServer {
ZedVersion(SemanticVersion::new(1, 0, 0)),
None,
None,
None,
Some(connection_id_tx),
Executor::Deterministic(cx.background_executor().clone()),
None,
@@ -362,7 +318,7 @@ impl TestServer {
});
client
.connect(false, &cx.to_async())
.authenticate_and_connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
@@ -735,17 +691,17 @@ impl TestClient {
current: store
.contacts()
.iter()
.map(|contact| contact.user.github_login.clone().to_string())
.map(|contact| contact.user.github_login.clone())
.collect(),
outgoing_requests: store
.outgoing_contact_requests()
.iter()
.map(|user| user.github_login.clone().to_string())
.map(|user| user.github_login.clone())
.collect(),
incoming_requests: store
.incoming_contact_requests()
.iter()
.map(|user| user.github_login.clone().to_string())
.map(|user| user.github_login.clone())
.collect(),
})
}

View File

@@ -940,7 +940,7 @@ impl CollabPanel {
room.read(cx).local_participant().role == proto::ChannelRole::Admin
});
ListItem::new(user.github_login.clone())
ListItem::new(SharedString::from(user.github_login.clone()))
.start_slot(Avatar::new(user.avatar_uri.clone()))
.child(Label::new(user.github_login.clone()))
.toggle_state(is_selected)
@@ -2331,7 +2331,7 @@ impl CollabPanel {
let client = this.client.clone();
cx.spawn_in(window, async move |_, cx| {
client
.connect(true, &cx)
.authenticate_and_connect(true, &cx)
.await
.into_response()
.notify_async_err(cx);
@@ -2583,7 +2583,7 @@ impl CollabPanel {
) -> impl IntoElement {
let online = contact.online;
let busy = contact.busy || calling;
let github_login = contact.user.github_login.clone();
let github_login = SharedString::from(contact.user.github_login.clone());
let item = ListItem::new(github_login.clone())
.indent_level(1)
.indent_step_size(px(20.))
@@ -2662,7 +2662,7 @@ impl CollabPanel {
is_selected: bool,
cx: &mut Context<Self>,
) -> impl IntoElement {
let github_login = user.github_login.clone();
let github_login = SharedString::from(user.github_login.clone());
let user_id = user.id;
let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user);
let color = if is_response_pending {

View File

@@ -634,13 +634,13 @@ impl Render for NotificationPanel {
.child(Icon::new(IconName::Envelope)),
)
.map(|this| {
if !self.client.status().borrow().is_connected() {
if self.client.user_id().is_none() {
this.child(
v_flex()
.gap_2()
.p_4()
.child(
Button::new("connect_prompt_button", "Connect")
Button::new("sign_in_prompt_button", "Sign in")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
@@ -652,7 +652,10 @@ impl Render for NotificationPanel {
let client = client.clone();
window
.spawn(cx, async move |cx| {
match client.connect(true, &cx).await {
match client
.authenticate_and_connect(true, &cx)
.await
{
util::ConnectionResult::Timeout => {
log::error!("Connection timeout");
}
@@ -670,7 +673,7 @@ impl Render for NotificationPanel {
)
.child(
div().flex().w_full().items_center().child(
Label::new("Connect to view notifications.")
Label::new("Sign in to view notifications.")
.color(Color::Muted)
.size(LabelSize::Small),
),

View File

@@ -158,7 +158,6 @@ impl Client {
pub fn stdio(
server_id: ContextServerId,
binary: ModelContextServerBinary,
working_directory: &Option<PathBuf>,
cx: AsyncApp,
) -> Result<Self> {
log::info!(
@@ -173,7 +172,7 @@ impl Client {
.map(|name| name.to_string_lossy().to_string())
.unwrap_or_else(String::new);
let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?);
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
Self::new(server_id, server_name.into(), transport, cx)
}
@@ -441,12 +440,14 @@ impl Client {
Ok(())
}
pub fn on_notification(
&self,
method: &'static str,
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
) {
self.notification_handlers.lock().insert(method, f);
#[allow(unused)]
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
{
self.notification_handlers
.lock()
.insert(method, Box::new(f));
}
}

View File

@@ -53,7 +53,7 @@ impl std::fmt::Debug for ContextServerCommand {
}
enum ContextServerTransport {
Stdio(ContextServerCommand, Option<PathBuf>),
Stdio(ContextServerCommand),
Custom(Arc<dyn crate::transport::Transport>),
}
@@ -64,18 +64,11 @@ pub struct ContextServer {
}
impl ContextServer {
pub fn stdio(
id: ContextServerId,
command: ContextServerCommand,
working_directory: Option<Arc<Path>>,
) -> Self {
pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
Self {
id,
client: RwLock::new(None),
configuration: ContextServerTransport::Stdio(
command,
working_directory.map(|directory| directory.to_path_buf()),
),
configuration: ContextServerTransport::Stdio(command),
}
}
@@ -95,36 +88,15 @@ impl ContextServer {
self.client.read().clone()
}
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
self.initialize(self.new_client(cx)?).await
}
/// Starts the context server, making sure handlers are registered before initialization happens
pub async fn start_with_handlers(
&self,
notification_handlers: Vec<(
&'static str,
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
)>,
cx: &AsyncApp,
) -> Result<()> {
let client = self.new_client(cx)?;
for (method, handler) in notification_handlers {
client.on_notification(method, handler);
}
self.initialize(client).await
}
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
let client = match &self.configuration {
ContextServerTransport::Stdio(command) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
executable: Path::new(&command.path).to_path_buf(),
args: command.args.clone(),
env: command.env.clone(),
},
working_directory,
cx.clone(),
)?,
ContextServerTransport::Custom(transport) => Client::new(
@@ -133,7 +105,8 @@ impl ContextServer {
transport.clone(),
cx.clone(),
)?,
})
};
self.initialize(client).await
}
async fn initialize(&self, client: Client) -> Result<()> {

View File

@@ -83,18 +83,14 @@ impl McpServer {
}
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
let mut settings = schemars::generate::SchemaSettings::draft07();
settings.inline_subschemas = true;
let mut generator = settings.into_generator();
let output_schema = generator.root_schema_for::<T::Output>();
let unit_schema = generator.root_schema_for::<T::Output>();
let output_schema = schemars::schema_for!(T::Output);
let unit_schema = schemars::schema_for!(());
let registered_tool = RegisteredTool {
tool: Tool {
name: T::NAME.into(),
description: Some(tool.description().into()),
input_schema: generator.root_schema_for::<T::Input>().into(),
input_schema: schemars::schema_for!(T::Input).into(),
output_schema: if output_schema == unit_schema {
None
} else {

View File

@@ -115,11 +115,10 @@ impl InitializedContextServerProtocol {
self.inner.notify(T::METHOD, params)
}
pub fn on_notification(
&self,
method: &'static str,
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
) {
pub fn on_notification<F>(&self, method: &'static str, f: F)
where
F: 'static + Send + FnMut(Value, AsyncApp),
{
self.inner.on_notification(method, f);
}
}

View File

@@ -1,4 +1,3 @@
use std::path::PathBuf;
use std::pin::Pin;
use anyhow::{Context as _, Result};
@@ -23,11 +22,7 @@ pub struct StdioTransport {
}
impl StdioTransport {
pub fn new(
binary: ModelContextServerBinary,
working_directory: &Option<PathBuf>,
cx: &AsyncApp,
) -> Result<Self> {
pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
let mut command = util::command::new_smol_command(&binary.executable);
command
.args(&binary.args)
@@ -37,10 +32,6 @@ impl StdioTransport {
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
if let Some(working_directory) = working_directory {
command.current_dir(working_directory);
}
let mut server = command.spawn().with_context(|| {
format!(
"failed to spawn command. (path={:?}, args={:?})",

View File

@@ -295,7 +295,7 @@ mod tests {
request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
},
},
Box::new(|_| {}),
Box::new(|_| panic!("Did not expect to hit this code path")),
&mut cx.to_async(),
)
.await

Some files were not shown because too many files have changed in this diff Show More