Compare commits
207 Commits
new-acp-na
...
windows/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a3d56749e | ||
|
|
3519e8fd7c | ||
|
|
b3372e7eac | ||
|
|
4f7bb14acf | ||
|
|
99c5b72b3d | ||
|
|
c995dd2016 | ||
|
|
80be0e29b9 | ||
|
|
b75f6e2210 | ||
|
|
b8f85be372 | ||
|
|
34e433ad90 | ||
|
|
6a91ac26d7 | ||
|
|
345fd526fc | ||
|
|
7c8074ce5c | ||
|
|
0a0803e2a7 | ||
|
|
8c8b91470a | ||
|
|
98692cc928 | ||
|
|
d194bf4f52 | ||
|
|
cc763729a0 | ||
|
|
e370c3d601 | ||
|
|
1d60984cb6 | ||
|
|
eb3bb95c91 | ||
|
|
554b36fd3c | ||
|
|
89a863d012 | ||
|
|
441731de2e | ||
|
|
ead7a1e1f0 | ||
|
|
73eaee8f6f | ||
|
|
98f31172ab | ||
|
|
181f324473 | ||
|
|
a1f03ee42c | ||
|
|
741b38f906 | ||
|
|
599b82fc9d | ||
|
|
64b3b050e3 | ||
|
|
62d1b7e36f | ||
|
|
d7b14d8dc5 | ||
|
|
ce67ce1482 | ||
|
|
8eea9aad40 | ||
|
|
e9697e4639 | ||
|
|
92b0a7e760 | ||
|
|
d96bafb1e5 | ||
|
|
5f3a1bdbd1 | ||
|
|
1ee81a507b | ||
|
|
89d34e1513 | ||
|
|
a9058346bf | ||
|
|
ac1ea0f96d | ||
|
|
0065e5fd76 | ||
|
|
ca6aa25d1e | ||
|
|
6964cecc14 | ||
|
|
63daf44693 | ||
|
|
4de2ebf954 | ||
|
|
3277640f55 | ||
|
|
d192ac6b7f | ||
|
|
c67ddd7572 | ||
|
|
b54eaecbbc | ||
|
|
2744e6cb65 | ||
|
|
18937f5756 | ||
|
|
347b863ac6 | ||
|
|
9d8ef8156d | ||
|
|
9dbbee0334 | ||
|
|
32f2505fbf | ||
|
|
2711d8823c | ||
|
|
787fee8a1a | ||
|
|
be7d56e11b | ||
|
|
fcb77979f3 | ||
|
|
787c6382f9 | ||
|
|
74d953d024 | ||
|
|
b5377c56f2 | ||
|
|
275d84d566 | ||
|
|
3978bba5a7 | ||
|
|
52c0fa5ce9 | ||
|
|
d208f75f46 | ||
|
|
1b0a0aa58e | ||
|
|
5ff9114b18 | ||
|
|
d9c6d09545 | ||
|
|
61981aabb5 | ||
|
|
0b57c86e07 | ||
|
|
c7342a9df5 | ||
|
|
0e45ef7e43 | ||
|
|
0c40bb9b5f | ||
|
|
5058752f2d | ||
|
|
432d11f57b | ||
|
|
32488e1e2d | ||
|
|
9acee42c38 | ||
|
|
72c55b4653 | ||
|
|
fa1320d9aa | ||
|
|
eb310bcf7d | ||
|
|
8c1d9f75d1 | ||
|
|
499b3b6b50 | ||
|
|
c6e020f60f | ||
|
|
7ab2d0d800 | ||
|
|
c007121b41 | ||
|
|
22c9d133bd | ||
|
|
32758022df | ||
|
|
0d8600bf1e | ||
|
|
22cba07072 | ||
|
|
642d769502 | ||
|
|
bfdcc65801 | ||
|
|
54e2420405 | ||
|
|
b012246d2b | ||
|
|
667c19907a | ||
|
|
5261c02d18 | ||
|
|
204071e6bf | ||
|
|
5472c71f1a | ||
|
|
723712e3cf | ||
|
|
0c274370c3 | ||
|
|
31fab3a37a | ||
|
|
4f416d3818 | ||
|
|
ffef9fd25a | ||
|
|
2a6b83f190 | ||
|
|
1b12dd39cc | ||
|
|
9162583bac | ||
|
|
8075998c09 | ||
|
|
3b6105b713 | ||
|
|
2b53a2cb12 | ||
|
|
96d847b6d1 | ||
|
|
7fde34f85e | ||
|
|
401e0e6f41 | ||
|
|
201c274c4b | ||
|
|
ecde968a0c | ||
|
|
4a78ce7cfd | ||
|
|
fda3d56d87 | ||
|
|
9c3cfca835 | ||
|
|
1fb689bad3 | ||
|
|
238ccec5ee | ||
|
|
c8ae5a3b11 | ||
|
|
dbe2ce2464 | ||
|
|
5287183667 | ||
|
|
a48ae50e1a | ||
|
|
ca3d55ee4d | ||
|
|
c0bad42968 | ||
|
|
7186f1322e | ||
|
|
21e14b5f9a | ||
|
|
7d84014ad2 | ||
|
|
68780da673 | ||
|
|
1f55a0a358 | ||
|
|
ba80e16339 | ||
|
|
11dc14ad4d | ||
|
|
9f200ebf5a | ||
|
|
788865e892 | ||
|
|
e87ee91d8e | ||
|
|
b0e48d01ce | ||
|
|
825ee6233b | ||
|
|
154705e729 | ||
|
|
636a057373 | ||
|
|
df1f62477c | ||
|
|
6907064be6 | ||
|
|
c1eaf3317d | ||
|
|
6477a9b056 | ||
|
|
84f75fe683 | ||
|
|
7627097875 | ||
|
|
78824390d0 | ||
|
|
4d936845f3 | ||
|
|
76fb80eaeb | ||
|
|
29b5acf27b | ||
|
|
e560c6813f | ||
|
|
a57cbe4636 | ||
|
|
7cf10d110c | ||
|
|
1888f21a14 | ||
|
|
63727f99da | ||
|
|
602bd189f6 | ||
|
|
b8314e74db | ||
|
|
a486bb28f6 | ||
|
|
b1b5a383e0 | ||
|
|
b0fe5fd56f | ||
|
|
398d492f85 | ||
|
|
55edee58fb | ||
|
|
da3736bd5f | ||
|
|
4b2ff5e251 | ||
|
|
46fc76fdf8 | ||
|
|
ffbb47452d | ||
|
|
5ed8b13e4a | ||
|
|
1baafae3f7 | ||
|
|
2017ce3699 | ||
|
|
f715acc92a | ||
|
|
291691ca0e | ||
|
|
158732eb17 | ||
|
|
cdbaff8375 | ||
|
|
c014dbae8c | ||
|
|
83d942611f | ||
|
|
f16f07b36f | ||
|
|
85cf9e405e | ||
|
|
a1c00ed87f | ||
|
|
34d5926ebd | ||
|
|
6fc8d7746f | ||
|
|
2fb31a9157 | ||
|
|
a7e34ab0bc | ||
|
|
6928488aad | ||
|
|
8514850ad4 | ||
|
|
231c38aa41 | ||
|
|
8d538fad0c | ||
|
|
f5aa88ca6a | ||
|
|
b9eb18eb7f | ||
|
|
b130346ede | ||
|
|
e8bd47f668 | ||
|
|
6a918b64bf | ||
|
|
c82edc38a9 | ||
|
|
622a42e3aa | ||
|
|
dcdd7404e4 | ||
|
|
52c181328c | ||
|
|
2319cd8211 | ||
|
|
d0a2257472 | ||
|
|
af2009710a | ||
|
|
eec406bb36 | ||
|
|
83ea328be5 | ||
|
|
f2c847a1b0 | ||
|
|
5d03296dc2 | ||
|
|
b4771bc4f8 | ||
|
|
68192052fd |
8
.github/actions/build_docs/action.yml
vendored
@@ -19,7 +19,7 @@ runs:
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: ./script/linux
|
||||
|
||||
- name: Check for broken links (in MD)
|
||||
- name: Check for broken links
|
||||
uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1
|
||||
with:
|
||||
args: --no-progress --exclude '^http' './docs/src/**/*'
|
||||
@@ -30,9 +30,3 @@ runs:
|
||||
run: |
|
||||
mkdir -p target/deploy
|
||||
mdbook build ./docs --dest-dir=../target/deploy/docs/
|
||||
|
||||
- name: Check for broken links (in HTML)
|
||||
uses: lycheeverse/lychee-action@82202e5e9c2f4ef1a55a3d02563e1cb6041e5332 # v2.4.1
|
||||
with:
|
||||
args: --no-progress --exclude '^http' 'target/deploy/docs/'
|
||||
fail: true
|
||||
|
||||
23
.github/workflows/ci.yml
vendored
@@ -771,8 +771,7 @@ jobs:
|
||||
timeout-minutes: 120
|
||||
name: Create a Windows installer
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
# if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
if: true && (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
needs: [windows_tests]
|
||||
env:
|
||||
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}
|
||||
@@ -808,16 +807,16 @@ jobs:
|
||||
name: ZedEditorUserSetup-x64-${{ github.event.pull_request.head.sha || github.sha }}.exe
|
||||
path: ${{ env.SETUP_PATH }}
|
||||
|
||||
- name: Upload Artifacts to release
|
||||
uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
# Re-enable when we are ready to publish windows preview releases
|
||||
if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
|
||||
with:
|
||||
draft: true
|
||||
prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
|
||||
files: ${{ env.SETUP_PATH }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# - name: Upload Artifacts to release
|
||||
# uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 # v1
|
||||
# # Re-enable when we are ready to publish windows preview releases
|
||||
# if: ${{ !(contains(github.event.pull_request.labels.*.name, 'run-bundling')) && env.RELEASE_CHANNEL == 'preview' }} # upload only preview
|
||||
# with:
|
||||
# draft: true
|
||||
# prerelease: ${{ env.RELEASE_CHANNEL == 'preview' }}
|
||||
# files: ${{ env.SETUP_PATH }}
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
auto-release-preview:
|
||||
name: Auto release preview
|
||||
|
||||
574
Cargo.lock
generated
28
Cargo.toml
@@ -1,13 +1,13 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/acp_thread",
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent_servers",
|
||||
"crates/agent_settings",
|
||||
"crates/acp_thread",
|
||||
"crates/agent_ui",
|
||||
"crates/agent",
|
||||
"crates/agent_settings",
|
||||
"crates/ai_onboarding",
|
||||
"crates/agent_servers",
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
"crates/assets",
|
||||
@@ -29,9 +29,6 @@ members = [
|
||||
"crates/cli",
|
||||
"crates/client",
|
||||
"crates/clock",
|
||||
"crates/cloud_api_client",
|
||||
"crates/cloud_api_types",
|
||||
"crates/cloud_llm_client",
|
||||
"crates/collab",
|
||||
"crates/collab_ui",
|
||||
"crates/collections",
|
||||
@@ -51,8 +48,8 @@ members = [
|
||||
"crates/diagnostics",
|
||||
"crates/docs_preprocessor",
|
||||
"crates/editor",
|
||||
"crates/eval",
|
||||
"crates/explorer_command_injector",
|
||||
"crates/eval",
|
||||
"crates/extension",
|
||||
"crates/extension_api",
|
||||
"crates/extension_cli",
|
||||
@@ -73,6 +70,7 @@ members = [
|
||||
"crates/gpui",
|
||||
"crates/gpui_macros",
|
||||
"crates/gpui_tokio",
|
||||
|
||||
"crates/html_to_markdown",
|
||||
"crates/http_client",
|
||||
"crates/http_client_tls",
|
||||
@@ -101,6 +99,7 @@ members = [
|
||||
"crates/markdown_preview",
|
||||
"crates/media",
|
||||
"crates/menu",
|
||||
"crates/svg_preview",
|
||||
"crates/migrator",
|
||||
"crates/mistral",
|
||||
"crates/multi_buffer",
|
||||
@@ -141,7 +140,6 @@ members = [
|
||||
"crates/semantic_version",
|
||||
"crates/session",
|
||||
"crates/settings",
|
||||
"crates/settings_profile_selector",
|
||||
"crates/settings_ui",
|
||||
"crates/snippet",
|
||||
"crates/snippet_provider",
|
||||
@@ -154,7 +152,6 @@ members = [
|
||||
"crates/sum_tree",
|
||||
"crates/supermaven",
|
||||
"crates/supermaven_api",
|
||||
"crates/svg_preview",
|
||||
"crates/tab_switcher",
|
||||
"crates/task",
|
||||
"crates/tasks_ui",
|
||||
@@ -189,7 +186,6 @@ members = [
|
||||
"crates/zed",
|
||||
"crates/zed_actions",
|
||||
"crates/zeta",
|
||||
"crates/zeta_cli",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
|
||||
@@ -255,9 +251,6 @@ channel = { path = "crates/channel" }
|
||||
cli = { path = "crates/cli" }
|
||||
client = { path = "crates/client" }
|
||||
clock = { path = "crates/clock" }
|
||||
cloud_api_client = { path = "crates/cloud_api_client" }
|
||||
cloud_api_types = { path = "crates/cloud_api_types" }
|
||||
cloud_llm_client = { path = "crates/cloud_llm_client" }
|
||||
collab = { path = "crates/collab" }
|
||||
collab_ui = { path = "crates/collab_ui" }
|
||||
collections = { path = "crates/collections" }
|
||||
@@ -344,7 +337,6 @@ picker = { path = "crates/picker" }
|
||||
plugin = { path = "crates/plugin" }
|
||||
plugin_macros = { path = "crates/plugin_macros" }
|
||||
prettier = { path = "crates/prettier" }
|
||||
settings_profile_selector = { path = "crates/settings_profile_selector" }
|
||||
project = { path = "crates/project" }
|
||||
project_panel = { path = "crates/project_panel" }
|
||||
project_symbols = { path = "crates/project_symbols" }
|
||||
@@ -421,7 +413,7 @@ zlog_settings = { path = "crates/zlog_settings" }
|
||||
#
|
||||
|
||||
agentic-coding-protocol = "0.0.10"
|
||||
agent-client-protocol = {path="../agent-client-protocol"}
|
||||
agent-client-protocol = "0.0.11"
|
||||
aho-corasick = "1.1"
|
||||
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
|
||||
any_vec = "0.14"
|
||||
@@ -653,6 +645,7 @@ which = "6.0.0"
|
||||
windows-core = "0.61"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "= 0.8.6"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
@@ -679,6 +672,8 @@ features = [
|
||||
"UI_ViewManagement",
|
||||
"Wdk_System_SystemServices",
|
||||
"Win32_Globalization",
|
||||
"Win32_Graphics_Direct2D",
|
||||
"Win32_Graphics_Direct2D_Common",
|
||||
"Win32_Graphics_Direct3D",
|
||||
"Win32_Graphics_Direct3D11",
|
||||
"Win32_Graphics_Direct3D_Fxc",
|
||||
@@ -689,6 +684,7 @@ features = [
|
||||
"Win32_Graphics_Dxgi_Common",
|
||||
"Win32_Graphics_Gdi",
|
||||
"Win32_Graphics_Imaging",
|
||||
"Win32_Graphics_Imaging_D2D",
|
||||
"Win32_Networking_WinSock",
|
||||
"Win32_Security",
|
||||
"Win32_Security_Credentials",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Zed
|
||||
|
||||
[](https://zed.dev)
|
||||
[](https://github.com/zed-industries/zed/actions/workflows/ci.yml)
|
||||
|
||||
Welcome to Zed, a high-performance, multiplayer code editor from the creators of [Atom](https://github.com/atom/atom) and [Tree-sitter](https://github.com/tree-sitter/tree-sitter).
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"label": "",
|
||||
"message": "Zed",
|
||||
"logoSvg": "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 96 96\"><rect width=\"96\" height=\"96\" fill=\"#000\"/><path fill-rule=\"evenodd\" clip-rule=\"evenodd\" d=\"M9 6C7.34315 6 6 7.34315 6 9V75H0V9C0 4.02944 4.02944 0 9 0H89.3787C93.3878 0 95.3955 4.84715 92.5607 7.68198L43.0551 57.1875H57V51H63V58.6875C63 61.1728 60.9853 63.1875 58.5 63.1875H37.0551L26.7426 73.5H73.5V36H79.5V73.5C79.5 76.8137 76.8137 79.5 73.5 79.5H20.7426L10.2426 90H87C88.6569 90 90 88.6569 90 87V21H96V87C96 91.9706 91.9706 96 87 96H6.62132C2.61224 96 0.604504 91.1529 3.43934 88.318L52.7574 39H39V45H33V37.5C33 35.0147 35.0147 33 37.5 33H58.7574L69.2574 22.5H22.5V60H16.5V22.5C16.5 19.1863 19.1863 16.5 22.5 16.5H75.2574L85.7574 6H9Z\" fill=\"#fff\"/></svg>",
|
||||
"logoWidth": 16,
|
||||
"labelColor": "black",
|
||||
"color": "white"
|
||||
}
|
||||
|
Before Width: | Height: | Size: 6.3 KiB |
@@ -1,9 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path opacity="0.6" d="M3.5 11V5.5L8.5 8L3.5 11Z" fill="black"/>
|
||||
<path opacity="0.4" d="M8.5 14L3.5 11L8.5 8V14Z" fill="black"/>
|
||||
<path opacity="0.6" d="M8.5 5.5H3.5L8.5 2.5L8.5 5.5Z" fill="black"/>
|
||||
<path opacity="0.8" d="M8.5 5.5V2.5L13.5 5.5H8.5Z" fill="black"/>
|
||||
<path opacity="0.2" d="M13.5 11L8.5 14L11 9.5L13.5 11Z" fill="black"/>
|
||||
<path opacity="0.5" d="M13.5 11L11 9.5L13.5 5V11Z" fill="black"/>
|
||||
<path d="M3.5 11V5L8.5 2.11325L13.5 5V11L8.5 13.8868L3.5 11Z" stroke="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 583 B |
@@ -1,10 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_2716_663)">
|
||||
<path d="M8.47552 2.45453C11.5167 2.45457 13.9814 4.94501 13.9814 8.01623C13.9814 11.0875 11.5167 13.578 8.47552 13.5781C5.43427 13.5781 2.96948 11.0875 2.96948 8.01623C2.9695 4.94498 5.43429 2.45453 8.47552 2.45453ZM10.8795 4.70348C10.7605 4.16887 10.1328 3.85468 9.53627 3.96342C8.97622 4.06552 7.62871 4.45681 7.62057 4.45916C9.29414 4.44469 9.57429 4.4726 9.69939 4.64751C9.77324 4.7508 9.66576 4.89248 9.21944 4.96538C8.73515 5.04447 7.73014 5.13958 7.72343 5.14022C6.75441 5.19776 6.07177 5.20168 5.86705 5.63512C5.73334 5.91827 6.00968 6.16857 6.13082 6.32527C6.64271 6.89455 7.38215 7.20158 7.85809 7.42767C8.03716 7.51274 8.56257 7.67345 8.56257 7.67345C7.01855 7.58853 5.90474 8.06267 5.2514 8.60855C4.51246 9.29204 4.83937 10.1067 6.35327 10.6084C7.24742 10.9047 7.69094 11.0439 9.02473 10.9238C9.81031 10.8815 9.9342 10.9068 9.94203 10.9712C9.95275 11.062 9.06932 11.2874 8.82812 11.357C8.21455 11.534 6.60645 11.8913 6.59758 11.8932C6.60115 11.8935 7.06249 11.9257 7.65531 11.8735C7.89632 11.8522 8.81142 11.7624 9.49557 11.6123C9.49557 11.6123 10.3297 11.4338 10.7759 11.2693C11.2429 11.0973 11.497 10.9512 11.6113 10.7443C11.6063 10.7019 11.6465 10.5516 11.4313 10.4613C10.8807 10.2304 10.2423 10.2721 8.9789 10.2453C7.57789 10.1972 7.11184 9.9626 6.86356 9.77373C6.62548 9.58212 6.74518 9.05204 7.76528 8.5851C8.27917 8.33646 10.2935 7.87759 10.2935 7.87759C9.61511 7.54227 8.35014 6.95284 8.09005 6.82552C7.86199 6.71388 7.49701 6.54572 7.4179 6.34233C7.32824 6.14709 7.6297 5.97888 7.79813 5.9307C8.34057 5.77424 9.10635 5.67701 9.8033 5.66609C10.1536 5.66061 10.2105 5.63806 10.2105 5.63806C10.6939 5.55787 11.0121 5.22722 10.8795 4.70348Z" fill="black"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_2716_663">
|
||||
<rect width="12" height="12" fill="white" transform="translate(2.5 2)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
@@ -1,3 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M3.6725 13.9985C3.36161 13.9982 3.06354 13.8746 2.84371 13.6548C2.62388 13.435 2.50026 13.1369 2.5 12.826V7.494C2.5 6.8325 2.7675 6.185 3.2365 5.7165L6.219 2.736C6.45192 2.50247 6.72867 2.31724 7.03335 2.19094C7.33804 2.06464 7.66467 1.99975 7.9945 2H13.3275C13.6384 2.00027 13.9365 2.12388 14.1563 2.34371C14.3761 2.56354 14.4997 2.86162 14.5 3.1725V8.5045C14.4983 9.17074 14.2336 9.80936 13.7635 10.2815L10.781 13.264C10.5477 13.4976 10.2706 13.6829 9.96561 13.8092C9.66059 13.9355 9.33364 14.0003 9.0035 14V13.9985H3.6725ZM8.157 10.5715H5.243V11.257H8.157V10.5715ZM4.4815 5.257H11.243V12.0165L13.3715 9.888C13.7373 9.52036 13.9433 9.02316 13.9445 8.5045V3.1725C13.9445 2.8335 13.6685 2.5555 13.3275 2.5555H7.9945C7.73753 2.55499 7.483 2.6053 7.24556 2.70356C7.00813 2.80181 6.79246 2.94606 6.611 3.128L4.4815 5.257ZM4.3855 5.353L3.628 6.11C3.26258 6.47809 3.0569 6.97533 3.0555 7.494V12.826C3.0555 13.165 3.3315 13.443 3.6725 13.443H9.0055C9.26249 13.4434 9.51701 13.3929 9.75445 13.2946C9.99188 13.1963 10.2075 13.052 10.389 12.87L11.145 12.1145H4.3855V5.353Z" fill="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.2 KiB |
@@ -1,5 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M13.0945 8.01611C13.0945 7.87619 12.9911 7.79551 12.8642 7.8356L4.13456 10.6038C4.00742 10.6441 3.90427 10.7904 3.90427 10.9301V13.7593C3.90427 13.8992 4.00742 13.9801 4.13456 13.9398L12.8642 11.1719C12.9911 11.1315 13.0945 10.9852 13.0945 10.8453V8.01611Z" fill="black"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M3.90427 7.92597C3.90427 8.06588 4.00742 8.21218 4.13456 8.25252L12.8655 11.0209C12.9926 11.0613 13.0958 10.9803 13.0958 10.8407V8.01124C13.0958 7.87158 12.9926 7.72529 12.8655 7.68494L4.13456 4.91652C4.00742 4.87618 3.90427 4.95686 3.90427 5.09677V7.92597Z" fill="black"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M13.0945 2.20248C13.0945 2.06256 12.9911 1.98163 12.8642 2.02197L4.13456 4.78988C4.00742 4.83022 3.90427 4.97652 3.90427 5.11644V7.94563C3.90427 8.08554 4.00742 8.16622 4.13456 8.12614L12.8642 5.35797C12.9911 5.31763 13.0945 5.17133 13.0945 5.03167V2.20248Z" fill="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.0 KiB |
@@ -1,3 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M11.0094 13.9181C11.1984 13.9917 11.4139 13.987 11.6047 13.8952L14.0753 12.7064C14.3349 12.5814 14.5 12.3187 14.5 12.0305V3.9696C14.5 3.68136 14.3349 3.41862 14.0753 3.2937L11.6047 2.10485C11.3543 1.98438 11.0614 2.01389 10.8416 2.17363C10.8102 2.19645 10.7803 2.22193 10.7523 2.25001L6.02261 6.56498L3.96246 5.00115C3.77068 4.85558 3.50244 4.86751 3.32432 5.02953L2.66356 5.63059C2.44569 5.82877 2.44544 6.17152 2.66302 6.37004L4.44965 8.00001L2.66302 9.62998C2.44544 9.82849 2.44569 10.1713 2.66356 10.3694L3.32432 10.9705C3.50244 11.1325 3.77068 11.1444 3.96246 10.9989L6.02261 9.43504L10.7523 13.75C10.8271 13.8249 10.915 13.8812 11.0094 13.9181ZM11.5018 5.27587L7.91309 8.00001L11.5018 10.7241V5.27587Z" fill="black"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 876 B |
@@ -1,4 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M13.0001 8.62505C13.0001 11.75 10.8126 13.3125 8.21266 14.2187C8.07651 14.2648 7.92862 14.2626 7.79392 14.2125C5.18771 13.3125 3.00024 11.75 3.00024 8.62505V4.25012C3.00024 4.08436 3.06609 3.92539 3.1833 3.80818C3.30051 3.69098 3.45948 3.62513 3.62523 3.62513C4.87521 3.62513 6.43769 2.87514 7.52517 1.92516C7.65758 1.81203 7.82601 1.74988 8.00016 1.74988C8.17431 1.74988 8.34275 1.81203 8.47515 1.92516C9.56889 2.88139 11.1251 3.62513 12.3751 3.62513C12.5408 3.62513 12.6998 3.69098 12.817 3.80818C12.9342 3.92539 13.0001 4.08436 13.0001 4.25012V8.62505Z" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M6 8.00002L7.33333 9.33335L10 6.66669" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 883 B |
@@ -232,7 +232,7 @@
|
||||
"ctrl-n": "agent::NewThread",
|
||||
"ctrl-alt-n": "agent::NewTextThread",
|
||||
"ctrl-shift-h": "agent::OpenHistory",
|
||||
"ctrl-alt-c": "agent::OpenSettings",
|
||||
"ctrl-alt-c": "agent::OpenConfiguration",
|
||||
"ctrl-alt-p": "agent::OpenRulesLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "agent::ToggleModelSelector",
|
||||
@@ -495,7 +495,7 @@
|
||||
"shift-f12": "editor::GoToImplementation",
|
||||
"alt-ctrl-f12": "editor::GoToTypeDefinitionSplit",
|
||||
"alt-shift-f12": "editor::FindAllReferences",
|
||||
"ctrl-m": "editor::MoveToEnclosingBracket", // from jetbrains
|
||||
"ctrl-m": "editor::MoveToEnclosingBracket",
|
||||
"ctrl-|": "editor::MoveToEnclosingBracket",
|
||||
"ctrl-{": "editor::Fold",
|
||||
"ctrl-}": "editor::UnfoldLines",
|
||||
@@ -598,7 +598,6 @@
|
||||
"ctrl-shift-t": "pane::ReopenClosedItem",
|
||||
"ctrl-k ctrl-s": "zed::OpenKeymapEditor",
|
||||
"ctrl-k ctrl-t": "theme_selector::Toggle",
|
||||
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
|
||||
"ctrl-t": "project_symbols::Toggle",
|
||||
"ctrl-p": "file_finder::Toggle",
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
@@ -1168,14 +1167,5 @@
|
||||
"up": "menu::SelectPrevious",
|
||||
"down": "menu::SelectNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Onboarding",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"ctrl-1": "onboarding::ActivateBasicsPage",
|
||||
"ctrl-2": "onboarding::ActivateEditingPage",
|
||||
"ctrl-3": "onboarding::ActivateAISetupPage"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -272,7 +272,7 @@
|
||||
"cmd-n": "agent::NewThread",
|
||||
"cmd-alt-n": "agent::NewTextThread",
|
||||
"cmd-shift-h": "agent::OpenHistory",
|
||||
"cmd-alt-c": "agent::OpenSettings",
|
||||
"cmd-alt-c": "agent::OpenConfiguration",
|
||||
"cmd-alt-p": "agent::OpenRulesLibrary",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
@@ -549,7 +549,7 @@
|
||||
"alt-cmd-f12": "editor::GoToTypeDefinitionSplit",
|
||||
"alt-shift-f12": "editor::FindAllReferences",
|
||||
"cmd-|": "editor::MoveToEnclosingBracket",
|
||||
"ctrl-m": "editor::MoveToEnclosingBracket", // From Jetbrains
|
||||
"ctrl-m": "editor::MoveToEnclosingBracket",
|
||||
"alt-cmd-[": "editor::Fold",
|
||||
"alt-cmd-]": "editor::UnfoldLines",
|
||||
"cmd-k cmd-l": "editor::ToggleFold",
|
||||
@@ -665,7 +665,6 @@
|
||||
"cmd-shift-t": "pane::ReopenClosedItem",
|
||||
"cmd-k cmd-s": "zed::OpenKeymapEditor",
|
||||
"cmd-k cmd-t": "theme_selector::Toggle",
|
||||
"ctrl-alt-cmd-p": "settings_profile_selector::Toggle",
|
||||
"cmd-t": "project_symbols::Toggle",
|
||||
"cmd-p": "file_finder::Toggle",
|
||||
"ctrl-tab": "tab_switcher::Toggle",
|
||||
@@ -1270,14 +1269,5 @@
|
||||
"up": "menu::SelectPrevious",
|
||||
"down": "menu::SelectNext"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Onboarding",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-1": "onboarding::ActivateBasicsPage",
|
||||
"cmd-2": "onboarding::ActivateEditingPage",
|
||||
"cmd-3": "onboarding::ActivateAISetupPage"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"ctrl-shift-i": "agent::ToggleFocus",
|
||||
"ctrl-l": "agent::ToggleFocus",
|
||||
"ctrl-shift-l": "agent::ToggleFocus",
|
||||
"ctrl-shift-j": "agent::OpenSettings"
|
||||
"ctrl-shift-j": "agent::OpenConfiguration"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -95,7 +95,7 @@
|
||||
"ctrl-shift-r": ["pane::DeploySearch", { "replace_enabled": true }],
|
||||
"alt-shift-f10": "task::Spawn",
|
||||
"ctrl-e": "file_finder::Toggle",
|
||||
// "ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
"ctrl-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
"ctrl-shift-n": "file_finder::Toggle",
|
||||
"ctrl-shift-a": "command_palette::Toggle",
|
||||
"shift shift": "command_palette::Toggle",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"cmd-shift-i": "agent::ToggleFocus",
|
||||
"cmd-l": "agent::ToggleFocus",
|
||||
"cmd-shift-l": "agent::ToggleFocus",
|
||||
"cmd-shift-j": "agent::OpenSettings"
|
||||
"cmd-shift-j": "agent::OpenConfiguration"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -97,7 +97,7 @@
|
||||
"cmd-shift-r": ["pane::DeploySearch", { "replace_enabled": true }],
|
||||
"ctrl-alt-r": "task::Spawn",
|
||||
"cmd-e": "file_finder::Toggle",
|
||||
// "cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
"cmd-k": "git_panel::ToggleFocus", // bug: This should also focus commit editor
|
||||
"cmd-shift-o": "file_finder::Toggle",
|
||||
"cmd-shift-a": "command_palette::Toggle",
|
||||
"shift shift": "command_palette::Toggle",
|
||||
|
||||
@@ -1877,25 +1877,5 @@
|
||||
"save_breakpoints": true,
|
||||
"dock": "bottom",
|
||||
"button": true
|
||||
},
|
||||
// Configures any number of settings profiles that are temporarily applied on
|
||||
// top of your existing user settings when selected from
|
||||
// `settings profile selector: toggle`.
|
||||
// Examples:
|
||||
// "profiles": {
|
||||
// "Presenting": {
|
||||
// "agent_font_size": 20.0,
|
||||
// "buffer_font_size": 20.0,
|
||||
// "theme": "One Light",
|
||||
// "ui_font_size": 20.0
|
||||
// },
|
||||
// "Python (ty)": {
|
||||
// "languages": {
|
||||
// "Python": {
|
||||
// "language_servers": ["ty"]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
"profiles": []
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,7 +391,7 @@ impl ToolCallContent {
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
match content {
|
||||
acp::ToolCallContent::Content { content } => Self::ContentBlock {
|
||||
acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
|
||||
content: ContentBlock::new(content, &language_registry, cx),
|
||||
},
|
||||
acp::ToolCallContent::Diff { diff } => Self::Diff {
|
||||
@@ -580,9 +580,6 @@ pub struct AcpThread {
|
||||
pub enum AcpThreadEvent {
|
||||
NewEntry,
|
||||
EntryUpdated(usize),
|
||||
ToolAuthorizationRequired,
|
||||
Stopped,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl EventEmitter<AcpThreadEvent> for AcpThread {}
|
||||
@@ -619,7 +616,6 @@ impl Error for LoadError {}
|
||||
|
||||
impl AcpThread {
|
||||
pub fn new(
|
||||
title: impl Into<SharedString>,
|
||||
connection: Rc<dyn AgentConnection>,
|
||||
project: Entity<Project>,
|
||||
session_id: acp::SessionId,
|
||||
@@ -632,7 +628,7 @@ impl AcpThread {
|
||||
shared_buffers: Default::default(),
|
||||
entries: Default::default(),
|
||||
plan: Default::default(),
|
||||
title: title.into(),
|
||||
title: connection.name().into(),
|
||||
project,
|
||||
send_task: None,
|
||||
connection,
|
||||
@@ -680,32 +676,20 @@ impl AcpThread {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn used_tools_since_last_user_message(&self) -> bool {
|
||||
for entry in self.entries.iter().rev() {
|
||||
match entry {
|
||||
AgentThreadEntry::UserMessage(..) => return false,
|
||||
AgentThreadEntry::AssistantMessage(..) => continue,
|
||||
AgentThreadEntry::ToolCall(..) => return true,
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn handle_session_update(
|
||||
&mut self,
|
||||
update: acp::SessionUpdate,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Result<()> {
|
||||
match update {
|
||||
acp::SessionUpdate::UserMessageChunk { content } => {
|
||||
self.push_user_content_block(content, cx);
|
||||
acp::SessionUpdate::UserMessage(content_block) => {
|
||||
self.push_user_content_block(content_block, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentMessageChunk { content } => {
|
||||
self.push_assistant_content_block(content, false, cx);
|
||||
acp::SessionUpdate::AgentMessageChunk(content_block) => {
|
||||
self.push_assistant_content_block(content_block, false, cx);
|
||||
}
|
||||
acp::SessionUpdate::AgentThoughtChunk { content } => {
|
||||
self.push_assistant_content_block(content, true, cx);
|
||||
acp::SessionUpdate::AgentThoughtChunk(content_block) => {
|
||||
self.push_assistant_content_block(content_block, true, cx);
|
||||
}
|
||||
acp::SessionUpdate::ToolCall(tool_call) => {
|
||||
self.upsert_tool_call(tool_call, cx);
|
||||
@@ -895,7 +879,6 @@ impl AcpThread {
|
||||
};
|
||||
|
||||
self.upsert_tool_call_inner(tool_call, status, cx);
|
||||
cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
|
||||
rx
|
||||
}
|
||||
|
||||
@@ -974,6 +957,10 @@ impl AcpThread {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
|
||||
self.connection.authenticate(cx)
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn send_raw(
|
||||
&mut self,
|
||||
@@ -1015,7 +1002,7 @@ impl AcpThread {
|
||||
let result = this
|
||||
.update(cx, |this, cx| {
|
||||
this.connection.prompt(
|
||||
acp::PromptRequest {
|
||||
acp::PromptArguments {
|
||||
prompt: message,
|
||||
session_id: this.session_id.clone(),
|
||||
},
|
||||
@@ -1031,18 +1018,12 @@ impl AcpThread {
|
||||
.log_err();
|
||||
}));
|
||||
|
||||
cx.spawn(async move |this, cx| match rx.await {
|
||||
Ok(Err(e)) => {
|
||||
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
|
||||
.log_err();
|
||||
Err(e)?
|
||||
async move {
|
||||
match rx.await {
|
||||
Ok(Err(e)) => Err(e)?,
|
||||
_ => Ok(()),
|
||||
}
|
||||
_ => {
|
||||
this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
|
||||
.log_err();
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
@@ -1616,16 +1597,9 @@ mod tests {
|
||||
name: "test",
|
||||
connection,
|
||||
child_status: io_task,
|
||||
current_thread: thread_rc,
|
||||
auth_methods: [acp::AuthMethod {
|
||||
id: acp::AuthMethodId("acp-old-no-id".into()),
|
||||
label: "Log in".into(),
|
||||
description: None,
|
||||
}],
|
||||
};
|
||||
|
||||
AcpThread::new(
|
||||
"test",
|
||||
Rc::new(connection),
|
||||
project,
|
||||
acp::SessionId("test".into()),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc};
|
||||
use std::{path::Path, rc::Rc};
|
||||
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use project::Project;
|
||||
@@ -9,6 +9,8 @@ use ui::App;
|
||||
use crate::AcpThread;
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
@@ -16,21 +18,9 @@ pub trait AgentConnection {
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,10 @@ use anyhow::{Context as _, Result};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use std::{cell::RefCell, path::Path, rc::Rc};
|
||||
use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
|
||||
use ui::App;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AcpThread, AgentConnection, AuthRequired};
|
||||
use crate::{AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OldAcpClientDelegate {
|
||||
@@ -47,7 +46,7 @@ impl acp_old::Client for OldAcpClientDelegate {
|
||||
thread.push_assistant_content_block(thought.into(), true, cx)
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
.ok();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
@@ -351,15 +350,27 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Unauthenticated;
|
||||
|
||||
impl Error for Unauthenticated {}
|
||||
impl fmt::Display for Unauthenticated {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Unauthenticated")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OldAcpAgentConnection {
|
||||
pub name: &'static str,
|
||||
pub connection: acp_old::AgentConnection,
|
||||
pub child_status: Task<Result<()>>,
|
||||
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
|
||||
pub auth_methods: [acp::AuthMethod; 1],
|
||||
}
|
||||
|
||||
impl AgentConnection for OldAcpAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
@@ -372,31 +383,25 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||
}
|
||||
.into_any(),
|
||||
);
|
||||
let current_thread = self.current_thread.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let result = task.await?;
|
||||
let result = acp_old::InitializeParams::response_from_any(result)?;
|
||||
|
||||
if !result.is_authenticated {
|
||||
anyhow::bail!(AuthRequired)
|
||||
anyhow::bail!(Unauthenticated)
|
||||
}
|
||||
|
||||
cx.update(|cx| {
|
||||
let thread = cx.new(|cx| {
|
||||
let session_id = acp::SessionId("acp-old-no-id".into());
|
||||
AcpThread::new("Gemini", self.clone(), project, session_id, cx)
|
||||
AcpThread::new(self.clone(), project, session_id, cx)
|
||||
});
|
||||
current_thread.replace(thread.downgrade());
|
||||
thread
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&self.auth_methods
|
||||
}
|
||||
|
||||
fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
let task = self
|
||||
.connection
|
||||
.request_any(acp_old::AuthenticateParams.into_any());
|
||||
@@ -406,7 +411,7 @@ impl AgentConnection for OldAcpAgentConnection {
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
let chunks = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
|
||||
@@ -25,7 +25,6 @@ assistant_context.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
context_server.workspace = true
|
||||
@@ -36,9 +35,9 @@ futures.workspace = true
|
||||
git.workspace = true
|
||||
gpui.workspace = true
|
||||
heed.workspace = true
|
||||
http_client.workspace = true
|
||||
icons.workspace = true
|
||||
indoc.workspace = true
|
||||
http_client.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
@@ -47,6 +46,7 @@ paths.workspace = true
|
||||
postage.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
ref-cast.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
@@ -63,6 +63,7 @@ time.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
zstd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -13,7 +13,6 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
|
||||
use collections::HashMap;
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::{FutureExt, StreamExt as _, future::Shared};
|
||||
@@ -37,6 +36,7 @@ use project::{
|
||||
git_store::{GitStore, GitStoreCheckpoint, RepositoryState},
|
||||
};
|
||||
use prompt_store::{ModelContext, PromptBuilder};
|
||||
use proto::Plan;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
@@ -49,6 +49,7 @@ use std::{
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
|
||||
const MAX_RETRY_ATTEMPTS: u8 = 4;
|
||||
const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
|
||||
@@ -1680,7 +1681,7 @@ impl Thread {
|
||||
|
||||
let completion_mode = request
|
||||
.mode
|
||||
.unwrap_or(cloud_llm_client::CompletionMode::Normal);
|
||||
.unwrap_or(zed_llm_client::CompletionMode::Normal);
|
||||
|
||||
self.last_received_chunk_at = Some(Instant::now());
|
||||
|
||||
@@ -3254,10 +3255,8 @@ impl Thread {
|
||||
}
|
||||
|
||||
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context<Self>) {
|
||||
self.project
|
||||
.read(cx)
|
||||
.user_store()
|
||||
.update(cx, |user_store, cx| {
|
||||
self.project.update(cx, |project, cx| {
|
||||
project.user_store().update(cx, |user_store, cx| {
|
||||
user_store.update_model_request_usage(
|
||||
ModelRequestUsage(RequestUsage {
|
||||
amount: amount as i32,
|
||||
@@ -3265,7 +3264,8 @@ impl Thread {
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn deny_tool_use(
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
use agent_client_protocol::{self as acp, Agent as _};
|
||||
use collections::HashMap;
|
||||
use futures::channel::oneshot;
|
||||
use project::Project;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::AgentServerCommand;
|
||||
use acp_thread::{AcpThread, AgentConnection, AuthRequired};
|
||||
|
||||
pub struct AcpConnection {
|
||||
server_name: &'static str,
|
||||
connection: Rc<acp::ClientSideConnection>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
auth_methods: Vec<acp::AuthMethod>,
|
||||
_io_task: Task<Result<()>>,
|
||||
}
|
||||
|
||||
pub struct AcpSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
}
|
||||
|
||||
impl AcpConnection {
|
||||
pub async fn stdio(
|
||||
server_name: &'static str,
|
||||
command: AgentServerCommand,
|
||||
root_dir: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||
.envs(command.env.iter().flatten())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdout = child.stdout.take().expect("Failed to take stdout");
|
||||
let stdin = child.stdin.take().expect("Failed to take stdin");
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let client = ClientDelegate {
|
||||
sessions: sessions.clone(),
|
||||
cx: cx.clone(),
|
||||
};
|
||||
let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
move |fut| {
|
||||
foreground_executor.spawn(fut).detach();
|
||||
}
|
||||
});
|
||||
|
||||
let io_task = cx.background_spawn(io_task);
|
||||
|
||||
let response = connection
|
||||
.initialize(acp::InitializeRequest {
|
||||
protocol_version: acp::VERSION,
|
||||
client_capabilities: acp::ClientCapabilities {
|
||||
fs: acp::FileSystemCapability {
|
||||
read_text_file: true,
|
||||
write_text_file: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
|
||||
// todo! check version
|
||||
|
||||
Ok(Self {
|
||||
auth_methods: response.auth_methods,
|
||||
connection: connection.into(),
|
||||
server_name,
|
||||
sessions,
|
||||
_io_task: io_task,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for AcpConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let conn = self.connection.clone();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let response = conn
|
||||
.new_session(acp::NewSessionRequest {
|
||||
// todo! Zed MCP server?
|
||||
mcp_servers: vec![],
|
||||
cwd,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let Some(session_id) = response.session_id else {
|
||||
anyhow::bail!(AuthRequired);
|
||||
};
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new(
|
||||
self.server_name,
|
||||
self.clone(),
|
||||
project,
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let session = AcpSession {
|
||||
thread: thread.downgrade(),
|
||||
};
|
||||
sessions.borrow_mut().insert(session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&self.auth_methods
|
||||
}
|
||||
|
||||
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let result = conn
|
||||
.authenticate(acp::AuthenticateRequest {
|
||||
method_id: method_id.clone(),
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
let conn = self.connection.clone();
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { Ok(conn.prompt(params).await?) })
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
|
||||
let conn = self.connection.clone();
|
||||
let params = acp::CancelledNotification {
|
||||
session_id: session_id.clone(),
|
||||
};
|
||||
cx.foreground_executor()
|
||||
.spawn(async move { conn.cancelled(params).await })
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientDelegate {
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
cx: AsyncApp,
|
||||
}
|
||||
|
||||
impl acp::Client for ClientDelegate {
|
||||
async fn request_permission(
|
||||
&self,
|
||||
arguments: acp::RequestPermissionRequest,
|
||||
) -> Result<acp::RequestPermissionResponse, acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let result = self
|
||||
.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
|
||||
};
|
||||
|
||||
Ok(acp::RequestPermissionResponse { outcome })
|
||||
}
|
||||
|
||||
async fn write_text_file(
|
||||
&self,
|
||||
arguments: acp::WriteTextFileRequest,
|
||||
) -> Result<(), acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
self.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(arguments.path, arguments.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_text_file(
|
||||
&self,
|
||||
arguments: acp::ReadTextFileRequest,
|
||||
) -> Result<acp::ReadTextFileResponse, acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let content = self
|
||||
.sessions
|
||||
.borrow()
|
||||
.get(&arguments.session_id)
|
||||
.context("Failed to get session")?
|
||||
.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(acp::ReadTextFileResponse { content })
|
||||
}
|
||||
|
||||
async fn session_notification(
|
||||
&self,
|
||||
notification: acp::SessionNotification,
|
||||
) -> Result<(), acp::Error> {
|
||||
let cx = &mut self.cx.clone();
|
||||
let sessions = self.sessions.borrow();
|
||||
let session = sessions
|
||||
.get(¬ification.session_id)
|
||||
.context("Failed to get session")?;
|
||||
|
||||
session.thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,14 @@
|
||||
mod acp_connection;
|
||||
mod claude;
|
||||
mod codex;
|
||||
mod gemini;
|
||||
mod mcp_server;
|
||||
mod settings;
|
||||
|
||||
#[cfg(test)]
|
||||
mod e2e_tests;
|
||||
|
||||
pub use claude::*;
|
||||
pub use codex::*;
|
||||
pub use gemini::*;
|
||||
pub use settings::*;
|
||||
|
||||
@@ -36,6 +38,7 @@ pub trait AgentServer: Send {
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
// these will go away when old_acp is fully removed
|
||||
root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
|
||||
@@ -70,6 +70,10 @@ struct ClaudeAgentConnection {
|
||||
}
|
||||
|
||||
impl AgentConnection for ClaudeAgentConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
ClaudeCode.name()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
@@ -164,9 +168,8 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
}
|
||||
});
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
|
||||
})?;
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
@@ -183,15 +186,11 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
})
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
&[]
|
||||
}
|
||||
|
||||
fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
|
||||
fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
|
||||
let sessions = self.sessions.borrow();
|
||||
let Some(session) = sessions.get(¶ms.session_id) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
|
||||
317
crates/agent_servers/src/codex.rs
Normal file
@@ -0,0 +1,317 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use context_server::listener::McpServerTool;
|
||||
use context_server::types::requests;
|
||||
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use smol::stream::StreamExt as _;
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, sync::Arc};
|
||||
use util::ResultExt;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
|
||||
use crate::mcp_server::ZedMcpServer;
|
||||
use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
|
||||
use acp_thread::{AcpThread, AgentConnection};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Codex;
|
||||
|
||||
impl AgentServer for Codex {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
"Welcome to Codex"
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
"What can I help with?"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::AiOpenAi
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let project = project.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).codex.clone()
|
||||
})?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find codex binary");
|
||||
};
|
||||
|
||||
let client: Arc<ContextServer> = ContextServer::stdio(
|
||||
ContextServerId("codex-mcp-server".into()),
|
||||
ContextServerCommand {
|
||||
path: command.path,
|
||||
args: command.args,
|
||||
env: command.env,
|
||||
},
|
||||
)
|
||||
.into();
|
||||
ContextServer::start(client.clone(), cx).await?;
|
||||
|
||||
let (notification_tx, mut notification_rx) = mpsc::unbounded();
|
||||
client
|
||||
.client()
|
||||
.context("Failed to subscribe")?
|
||||
.on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
|
||||
move |notification, _cx| {
|
||||
let notification_tx = notification_tx.clone();
|
||||
log::trace!(
|
||||
"ACP Notification: {}",
|
||||
serde_json::to_string_pretty(¬ification).unwrap()
|
||||
);
|
||||
|
||||
if let Some(notification) =
|
||||
serde_json::from_value::<acp::SessionNotification>(notification)
|
||||
.log_err()
|
||||
{
|
||||
notification_tx.unbounded_send(notification).ok();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sessions = Rc::new(RefCell::new(HashMap::default()));
|
||||
|
||||
let notification_handler_task = cx.spawn({
|
||||
let sessions = sessions.clone();
|
||||
async move |cx| {
|
||||
while let Some(notification) = notification_rx.next().await {
|
||||
CodexConnection::handle_session_notification(
|
||||
notification,
|
||||
sessions.clone(),
|
||||
cx,
|
||||
)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let connection = CodexConnection {
|
||||
client,
|
||||
sessions,
|
||||
_notification_handler_task: notification_handler_task,
|
||||
};
|
||||
Ok(Rc::new(connection) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct CodexConnection {
|
||||
client: Arc<context_server::ContextServer>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
_notification_handler_task: Task<()>,
|
||||
}
|
||||
|
||||
struct CodexSession {
|
||||
thread: WeakEntity<AcpThread>,
|
||||
cancel_tx: Option<oneshot::Sender<()>>,
|
||||
_mcp_server: ZedMcpServer,
|
||||
}
|
||||
|
||||
impl AgentConnection for CodexConnection {
|
||||
fn name(&self) -> &'static str {
|
||||
"Codex"
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
let cwd = cwd.to_path_buf();
|
||||
cx.spawn(async move |cx| {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
|
||||
|
||||
let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
|
||||
|
||||
let response = client
|
||||
.request::<requests::CallTool>(context_server::types::CallToolParams {
|
||||
name: acp::NEW_SESSION_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(acp::NewSessionArguments {
|
||||
mcp_servers: [(
|
||||
mcp_server::SERVER_NAME.to_string(),
|
||||
mcp_server.server_config()?,
|
||||
)]
|
||||
.into(),
|
||||
client_tools: acp::ClientTools {
|
||||
request_permission: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::RequestPermissionTool::NAME.into(),
|
||||
}),
|
||||
read_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::ReadTextFileTool::NAME.into(),
|
||||
}),
|
||||
write_text_file: Some(acp::McpToolId {
|
||||
mcp_server: mcp_server::SERVER_NAME.into(),
|
||||
tool_name: mcp_server::WriteTextFileTool::NAME.into(),
|
||||
}),
|
||||
},
|
||||
cwd,
|
||||
})?),
|
||||
meta: None,
|
||||
})
|
||||
.await?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
let result = serde_json::from_value::<acp::NewSessionOutput>(
|
||||
response.structured_content.context("Empty response")?,
|
||||
)?;
|
||||
|
||||
let thread =
|
||||
cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
|
||||
|
||||
thread_tx.send(thread.downgrade())?;
|
||||
|
||||
let session = CodexSession {
|
||||
thread: thread.downgrade(),
|
||||
cancel_tx: None,
|
||||
_mcp_server: mcp_server,
|
||||
};
|
||||
sessions.borrow_mut().insert(result.session_id, session);
|
||||
|
||||
Ok(thread)
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(Err(anyhow!("Authentication not supported")))
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
params: agent_client_protocol::PromptArguments,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
let client = self.client.client();
|
||||
let sessions = self.sessions.clone();
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let client = client.context("MCP server is not initialized yet")?;
|
||||
|
||||
let (new_cancel_tx, cancel_rx) = oneshot::channel();
|
||||
{
|
||||
let mut sessions = sessions.borrow_mut();
|
||||
let session = sessions
|
||||
.get_mut(¶ms.session_id)
|
||||
.context("Session not found")?;
|
||||
session.cancel_tx.replace(new_cancel_tx);
|
||||
}
|
||||
|
||||
let result = client
|
||||
.request_with::<requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: acp::PROMPT_TOOL_NAME.into(),
|
||||
arguments: Some(serde_json::to_value(params)?),
|
||||
meta: None,
|
||||
},
|
||||
Some(cancel_rx),
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(err) = &result
|
||||
&& err.is::<context_server::client::RequestCanceled>()
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let response = result?;
|
||||
|
||||
if response.is_error.unwrap_or_default() {
|
||||
return Err(anyhow!(response.text_contents()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
|
||||
let mut sessions = self.sessions.borrow_mut();
|
||||
|
||||
if let Some(cancel_tx) = sessions
|
||||
.get_mut(session_id)
|
||||
.and_then(|session| session.cancel_tx.take())
|
||||
{
|
||||
cancel_tx.send(()).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexConnection {
|
||||
pub fn handle_session_notification(
|
||||
notification: acp::SessionNotification,
|
||||
threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let threads = threads.borrow();
|
||||
let Some(thread) = threads
|
||||
.get(¬ification.session_id)
|
||||
.and_then(|session| session.thread.upgrade())
|
||||
else {
|
||||
log::error!(
|
||||
"Thread not found for session ID: {}",
|
||||
notification.session_id
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(notification.update, cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CodexConnection {
|
||||
fn drop(&mut self) {
|
||||
self.client.stop().log_err();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
use crate::AgentServerCommand;
|
||||
use std::path::Path;
|
||||
|
||||
crate::common_e2e_tests!(Codex, allow_option_id = "approve");
|
||||
|
||||
pub fn local_command() -> AgentServerCommand {
|
||||
let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../../../codex/codex-rs/target/debug/codex");
|
||||
|
||||
AgentServerCommand {
|
||||
path: cli_path,
|
||||
args: vec!["mcp".into()],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{Entity, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use util::path;
|
||||
|
||||
@@ -26,11 +27,7 @@ pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppCont
|
||||
.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(
|
||||
thread.entries().len() >= 2,
|
||||
"Expected at least 2 entries. Got: {:?}",
|
||||
thread.entries()
|
||||
);
|
||||
assert_eq!(thread.entries().len(), 2);
|
||||
assert!(matches!(
|
||||
thread.entries()[0],
|
||||
AgentThreadEntry::UserMessage(_)
|
||||
@@ -111,19 +108,19 @@ pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut Tes
|
||||
}
|
||||
|
||||
pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let _fs = init_test(cx).await;
|
||||
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let foo_path = tempdir.path().join("foo");
|
||||
std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
|
||||
|
||||
let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
|
||||
let fs = init_test(cx).await;
|
||||
fs.insert_tree(
|
||||
path!("/private/tmp"),
|
||||
json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
&format!("Read {} and tell me what you see.", foo_path.display()),
|
||||
"Read the '/private/tmp/foo' file and tell me what you see.",
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -146,8 +143,6 @@ pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestApp
|
||||
.any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
|
||||
);
|
||||
});
|
||||
|
||||
drop(tempdir);
|
||||
}
|
||||
|
||||
pub async fn test_tool_call_with_confirmation(
|
||||
@@ -160,7 +155,7 @@ pub async fn test_tool_call_with_confirmation(
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
|
||||
r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@@ -180,10 +175,10 @@ pub async fn test_tool_call_with_confirmation(
|
||||
)
|
||||
.await;
|
||||
|
||||
let tool_call_id = thread.read_with(cx, |thread, cx| {
|
||||
let tool_call_id = thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
label,
|
||||
content,
|
||||
status: ToolCallStatus::WaitingForConfirmation { .. },
|
||||
..
|
||||
}) = &thread
|
||||
@@ -195,8 +190,7 @@ pub async fn test_tool_call_with_confirmation(
|
||||
panic!();
|
||||
};
|
||||
|
||||
let label = label.read(cx).source();
|
||||
assert!(label.contains("touch"), "Got: {}", label);
|
||||
assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
|
||||
|
||||
id.clone()
|
||||
});
|
||||
@@ -248,7 +242,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
|
||||
r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@@ -268,10 +262,10 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
)
|
||||
.await;
|
||||
|
||||
thread.read_with(cx, |thread, cx| {
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let AgentThreadEntry::ToolCall(ToolCall {
|
||||
id,
|
||||
label,
|
||||
content,
|
||||
status: ToolCallStatus::WaitingForConfirmation { .. },
|
||||
..
|
||||
}) = &thread.entries()[first_tool_call_ix]
|
||||
@@ -279,8 +273,7 @@ pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppCon
|
||||
panic!("{:?}", thread.entries()[1]);
|
||||
};
|
||||
|
||||
let label = label.read(cx).source();
|
||||
assert!(label.contains("touch"), "Got: {}", label);
|
||||
assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
|
||||
|
||||
id.clone()
|
||||
});
|
||||
@@ -375,6 +368,9 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
gemini: Some(AgentServerSettings {
|
||||
command: crate::gemini::tests::local_command(),
|
||||
}),
|
||||
codex: Some(AgentServerSettings {
|
||||
command: crate::codex::tests::local_command(),
|
||||
}),
|
||||
},
|
||||
cx,
|
||||
);
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
use anyhow::anyhow;
|
||||
use std::cell::RefCell;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{AgentServer, AgentServerCommand, acp_connection::AcpConnection};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::Result;
|
||||
use gpui::{Entity, Task};
|
||||
use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
|
||||
use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
|
||||
use agentic_coding_protocol as acp_old;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
use ui::App;
|
||||
@@ -39,27 +43,145 @@ impl AgentServer for Gemini {
|
||||
project: &Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Rc<dyn AgentConnection>>> {
|
||||
let project = project.clone();
|
||||
let server_name = self.name();
|
||||
let root_dir = root_dir.to_path_buf();
|
||||
let project = project.clone();
|
||||
let this = self.clone();
|
||||
let name = self.name();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
let command = this.command(&project, cx).await?;
|
||||
|
||||
let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
else {
|
||||
anyhow::bail!("Failed to find gemini binary");
|
||||
};
|
||||
// todo! check supported version
|
||||
let mut child = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.current_dir(root_dir)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let conn = AcpConnection::stdio(server_name, command, &root_dir, cx).await?;
|
||||
Ok(Rc::new(conn) as _)
|
||||
let stdin = child.stdin.take().unwrap();
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
|
||||
let foreground_executor = cx.foreground_executor().clone();
|
||||
|
||||
let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
|
||||
|
||||
let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
|
||||
OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
|
||||
stdin,
|
||||
stdout,
|
||||
move |fut| foreground_executor.spawn(fut).detach(),
|
||||
);
|
||||
|
||||
let io_task = cx.background_spawn(async move {
|
||||
io_fut.await.log_err();
|
||||
});
|
||||
|
||||
let child_status = cx.background_spawn(async move {
|
||||
let result = match child.status().await {
|
||||
Err(e) => Err(anyhow!(e)),
|
||||
Ok(result) if result.success() => Ok(()),
|
||||
Ok(result) => {
|
||||
if let Some(AgentServerVersion::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command,
|
||||
}) = this.version(&command).await.log_err()
|
||||
{
|
||||
Err(anyhow!(LoadError::Unsupported {
|
||||
error_message,
|
||||
upgrade_message,
|
||||
upgrade_command
|
||||
}))
|
||||
} else {
|
||||
Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
|
||||
}
|
||||
}
|
||||
};
|
||||
drop(io_task);
|
||||
result
|
||||
});
|
||||
|
||||
let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
|
||||
name,
|
||||
connection,
|
||||
child_status,
|
||||
});
|
||||
|
||||
Ok(connection)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Gemini {
|
||||
async fn command(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<AgentServerCommand> {
|
||||
let settings = cx.read_global(|settings: &SettingsStore, _| {
|
||||
settings.get::<AllAgentServersSettings>(None).gemini.clone()
|
||||
})?;
|
||||
|
||||
if let Some(command) =
|
||||
AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
|
||||
{
|
||||
return Ok(command);
|
||||
};
|
||||
|
||||
let (fs, node_runtime) = project.update(cx, |project, _| {
|
||||
(project.fs().clone(), project.node_runtime().cloned())
|
||||
})?;
|
||||
let node_runtime = node_runtime.context("gemini not found on path")?;
|
||||
|
||||
let directory = ::paths::agent_servers_dir().join("gemini");
|
||||
fs.create_dir(&directory).await?;
|
||||
node_runtime
|
||||
.npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
|
||||
.await?;
|
||||
let path = directory.join("node_modules/.bin/gemini");
|
||||
|
||||
Ok(AgentServerCommand {
|
||||
path,
|
||||
args: vec![ACP_ARG.into()],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
|
||||
let version_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--version")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let help_fut = util::command::new_smol_command(&command.path)
|
||||
.args(command.args.iter())
|
||||
.arg("--help")
|
||||
.kill_on_drop(true)
|
||||
.output();
|
||||
|
||||
let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
|
||||
|
||||
let current_version = String::from_utf8(version_output?.stdout)?;
|
||||
let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
|
||||
|
||||
if supported {
|
||||
Ok(AgentServerVersion::Supported)
|
||||
} else {
|
||||
Ok(AgentServerVersion::Unsupported {
|
||||
error_message: format!(
|
||||
"Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
|
||||
current_version
|
||||
).into(),
|
||||
upgrade_message: "Upgrade Gemini to Latest".into(),
|
||||
upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use super::*;
|
||||
@@ -76,7 +198,7 @@ pub(crate) mod tests {
|
||||
|
||||
AgentServerCommand {
|
||||
path: "node".into(),
|
||||
args: vec![cli_path],
|
||||
args: vec![cli_path, ACP_ARG.into()],
|
||||
env: None,
|
||||
}
|
||||
}
|
||||
|
||||
207
crates/agent_servers/src/mcp_server.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
use context_server::types::{
|
||||
Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
|
||||
ToolsCapabilities, requests,
|
||||
};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use indoc::indoc;
|
||||
|
||||
pub struct ZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
}
|
||||
|
||||
pub const SERVER_NAME: &str = "zed";
|
||||
|
||||
impl ZedMcpServer {
|
||||
pub async fn new(
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
|
||||
mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
|
||||
|
||||
mcp_server.add_tool(RequestPermissionTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(ReadTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
mcp_server.add_tool(WriteTextFileTool {
|
||||
thread_rx: thread_rx.clone(),
|
||||
});
|
||||
|
||||
Ok(Self { server: mcp_server })
|
||||
}
|
||||
|
||||
pub fn server_config(&self) -> Result<acp::McpServerConfig> {
|
||||
#[cfg(not(test))]
|
||||
let zed_path = anyhow::Context::context(
|
||||
std::env::current_exe(),
|
||||
"finding current executable path for use in mcp_server",
|
||||
)?;
|
||||
|
||||
#[cfg(test)]
|
||||
let zed_path = crate::e2e_tests::get_zed_path();
|
||||
|
||||
Ok(acp::McpServerConfig {
|
||||
command: zed_path,
|
||||
args: vec![
|
||||
"--nc".into(),
|
||||
self.server.socket_path().display().to_string(),
|
||||
],
|
||||
env: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
Ok(InitializeResponse {
|
||||
protocol_version: ProtocolVersion("2025-06-18".into()),
|
||||
capabilities: ServerCapabilities {
|
||||
experimental: None,
|
||||
logging: None,
|
||||
completions: None,
|
||||
prompts: None,
|
||||
resources: None,
|
||||
tools: Some(ToolsCapabilities {
|
||||
list_changed: Some(false),
|
||||
}),
|
||||
},
|
||||
server_info: Implementation {
|
||||
name: SERVER_NAME.into(),
|
||||
version: "0.1.0".into(),
|
||||
},
|
||||
meta: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Tools
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RequestPermissionTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for RequestPermissionTool {
|
||||
type Input = acp::RequestPermissionArguments;
|
||||
type Output = acp::RequestPermissionOutput;
|
||||
|
||||
const NAME: &'static str = "Confirmation";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
indoc! {"
|
||||
Request permission for tool calls.
|
||||
|
||||
This tool is meant to be called programmatically by the agent loop, not the LLM.
|
||||
"}
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let result = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(input.tool_call, input.options, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let outcome = match result {
|
||||
Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
|
||||
Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
|
||||
};
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::RequestPermissionOutput { outcome },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ReadTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for ReadTextFileTool {
|
||||
type Input = acp::ReadTextFileArguments;
|
||||
type Output = acp::ReadTextFileOutput;
|
||||
|
||||
const NAME: &'static str = "Read";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Reads the content of the given file in the project including unsaved changes."
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
let content = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.read_text_file(input.path, input.line, input.limit, false, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: acp::ReadTextFileOutput { content },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WriteTextFileTool {
|
||||
thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
|
||||
}
|
||||
|
||||
impl McpServerTool for WriteTextFileTool {
|
||||
type Input = acp::WriteTextFileArguments;
|
||||
type Output = ();
|
||||
|
||||
const NAME: &'static str = "Write";
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Write to a file replacing its contents"
|
||||
}
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
};
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.write_text_file(input.path, input.content, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(ToolResponse {
|
||||
content: vec![],
|
||||
structured_content: (),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ pub fn init(cx: &mut App) {
|
||||
pub struct AllAgentServersSettings {
|
||||
pub gemini: Option<AgentServerSettings>,
|
||||
pub claude: Option<AgentServerSettings>,
|
||||
pub codex: Option<AgentServerSettings>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
|
||||
@@ -29,13 +30,21 @@ impl settings::Settings for AllAgentServersSettings {
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
let mut settings = AllAgentServersSettings::default();
|
||||
|
||||
for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
|
||||
for AllAgentServersSettings {
|
||||
gemini,
|
||||
claude,
|
||||
codex,
|
||||
} in sources.defaults_and_customizations()
|
||||
{
|
||||
if gemini.is_some() {
|
||||
settings.gemini = gemini.clone();
|
||||
}
|
||||
if claude.is_some() {
|
||||
settings.claude = claude.clone();
|
||||
}
|
||||
if codex.is_some() {
|
||||
settings.codex = codex.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
||||
@@ -13,7 +13,6 @@ path = "src/agent_settings.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
@@ -21,6 +20,7 @@ schemars.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
fs.workspace = true
|
||||
|
||||
@@ -321,11 +321,11 @@ pub enum CompletionMode {
|
||||
Burn,
|
||||
}
|
||||
|
||||
impl From<CompletionMode> for cloud_llm_client::CompletionMode {
|
||||
impl From<CompletionMode> for zed_llm_client::CompletionMode {
|
||||
fn from(value: CompletionMode) -> Self {
|
||||
match value {
|
||||
CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal,
|
||||
CompletionMode::Burn => cloud_llm_client::CompletionMode::Max,
|
||||
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
|
||||
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ audio.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -47,9 +46,9 @@ futures.workspace = true
|
||||
fuzzy.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
indoc.workspace = true
|
||||
http_client.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
indoc.workspace = true
|
||||
inventory.workspace = true
|
||||
itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
@@ -98,6 +97,7 @@ watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
assistant_tools.workspace = true
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
use acp_thread::{AgentConnection, Plan};
|
||||
use agent_servers::AgentServer;
|
||||
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
|
||||
use audio::{Audio, Sound};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
@@ -20,10 +18,10 @@ use editor::{
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, App, BorderStyle, EdgesRefinement, Empty, Entity, EntityId,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, PlatformDisplay, SharedString,
|
||||
StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement, Transformation,
|
||||
UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop, linear_gradient,
|
||||
list, percentage, point, prelude::*, pulsating_between,
|
||||
FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, SharedString, StyleRefinement,
|
||||
Subscription, Task, TextStyle, TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity,
|
||||
Window, div, linear_color_stop, linear_gradient, list, percentage, point, prelude::*,
|
||||
pulsating_between,
|
||||
};
|
||||
use language::language_settings::SoftWrap;
|
||||
use language::{Buffer, Language};
|
||||
@@ -47,10 +45,7 @@ use crate::acp::completion_provider::{ContextPickerCompletionProvider, MentionSe
|
||||
use crate::acp::message_history::MessageHistory;
|
||||
use crate::agent_diff::AgentDiff;
|
||||
use crate::message_editor::{MAX_EDITOR_LINES, MIN_EDITOR_LINES};
|
||||
use crate::ui::{AgentNotification, AgentNotificationEvent};
|
||||
use crate::{
|
||||
AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
|
||||
};
|
||||
use crate::{AgentDiffPane, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll};
|
||||
|
||||
const RESPONSE_PADDING_X: Pixels = px(19.);
|
||||
|
||||
@@ -64,8 +59,6 @@ pub struct AcpThreadView {
|
||||
message_set_from_history: bool,
|
||||
_message_editor_subscription: Subscription,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
notifications: Vec<WindowHandle<AgentNotification>>,
|
||||
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
|
||||
last_error: Option<Entity<Markdown>>,
|
||||
list_state: ListState,
|
||||
auth_task: Option<Task<()>>,
|
||||
@@ -181,8 +174,6 @@ impl AcpThreadView {
|
||||
message_set_from_history: false,
|
||||
_message_editor_subscription: message_editor_subscription,
|
||||
mention_set,
|
||||
notifications: Vec::new(),
|
||||
notification_subscriptions: HashMap::default(),
|
||||
diff_editors: Default::default(),
|
||||
list_state: list_state,
|
||||
last_error: None,
|
||||
@@ -232,8 +223,7 @@ impl AcpThreadView {
|
||||
{
|
||||
Err(e) => {
|
||||
let mut cx = cx.clone();
|
||||
// todo! remove duplication
|
||||
if e.downcast_ref::<acp_thread::AuthRequired>().is_some() {
|
||||
if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.thread_state = ThreadState::Unauthenticated { connection };
|
||||
cx.notify();
|
||||
@@ -391,9 +381,7 @@ impl AcpThreadView {
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
let Some(thread) = self.thread() else { return };
|
||||
let task = thread.update(cx, |thread, cx| thread.send(chunks.clone(), cx));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -576,30 +564,6 @@ impl AcpThreadView {
|
||||
self.sync_thread_entry_view(index, window, cx);
|
||||
self.list_state.splice(index..index + 1, 1);
|
||||
}
|
||||
AcpThreadEvent::ToolAuthorizationRequired => {
|
||||
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
|
||||
}
|
||||
AcpThreadEvent::Stopped => {
|
||||
let used_tools = thread.read(cx).used_tools_since_last_user_message();
|
||||
self.notify_with_sound(
|
||||
if used_tools {
|
||||
"Finished running tools"
|
||||
} else {
|
||||
"New message"
|
||||
},
|
||||
IconName::ZedAssistant,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
AcpThreadEvent::Error => {
|
||||
self.notify_with_sound(
|
||||
"Agent stopped due to an error",
|
||||
IconName::Warning,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -676,18 +640,13 @@ impl AcpThreadView {
|
||||
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&mut self,
|
||||
method: acp::AuthMethodId,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.last_error.take();
|
||||
let authenticate = connection.authenticate(method, cx);
|
||||
let authenticate = connection.authenticate(cx);
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let project = self.project.clone();
|
||||
let agent = self.agent.clone();
|
||||
@@ -2201,154 +2160,6 @@ impl AcpThreadView {
|
||||
self.list_state.scroll_to(ListOffset::default());
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn notify_with_sound(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
icon: IconName,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.play_notification_sound(window, cx);
|
||||
self.show_notification(caption, icon, window, cx);
|
||||
}
|
||||
|
||||
fn play_notification_sound(&self, window: &Window, cx: &mut App) {
|
||||
let settings = AgentSettings::get_global(cx);
|
||||
if settings.play_sound_when_agent_done && !window.is_window_active() {
|
||||
Audio::play_sound(Sound::AgentDone, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn show_notification(
|
||||
&mut self,
|
||||
caption: impl Into<SharedString>,
|
||||
icon: IconName,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if window.is_window_active() || !self.notifications.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let title = self.title(cx);
|
||||
|
||||
match AgentSettings::get_global(cx).notify_when_agent_waiting {
|
||||
NotifyWhenAgentWaiting::PrimaryScreen => {
|
||||
if let Some(primary) = cx.primary_display() {
|
||||
self.pop_up(icon, caption.into(), title, window, primary, cx);
|
||||
}
|
||||
}
|
||||
NotifyWhenAgentWaiting::AllScreens => {
|
||||
let caption = caption.into();
|
||||
for screen in cx.displays() {
|
||||
self.pop_up(icon, caption.clone(), title.clone(), window, screen, cx);
|
||||
}
|
||||
}
|
||||
NotifyWhenAgentWaiting::Never => {
|
||||
// Don't show anything
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn pop_up(
|
||||
&mut self,
|
||||
icon: IconName,
|
||||
caption: SharedString,
|
||||
title: SharedString,
|
||||
window: &mut Window,
|
||||
screen: Rc<dyn PlatformDisplay>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let options = AgentNotification::window_options(screen, cx);
|
||||
|
||||
let project_name = self.workspace.upgrade().and_then(|workspace| {
|
||||
workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.read(cx)
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.map(|worktree| worktree.read(cx).root_name().to_string())
|
||||
});
|
||||
|
||||
if let Some(screen_window) = cx
|
||||
.open_window(options, |_, cx| {
|
||||
cx.new(|_| {
|
||||
AgentNotification::new(title.clone(), caption.clone(), icon, project_name)
|
||||
})
|
||||
})
|
||||
.log_err()
|
||||
{
|
||||
if let Some(pop_up) = screen_window.entity(cx).log_err() {
|
||||
self.notification_subscriptions
|
||||
.entry(screen_window)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(cx.subscribe_in(&pop_up, window, {
|
||||
|this, _, event, window, cx| match event {
|
||||
AgentNotificationEvent::Accepted => {
|
||||
let handle = window.window_handle();
|
||||
cx.activate(true);
|
||||
|
||||
let workspace_handle = this.workspace.clone();
|
||||
|
||||
// If there are multiple Zed windows, activate the correct one.
|
||||
cx.defer(move |cx| {
|
||||
handle
|
||||
.update(cx, |_view, window, _cx| {
|
||||
window.activate_window();
|
||||
|
||||
if let Some(workspace) = workspace_handle.upgrade() {
|
||||
workspace.update(_cx, |workspace, cx| {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
});
|
||||
|
||||
this.dismiss_notifications(cx);
|
||||
}
|
||||
AgentNotificationEvent::Dismissed => {
|
||||
this.dismiss_notifications(cx);
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
self.notifications.push(screen_window);
|
||||
|
||||
// If the user manually refocuses the original window, dismiss the popup.
|
||||
self.notification_subscriptions
|
||||
.entry(screen_window)
|
||||
.or_insert_with(Vec::new)
|
||||
.push({
|
||||
let pop_up_weak = pop_up.downgrade();
|
||||
|
||||
cx.observe_window_activation(window, move |_, window, cx| {
|
||||
if window.is_window_active() {
|
||||
if let Some(pop_up) = pop_up_weak.upgrade() {
|
||||
pop_up.update(cx, |_, cx| {
|
||||
cx.emit(AgentNotificationEvent::Dismissed);
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dismiss_notifications(&mut self, cx: &mut Context<Self>) {
|
||||
for window in self.notifications.drain(..) {
|
||||
window
|
||||
.update(cx, |_, window, _| {
|
||||
window.remove_window();
|
||||
})
|
||||
.ok();
|
||||
|
||||
self.notification_subscriptions.remove(&window);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AcpThreadView {
|
||||
@@ -2386,26 +2197,22 @@ impl Render for AcpThreadView {
|
||||
.on_action(cx.listener(Self::next_history_message))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.child(match &self.thread_state {
|
||||
ThreadState::Unauthenticated { connection } => v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(h_flex().mt_1p5().justify_center().children(
|
||||
connection.auth_methods().into_iter().map(|method| {
|
||||
Button::new(
|
||||
SharedString::from(method.id.0.clone()),
|
||||
method.label.clone(),
|
||||
)
|
||||
.on_click({
|
||||
let method_id = method.id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.authenticate(method_id.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
}),
|
||||
)),
|
||||
ThreadState::Unauthenticated { .. } => {
|
||||
v_flex()
|
||||
.p_2()
|
||||
.flex_1()
|
||||
.items_center()
|
||||
.justify_center()
|
||||
.child(self.render_pending_auth_state())
|
||||
.child(
|
||||
h_flex().mt_1p5().justify_center().child(
|
||||
Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.authenticate(window, cx)
|
||||
})),
|
||||
),
|
||||
)
|
||||
}
|
||||
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
|
||||
ThreadState::LoadError(e) => v_flex()
|
||||
.p_2()
|
||||
@@ -2634,341 +2441,3 @@ fn plan_label_markdown_style(
|
||||
..default_md_style
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use agent_client_protocol::SessionId;
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
|
||||
use rand::Rng;
|
||||
use settings::SettingsStore;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello", window, cx);
|
||||
});
|
||||
|
||||
cx.deactivate_window();
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
cx.windows()
|
||||
.iter()
|
||||
.any(|window| window.downcast::<AgentNotification>().is_some())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_error(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let (thread_view, cx) =
|
||||
setup_thread_view(StubAgentServer::new(SaboteurAgentConnection), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello", window, cx);
|
||||
});
|
||||
|
||||
cx.deactivate_window();
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
cx.windows()
|
||||
.iter()
|
||||
.any(|window| window.downcast::<AgentNotification>().is_some())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let tool_call_id = acp::ToolCallId("1".into());
|
||||
let tool_call = acp::ToolCall {
|
||||
id: tool_call_id.clone(),
|
||||
label: "Label".into(),
|
||||
kind: acp::ToolKind::Edit,
|
||||
status: acp::ToolCallStatus::Pending,
|
||||
content: vec!["hi".into()],
|
||||
locations: vec![],
|
||||
raw_input: None,
|
||||
};
|
||||
let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
|
||||
.with_permission_requests(HashMap::from_iter([(
|
||||
tool_call_id,
|
||||
vec![acp::PermissionOption {
|
||||
id: acp::PermissionOptionId("1".into()),
|
||||
label: "Allow".into(),
|
||||
kind: acp::PermissionOptionKind::AllowOnce,
|
||||
}],
|
||||
)]));
|
||||
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
|
||||
|
||||
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
|
||||
message_editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello", window, cx);
|
||||
});
|
||||
|
||||
cx.deactivate_window();
|
||||
|
||||
thread_view.update_in(cx, |thread_view, window, cx| {
|
||||
thread_view.chat(&Chat, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert!(
|
||||
cx.windows()
|
||||
.iter()
|
||||
.any(|window| window.downcast::<AgentNotification>().is_some())
|
||||
);
|
||||
}
|
||||
|
||||
async fn setup_thread_view(
|
||||
agent: impl AgentServer + 'static,
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<AcpThreadView>, &mut VisualTestContext) {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, [], cx).await;
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let thread_view = cx.update(|window, cx| {
|
||||
cx.new(|cx| {
|
||||
AcpThreadView::new(
|
||||
Rc::new(agent),
|
||||
workspace.downgrade(),
|
||||
project,
|
||||
Rc::new(RefCell::new(MessageHistory::default())),
|
||||
1,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
cx.run_until_parked();
|
||||
(thread_view, cx)
|
||||
}
|
||||
|
||||
struct StubAgentServer<C> {
|
||||
connection: C,
|
||||
}
|
||||
|
||||
impl<C> StubAgentServer<C> {
|
||||
fn new(connection: C) -> Self {
|
||||
Self { connection }
|
||||
}
|
||||
}
|
||||
|
||||
impl StubAgentServer<StubAgentConnection> {
|
||||
fn default() -> Self {
|
||||
Self::new(StubAgentConnection::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> AgentServer for StubAgentServer<C>
|
||||
where
|
||||
C: 'static + AgentConnection + Send + Clone,
|
||||
{
|
||||
fn logo(&self) -> ui::IconName {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn empty_state_headline(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn empty_state_message(&self) -> &'static str {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&self,
|
||||
_root_dir: &Path,
|
||||
_project: &Entity<Project>,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
|
||||
Task::ready(Ok(Rc::new(self.connection.clone())))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct StubAgentConnection {
|
||||
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
updates: Vec<acp::SessionUpdate>,
|
||||
}
|
||||
|
||||
impl StubAgentConnection {
|
||||
fn new(updates: Vec<acp::SessionUpdate>) -> Self {
|
||||
Self {
|
||||
updates,
|
||||
permission_requests: HashMap::default(),
|
||||
sessions: Arc::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_permission_requests(
|
||||
mut self,
|
||||
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
|
||||
) -> Self {
|
||||
self.permission_requests = permission_requests;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
let session_id = SessionId(
|
||||
rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
);
|
||||
let thread = cx
|
||||
.new(|cx| {
|
||||
AcpThread::new("New Thread", self.clone(), project, session_id.clone(), cx)
|
||||
})
|
||||
.unwrap();
|
||||
self.sessions.lock().insert(session_id, thread.downgrade());
|
||||
Task::ready(Ok(thread))
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
let sessions = self.sessions.lock();
|
||||
let thread = sessions.get(¶ms.session_id).unwrap();
|
||||
let mut tasks = vec![];
|
||||
for update in &self.updates {
|
||||
let thread = thread.clone();
|
||||
let update = update.clone();
|
||||
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
|
||||
&& let Some(options) = self.permission_requests.get(&tool_call.id)
|
||||
{
|
||||
Some((tool_call.clone(), options.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let task = cx.spawn(async move |cx| {
|
||||
if let Some((tool_call, options)) = permission_request {
|
||||
let permission = thread.update(cx, |thread, cx| {
|
||||
thread.request_tool_call_permission(
|
||||
tool_call.clone(),
|
||||
options.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
permission.await?;
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.handle_session_update(update.clone(), cx).unwrap();
|
||||
})?;
|
||||
anyhow::Ok(())
|
||||
});
|
||||
tasks.push(task);
|
||||
}
|
||||
cx.spawn(async move |_| {
|
||||
try_join_all(tasks).await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SaboteurAgentConnection;
|
||||
|
||||
impl AgentConnection for SaboteurAgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
_cwd: &Path,
|
||||
cx: &mut gpui::AsyncApp,
|
||||
) -> Task<gpui::Result<Entity<AcpThread>>> {
|
||||
Task::ready(Ok(cx
|
||||
.new(|cx| AcpThread::new("New Thread", self, project, SessionId("test".into()), cx))
|
||||
.unwrap()))
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[agent_client_protocol::AuthMethod] {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn authenticate(
|
||||
&self,
|
||||
_method: acp::AuthMethodId,
|
||||
_cx: &mut App,
|
||||
) -> Task<gpui::Result<()>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
|
||||
Task::ready(Err(anyhow::anyhow!("Error prompting")))
|
||||
}
|
||||
|
||||
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AgentSettings::register(cx);
|
||||
workspace::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
release_channel::init(SemanticVersion::default(), cx);
|
||||
EditorSettings::register(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
|
||||
use anyhow::Context as _;
|
||||
use assistant_tool::ToolUseStatus;
|
||||
use audio::{Audio, Sound};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
use editor::scroll::Autoscroll;
|
||||
@@ -53,6 +52,7 @@ use util::ResultExt as _;
|
||||
use util::markdown::MarkdownCodeBlock;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
|
||||
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;
|
||||
|
||||
@@ -7,7 +7,6 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use agent_settings::AgentSettings;
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use cloud_llm_client::Plan;
|
||||
use collections::HashMap;
|
||||
use context_server::ContextServerId;
|
||||
use extension::ExtensionManifest;
|
||||
@@ -26,6 +25,7 @@ use project::{
|
||||
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
|
||||
project_settings::{ContextServerSettings, ProjectSettings},
|
||||
};
|
||||
use proto::Plan;
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Chip, ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
|
||||
@@ -180,7 +180,7 @@ impl AgentConfiguration {
|
||||
let current_plan = if is_zed_provider {
|
||||
self.workspace
|
||||
.upgrade()
|
||||
.and_then(|workspace| workspace.read(cx).user_store().read(cx).plan())
|
||||
.and_then(|workspace| workspace.read(cx).user_store().read(cx).current_plan())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -406,9 +406,7 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"always-allow-tool-actions-switch",
|
||||
"Allow running commands without asking for confirmation",
|
||||
Some(
|
||||
"The agent can perform potentially destructive actions without asking for your confirmation.".into(),
|
||||
),
|
||||
"The agent can perform potentially destructive actions without asking for your confirmation.",
|
||||
always_allow_tool_actions,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -426,7 +424,7 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"single-file-review",
|
||||
"Enable single-file agent reviews",
|
||||
Some("Agent edits are also displayed in single-file editors for review.".into()),
|
||||
"Agent edits are also displayed in single-file editors for review.",
|
||||
single_file_review,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -444,9 +442,7 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"sound-notification",
|
||||
"Play sound when finished generating",
|
||||
Some(
|
||||
"Hear a notification sound when the agent is done generating changes or needs your input.".into(),
|
||||
),
|
||||
"Hear a notification sound when the agent is done generating changes or needs your input.",
|
||||
play_sound_when_agent_done,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -464,9 +460,7 @@ impl AgentConfiguration {
|
||||
SwitchField::new(
|
||||
"modifier-send",
|
||||
"Use modifier to submit a message",
|
||||
Some(
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.".into(),
|
||||
),
|
||||
"Make a modifier (cmd-enter on macOS, ctrl-enter on Linux) required to send messages.",
|
||||
use_modifier_to_send,
|
||||
move |state, _window, cx| {
|
||||
let allow = state == &ToggleState::Selected;
|
||||
@@ -508,7 +502,7 @@ impl AgentConfiguration {
|
||||
.blend(cx.theme().colors().text_accent.opacity(0.2));
|
||||
|
||||
let (plan_name, label_color, bg_color) = match plan {
|
||||
Plan::ZedFree => ("Free", Color::Default, free_chip_bg),
|
||||
Plan::Free => ("Free", Color::Default, free_chip_bg),
|
||||
Plan::ZedProTrial => ("Pro Trial", Color::Accent, pro_chip_bg),
|
||||
Plan::ZedPro => ("Pro", Color::Accent, pro_chip_bg),
|
||||
};
|
||||
|
||||
@@ -1521,9 +1521,6 @@ impl AgentDiff {
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
}
|
||||
AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Error => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{DisableAiSettings, UserStore, zed_urls};
|
||||
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
@@ -60,6 +59,7 @@ use language_model::{
|
||||
};
|
||||
use project::{Project, ProjectPath, Worktree};
|
||||
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
|
||||
use proto::Plan;
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use search::{BufferSearchBar, buffer_search};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -77,9 +77,10 @@ use workspace::{
|
||||
};
|
||||
use zed_actions::{
|
||||
DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize,
|
||||
agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector},
|
||||
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector},
|
||||
assistant::{OpenRulesLibrary, ToggleFocus},
|
||||
};
|
||||
use zed_llm_client::{CompletionIntent, UsageLimit};
|
||||
|
||||
const AGENT_PANEL_KEY: &str = "agent_panel";
|
||||
|
||||
@@ -104,7 +105,7 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.open_history(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &OpenSettings, window, cx| {
|
||||
.register_action(|workspace, _: &OpenConfiguration, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
|
||||
@@ -578,6 +579,7 @@ impl AgentPanel {
|
||||
MessageEditor::new(
|
||||
fs.clone(),
|
||||
workspace.clone(),
|
||||
user_store.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
@@ -846,6 +848,7 @@ impl AgentPanel {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
self.user_store.clone(),
|
||||
context_store.clone(),
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
@@ -1119,6 +1122,7 @@ impl AgentPanel {
|
||||
MessageEditor::new(
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
self.user_store.clone(),
|
||||
context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
@@ -1987,6 +1991,20 @@ impl AgentPanel {
|
||||
);
|
||||
}),
|
||||
)
|
||||
.item(
|
||||
ContextMenuEntry::new("New Codex Thread")
|
||||
.icon(IconName::AiOpenAi)
|
||||
.icon_color(Color::Muted)
|
||||
.handler(move |window, cx| {
|
||||
window.dispatch_action(
|
||||
NewExternalAgentThread {
|
||||
agent: Some(crate::ExternalAgent::Codex),
|
||||
}
|
||||
.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
});
|
||||
menu
|
||||
}))
|
||||
@@ -2070,7 +2088,7 @@ impl AgentPanel {
|
||||
|
||||
menu = menu
|
||||
.action("Rules…", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenSettings))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.action(zoom_in_label, Box::new(ToggleZoom));
|
||||
menu
|
||||
}))
|
||||
@@ -2275,10 +2293,10 @@ impl AgentPanel {
|
||||
| ActiveView::Configuration => return false,
|
||||
}
|
||||
|
||||
let plan = self.user_store.read(cx).plan();
|
||||
let plan = self.user_store.read(cx).current_plan();
|
||||
let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some();
|
||||
|
||||
matches!(plan, Some(Plan::ZedFree)) && has_previous_trial
|
||||
matches!(plan, Some(Plan::Free)) && has_previous_trial
|
||||
}
|
||||
|
||||
fn should_render_onboarding(&self, cx: &mut Context<Self>) -> bool {
|
||||
@@ -2464,14 +2482,14 @@ impl AgentPanel {
|
||||
.icon_color(Color::Muted)
|
||||
.full_width()
|
||||
.key_binding(KeyBinding::for_action_in(
|
||||
&OpenSettings,
|
||||
&OpenConfiguration,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
))
|
||||
.on_click(|_event, window, cx| {
|
||||
window.dispatch_action(
|
||||
OpenSettings.boxed_clone(),
|
||||
OpenConfiguration.boxed_clone(),
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
@@ -2648,6 +2666,25 @@ impl AgentPanel {
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
.child(
|
||||
NewThreadButton::new(
|
||||
"new-codex-thread-btn",
|
||||
"New Codex Thread",
|
||||
IconName::AiOpenAi,
|
||||
)
|
||||
.on_click(
|
||||
|window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(NewExternalAgentThread {
|
||||
agent: Some(
|
||||
crate::ExternalAgent::Codex,
|
||||
),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
}),
|
||||
@@ -2676,11 +2713,16 @@ impl AgentPanel {
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Warning))
|
||||
.label_size(LabelSize::Small)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
KeyBinding::for_action_in(
|
||||
&OpenConfiguration,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(|_event, window, cx| {
|
||||
window.dispatch_action(OpenSettings.boxed_clone(), cx)
|
||||
window.dispatch_action(OpenConfiguration.boxed_clone(), cx)
|
||||
}),
|
||||
),
|
||||
ConfigurationError::ProviderPendingTermsAcceptance(provider) => {
|
||||
@@ -2874,7 +2916,7 @@ impl AgentPanel {
|
||||
) -> AnyElement {
|
||||
let error_message = match plan {
|
||||
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
|
||||
Plan::ZedProTrial | Plan::ZedFree => "Upgrade to Zed Pro for more prompts.",
|
||||
Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.",
|
||||
};
|
||||
|
||||
let icon = Icon::new(IconName::XCircle)
|
||||
@@ -3184,7 +3226,7 @@ impl Render for AgentPanel {
|
||||
.on_action(cx.listener(|this, _: &OpenHistory, window, cx| {
|
||||
this.open_history(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &OpenSettings, window, cx| {
|
||||
.on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| {
|
||||
this.open_configuration(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(Self::open_active_thread_as_markdown))
|
||||
|
||||
@@ -150,6 +150,7 @@ enum ExternalAgent {
|
||||
#[default]
|
||||
Gemini,
|
||||
ClaudeCode,
|
||||
Codex,
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
@@ -157,6 +158,7 @@ impl ExternalAgent {
|
||||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::Codex => Rc::new(agent_servers::Codex),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,8 +265,8 @@ fn update_command_palette_filter(cx: &mut App) {
|
||||
filter.hide_namespace("agent");
|
||||
filter.hide_namespace("assistant");
|
||||
filter.hide_namespace("copilot");
|
||||
filter.hide_namespace("supermaven");
|
||||
filter.hide_namespace("zed_predict_onboarding");
|
||||
|
||||
filter.hide_namespace("edit_prediction");
|
||||
|
||||
use editor::actions::{
|
||||
|
||||
@@ -6,7 +6,6 @@ use agent::{
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{
|
||||
@@ -36,6 +35,7 @@ use std::{
|
||||
};
|
||||
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub struct BufferCodegen {
|
||||
alternatives: Vec<Entity<CodegenAlternative>>,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#![allow(unused, dead_code)]
|
||||
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{Plan, UsageLimit};
|
||||
use gpui::Global;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use ui::prelude::*;
|
||||
use zed_llm_client::{Plan, UsageLimit};
|
||||
|
||||
/// Debug only: Used for testing various account states
|
||||
///
|
||||
|
||||
@@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _};
|
||||
use ui::prelude::*;
|
||||
use util::{RangeExt, ResultExt, maybe};
|
||||
use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId};
|
||||
use zed_actions::agent::OpenSettings;
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -345,7 +345,7 @@ impl InlineAssistant {
|
||||
if let Some(answer) = answer {
|
||||
if answer == 0 {
|
||||
cx.update(|window, cx| {
|
||||
window.dispatch_action(Box::new(OpenSettings), cx)
|
||||
window.dispatch_action(Box::new(OpenConfiguration), cx)
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(|_, window, cx| {
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenSettings.boxed_clone(),
|
||||
zed_actions::agent::OpenConfiguration.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
|
||||
@@ -17,7 +17,7 @@ use agent::{
|
||||
use agent_settings::{AgentSettings, CompletionMode};
|
||||
use ai_onboarding::ApiKeysWithProviders;
|
||||
use buffer_diff::BufferDiff;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use client::UserStore;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
use editor::display_map::CreaseId;
|
||||
@@ -42,6 +42,7 @@ use language_model::{
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use proto::Plan;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
@@ -52,6 +53,7 @@ use util::ResultExt as _;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::Chat;
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
@@ -77,6 +79,7 @@ pub struct MessageEditor {
|
||||
editor: Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
user_store: Entity<UserStore>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
history_store: Option<WeakEntity<HistoryStore>>,
|
||||
@@ -156,6 +159,7 @@ impl MessageEditor {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
user_store: Entity<UserStore>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
@@ -227,6 +231,7 @@ impl MessageEditor {
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
project: thread.read(cx).project().clone(),
|
||||
user_store,
|
||||
thread,
|
||||
incompatible_tools_state: incompatible_tools.clone(),
|
||||
workspace,
|
||||
@@ -1282,12 +1287,24 @@ impl MessageEditor {
|
||||
return None;
|
||||
}
|
||||
|
||||
let user_store = self.project.read(cx).user_store().read(cx);
|
||||
if user_store.is_usage_based_billing_enabled() {
|
||||
let user_store = self.user_store.read(cx);
|
||||
|
||||
let ubb_enable = user_store
|
||||
.usage_based_billing_enabled()
|
||||
.map_or(false, |enabled| enabled);
|
||||
|
||||
if ubb_enable {
|
||||
return None;
|
||||
}
|
||||
|
||||
let plan = user_store.plan().unwrap_or(cloud_llm_client::Plan::ZedFree);
|
||||
let plan = user_store
|
||||
.current_plan()
|
||||
.map(|plan| match plan {
|
||||
Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
})
|
||||
.unwrap_or(zed_llm_client::Plan::ZedFree);
|
||||
|
||||
let usage = user_store.model_request_usage()?;
|
||||
|
||||
@@ -1752,6 +1769,7 @@ impl AgentPreview for MessageEditor {
|
||||
) -> Option<AnyElement> {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
let fs = workspace.read(cx).app_state().fs.clone();
|
||||
let user_store = workspace.read(cx).app_state().user_store.clone();
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let weak_project = project.downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
|
||||
@@ -1764,6 +1782,7 @@ impl AgentPreview for MessageEditor {
|
||||
MessageEditor::new(
|
||||
fs,
|
||||
workspace.downgrade(),
|
||||
user_store,
|
||||
context_store,
|
||||
None,
|
||||
thread_store.downgrade(),
|
||||
|
||||
@@ -10,7 +10,6 @@ use agent::{
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use editor::{MultiBuffer, actions::SelectAll};
|
||||
use fs::Fs;
|
||||
@@ -28,6 +27,7 @@ use terminal_view::TerminalView;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use workspace::{Toast, Workspace, notifications::NotificationId};
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use client::{ModelRequestUsage, RequestUsage, zed_urls};
|
||||
use cloud_llm_client::{Plan, UsageLimit};
|
||||
use component::{empty_example, example_group_with_title, single_example};
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Callout, prelude::*};
|
||||
use zed_llm_client::{Plan, UsageLimit};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct UsageCallout {
|
||||
|
||||
@@ -16,10 +16,10 @@ default = []
|
||||
|
||||
[dependencies]
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
component.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
proto.workspace = true
|
||||
serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
telemetry.workspace = true
|
||||
|
||||
@@ -136,7 +136,10 @@ impl RenderOnce for ApiKeysWithoutProviders {
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.on_click(move |_, window, cx| {
|
||||
window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenConfiguration.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, UserStore};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{Entity, IntoElement, ParentElement};
|
||||
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use ui::prelude::*;
|
||||
@@ -57,8 +56,15 @@ impl AgentPanelOnboarding {
|
||||
|
||||
impl Render for AgentPanelOnboarding {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let enrolled_in_trial = self.user_store.read(cx).plan() == Some(Plan::ZedProTrial);
|
||||
let is_pro_user = self.user_store.read(cx).plan() == Some(Plan::ZedPro);
|
||||
let enrolled_in_trial = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedProTrial)
|
||||
);
|
||||
|
||||
let is_pro_user = matches!(
|
||||
self.user_store.read(cx).current_plan(),
|
||||
Some(proto::Plan::ZedPro)
|
||||
);
|
||||
|
||||
AgentPanelOnboardingCard::new()
|
||||
.child(
|
||||
|
||||
@@ -9,7 +9,6 @@ pub use agent_api_keys_onboarding::{ApiKeysWithProviders, ApiKeysWithoutProvider
|
||||
pub use agent_panel_onboarding_card::AgentPanelOnboardingCard;
|
||||
pub use agent_panel_onboarding_content::AgentPanelOnboarding;
|
||||
pub use ai_upsell_card::AiUpsellCard;
|
||||
use cloud_llm_client::Plan;
|
||||
pub use edit_prediction_onboarding_content::EditPredictionOnboarding;
|
||||
pub use young_account_banner::YoungAccountBanner;
|
||||
|
||||
@@ -80,7 +79,7 @@ impl From<client::Status> for SignInStatus {
|
||||
pub struct ZedAiOnboarding {
|
||||
pub sign_in_status: SignInStatus,
|
||||
pub has_accepted_terms_of_service: bool,
|
||||
pub plan: Option<Plan>,
|
||||
pub plan: Option<proto::Plan>,
|
||||
pub account_too_young: bool,
|
||||
pub continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
@@ -100,8 +99,8 @@ impl ZedAiOnboarding {
|
||||
|
||||
Self {
|
||||
sign_in_status: status.into(),
|
||||
has_accepted_terms_of_service: store.has_accepted_terms_of_service(),
|
||||
plan: store.plan(),
|
||||
has_accepted_terms_of_service: store.current_user_has_accepted_terms().unwrap_or(false),
|
||||
plan: store.current_plan(),
|
||||
account_too_young: store.account_too_young(),
|
||||
continue_with_zed_ai,
|
||||
accept_terms_of_service: Arc::new({
|
||||
@@ -114,9 +113,11 @@ impl ZedAiOnboarding {
|
||||
sign_in: Arc::new(move |_window, cx| {
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| client.sign_in_with_optional_connect(true, cx).await
|
||||
async move |cx| {
|
||||
client.authenticate_and_connect(true, cx).await;
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
.detach();
|
||||
}),
|
||||
dismiss_onboarding: None,
|
||||
}
|
||||
@@ -410,9 +411,9 @@ impl RenderOnce for ZedAiOnboarding {
|
||||
if matches!(self.sign_in_status, SignInStatus::SignedIn) {
|
||||
if self.has_accepted_terms_of_service {
|
||||
match self.plan {
|
||||
None | Some(Plan::ZedFree) => self.render_free_plan_state(cx),
|
||||
Some(Plan::ZedProTrial) => self.render_trial_state(cx),
|
||||
Some(Plan::ZedPro) => self.render_pro_plan_state(cx),
|
||||
None | Some(proto::Plan::Free) => self.render_free_plan_state(cx),
|
||||
Some(proto::Plan::ZedProTrial) => self.render_trial_state(cx),
|
||||
Some(proto::Plan::ZedPro) => self.render_pro_plan_state(cx),
|
||||
}
|
||||
} else {
|
||||
self.render_accept_terms_of_service()
|
||||
@@ -432,7 +433,7 @@ impl Component for ZedAiOnboarding {
|
||||
fn onboarding(
|
||||
sign_in_status: SignInStatus,
|
||||
has_accepted_terms_of_service: bool,
|
||||
plan: Option<Plan>,
|
||||
plan: Option<proto::Plan>,
|
||||
account_too_young: bool,
|
||||
) -> AnyElement {
|
||||
ZedAiOnboarding {
|
||||
@@ -467,15 +468,25 @@ impl Component for ZedAiOnboarding {
|
||||
),
|
||||
single_example(
|
||||
"Free Plan",
|
||||
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedFree), false),
|
||||
onboarding(SignInStatus::SignedIn, true, Some(proto::Plan::Free), false),
|
||||
),
|
||||
single_example(
|
||||
"Pro Trial",
|
||||
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedProTrial), false),
|
||||
onboarding(
|
||||
SignInStatus::SignedIn,
|
||||
true,
|
||||
Some(proto::Plan::ZedProTrial),
|
||||
false,
|
||||
),
|
||||
),
|
||||
single_example(
|
||||
"Pro Plan",
|
||||
onboarding(SignInStatus::SignedIn, true, Some(Plan::ZedPro), false),
|
||||
onboarding(
|
||||
SignInStatus::SignedIn,
|
||||
true,
|
||||
Some(proto::Plan::ZedPro),
|
||||
false,
|
||||
),
|
||||
),
|
||||
])
|
||||
.into_any_element(),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, zed_urls};
|
||||
use cloud_llm_client::Plan;
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Divider, List, Vector, VectorName, prelude::*};
|
||||
|
||||
@@ -11,22 +10,22 @@ use crate::{BulletItem, SignInStatus};
|
||||
pub struct AiUpsellCard {
|
||||
pub sign_in_status: SignInStatus,
|
||||
pub sign_in: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
pub user_plan: Option<Plan>,
|
||||
}
|
||||
|
||||
impl AiUpsellCard {
|
||||
pub fn new(client: Arc<Client>, user_plan: Option<Plan>) -> Self {
|
||||
pub fn new(client: Arc<Client>) -> Self {
|
||||
let status = *client.status().borrow();
|
||||
|
||||
Self {
|
||||
user_plan,
|
||||
sign_in_status: status.into(),
|
||||
sign_in: Arc::new(move |_window, cx| {
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| client.sign_in_with_optional_connect(true, cx).await
|
||||
async move |cx| {
|
||||
client.authenticate_and_connect(true, cx).await;
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
.detach();
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -35,7 +34,6 @@ impl AiUpsellCard {
|
||||
impl RenderOnce for AiUpsellCard {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let pro_section = v_flex()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
@@ -58,7 +56,6 @@ impl RenderOnce for AiUpsellCard {
|
||||
);
|
||||
|
||||
let free_section = v_flex()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.child(
|
||||
@@ -74,7 +71,7 @@ impl RenderOnce for AiUpsellCard {
|
||||
)
|
||||
.child(
|
||||
List::new()
|
||||
.child(BulletItem::new("50 prompts with Claude models"))
|
||||
.child(BulletItem::new("50 prompts with the Claude models"))
|
||||
.child(BulletItem::new("2,000 accepted edit predictions")),
|
||||
);
|
||||
|
||||
@@ -135,28 +132,22 @@ impl RenderOnce for AiUpsellCard {
|
||||
|
||||
v_flex()
|
||||
.relative()
|
||||
.p_4()
|
||||
.pt_3()
|
||||
.p_6()
|
||||
.pt_4()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(grid_bg)
|
||||
.child(gradient_bg)
|
||||
.child(Label::new("Try Zed AI").size(LabelSize::Large))
|
||||
.child(
|
||||
div()
|
||||
.max_w_3_4()
|
||||
.mb_2()
|
||||
.child(Label::new(DESCRIPTION).color(Color::Muted)),
|
||||
)
|
||||
.child(Headline::new("Try Zed AI"))
|
||||
.child(Label::new(DESCRIPTION).color(Color::Muted).mb_2())
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.mt_1p5()
|
||||
.mb_2p5()
|
||||
.items_start()
|
||||
.gap_6()
|
||||
.gap_12()
|
||||
.child(free_section)
|
||||
.child(pro_section),
|
||||
)
|
||||
@@ -192,7 +183,6 @@ impl Component for AiUpsellCard {
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedOut,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: None,
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
@@ -201,7 +191,6 @@ impl Component for AiUpsellCard {
|
||||
AiUpsellCard {
|
||||
sign_in_status: SignInStatus::SignedIn,
|
||||
sign_in: Arc::new(|_, _| {}),
|
||||
user_plan: None,
|
||||
}
|
||||
.into_any_element(),
|
||||
),
|
||||
|
||||
@@ -19,7 +19,6 @@ assistant_slash_commands.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -49,6 +48,7 @@ util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
||||
|
||||
@@ -11,7 +11,6 @@ use assistant_slash_command::{
|
||||
use assistant_slash_commands::FileCommandMetadata;
|
||||
use client::{self, Client, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::{Fs, RenameOptions};
|
||||
use futures::{FutureExt, StreamExt, future::Shared};
|
||||
@@ -47,6 +46,7 @@ use text::{BufferSnapshot, ToPoint};
|
||||
use ui::IconName;
|
||||
use util::{ResultExt, TryFutureExt, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub use crate::context_store::*;
|
||||
|
||||
|
||||
@@ -21,11 +21,9 @@ assistant_tool.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
derive_more.workspace = true
|
||||
diffy = "0.4.2"
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -65,6 +63,8 @@ web_search.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
diffy = "0.4.2"
|
||||
|
||||
[dev-dependencies]
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -7,7 +7,6 @@ mod streaming_fuzzy_matcher;
|
||||
use crate::{Template, Templates};
|
||||
use anyhow::Result;
|
||||
use assistant_tool::ActionLog;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use create_file_parser::{CreateFileParser, CreateFileParserEvent};
|
||||
pub use edit_parser::EditFormat;
|
||||
use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
|
||||
@@ -30,6 +29,7 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::
|
||||
use streaming_diff::{CharOperation, StreamingDiff};
|
||||
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
|
||||
use util::debug_panic;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CreateFilePromptTemplate {
|
||||
|
||||
@@ -6,7 +6,6 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{
|
||||
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
|
||||
};
|
||||
use cloud_llm_client::{WebSearchResponse, WebSearchResult};
|
||||
use futures::{Future, FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
@@ -18,6 +17,7 @@ use serde::{Deserialize, Serialize};
|
||||
use ui::{IconName, Tooltip, prelude::*};
|
||||
use web_search::WebSearchRegistry;
|
||||
use workspace::Workspace;
|
||||
use zed_llm_client::{WebSearchResponse, WebSearchResult};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct WebSearchToolInput {
|
||||
|
||||
@@ -18,6 +18,6 @@ collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
|
||||
rodio = { version = "0.20.0", default-features = false, features = ["wav"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -3,9 +3,12 @@ use std::{io::Cursor, sync::Arc};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AssetSource, Global};
|
||||
use rodio::{Decoder, Source, source::Buffered};
|
||||
use rodio::{
|
||||
Decoder, Source,
|
||||
source::{Buffered, SamplesConverter},
|
||||
};
|
||||
|
||||
type Sound = Buffered<Decoder<Cursor<Vec<u8>>>>;
|
||||
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
|
||||
|
||||
pub struct SoundRegistry {
|
||||
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
|
||||
@@ -45,7 +48,7 @@ impl SoundRegistry {
|
||||
.with_context(|| format!("No asset available for path {path}"))??
|
||||
.into_owned();
|
||||
let cursor = Cursor::new(bytes);
|
||||
let source = Decoder::new(cursor)?.buffered();
|
||||
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
|
||||
|
||||
self.cache.lock().insert(name.to_string(), source.clone());
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use assets::SoundRegistry;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use gpui::{App, AssetSource, BorrowAppContext, Global};
|
||||
use rodio::{OutputStream, OutputStreamBuilder};
|
||||
use rodio::{OutputStream, OutputStreamHandle};
|
||||
use util::ResultExt;
|
||||
|
||||
mod assets;
|
||||
@@ -37,7 +37,8 @@ impl Sound {
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Audio {
|
||||
output_handle: Option<OutputStream>,
|
||||
_output_stream: Option<OutputStream>,
|
||||
output_handle: Option<OutputStreamHandle>,
|
||||
}
|
||||
|
||||
#[derive(Deref, DerefMut)]
|
||||
@@ -50,9 +51,11 @@ impl Audio {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn ensure_output_exists(&mut self) -> Option<&OutputStream> {
|
||||
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
|
||||
if self.output_handle.is_none() {
|
||||
self.output_handle = OutputStreamBuilder::open_default_stream().log_err();
|
||||
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
|
||||
self.output_handle = output_handle;
|
||||
self._output_stream = _output_stream;
|
||||
}
|
||||
|
||||
self.output_handle.as_ref()
|
||||
@@ -66,7 +69,7 @@ impl Audio {
|
||||
cx.update_global::<GlobalAudio, _>(|this, cx| {
|
||||
let output_handle = this.ensure_output_exists()?;
|
||||
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
|
||||
output_handle.mixer().add(source);
|
||||
output_handle.play_raw(source).log_err()?;
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
@@ -77,6 +80,7 @@ impl Audio {
|
||||
}
|
||||
|
||||
cx.update_global::<GlobalAudio, _>(|this, _| {
|
||||
this._output_stream.take();
|
||||
this.output_handle.take();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ impl ChannelMembership {
|
||||
proto::channel_member::Kind::Member => 0,
|
||||
proto::channel_member::Kind::Invitee => 1,
|
||||
},
|
||||
username_order: &self.user.github_login,
|
||||
username_order: self.user.github_login.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,6 +259,20 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
assert_channels(&channel_store, &[(0, "the-channel".to_string())], cx);
|
||||
});
|
||||
|
||||
let get_users = server.receive::<proto::GetUsers>().await.unwrap();
|
||||
assert_eq!(get_users.payload.user_ids, vec![5]);
|
||||
server.respond(
|
||||
get_users.receipt(),
|
||||
proto::UsersResponse {
|
||||
users: vec![proto::User {
|
||||
id: 5,
|
||||
github_login: "nathansobo".into(),
|
||||
avatar_url: "http://avatar.com/nathansobo".into(),
|
||||
name: None,
|
||||
}],
|
||||
},
|
||||
);
|
||||
|
||||
// Join a channel and populate its existing messages.
|
||||
let channel = channel_store.update(cx, |store, cx| {
|
||||
let channel_id = store.ordered_channels().next().unwrap().1.id;
|
||||
@@ -320,7 +334,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
&[
|
||||
("user-5".into(), "a".into()),
|
||||
("nathansobo".into(), "a".into()),
|
||||
("maxbrunsfeld".into(), "b".into())
|
||||
]
|
||||
);
|
||||
@@ -423,7 +437,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
|
||||
.map(|message| (message.sender.github_login.clone(), message.body.clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
&[
|
||||
("user-5".into(), "y".into()),
|
||||
("nathansobo".into(), "y".into()),
|
||||
("maxbrunsfeld".into(), "z".into())
|
||||
]
|
||||
);
|
||||
|
||||
@@ -17,12 +17,11 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
|
||||
base64.workspace = true
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
derive_more.workspace = true
|
||||
@@ -34,8 +33,8 @@ http_client.workspace = true
|
||||
http_client_tls.workspace = true
|
||||
httparse = "1.10"
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
@@ -47,18 +46,19 @@ serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sha2.workspace = true
|
||||
smol.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http.workspace = true
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
tokio.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
worktree.workspace = true
|
||||
telemetry.workspace = true
|
||||
tokio.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -6,21 +6,22 @@ pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use async_recursion::async_recursion;
|
||||
use async_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
error::Error as WebsocketError,
|
||||
http::{HeaderValue, Request, StatusCode},
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use clock::SystemClock;
|
||||
use cloud_api_client::CloudApiClient;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{
|
||||
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
|
||||
channel::oneshot, future::BoxFuture,
|
||||
};
|
||||
use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{HttpClient, HttpClientWithUrl, http};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
@@ -30,6 +31,7 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use std::pin::Pin;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
convert::TryFrom,
|
||||
@@ -43,7 +45,6 @@ use std::{
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{cmp, pin::Pin};
|
||||
use telemetry::Telemetry;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
@@ -77,7 +78,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
|
||||
LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
|
||||
|
||||
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
|
||||
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
|
||||
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
|
||||
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
|
||||
|
||||
actions!(
|
||||
@@ -161,8 +162,20 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
|
||||
let client = client.clone();
|
||||
move |_: &SignIn, cx| {
|
||||
if let Some(client) = client.upgrade() {
|
||||
cx.spawn(async move |cx| client.sign_in_with_optional_connect(true, &cx).await)
|
||||
.detach_and_log_err(cx);
|
||||
cx.spawn(
|
||||
async move |cx| match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("Initial authentication timed out");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::error!("Initial authentication connection reset");
|
||||
}
|
||||
ConnectionResult::Result(r) => {
|
||||
r.log_err();
|
||||
}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -200,7 +213,6 @@ pub struct Client {
|
||||
id: AtomicU64,
|
||||
peer: Arc<Peer>,
|
||||
http: Arc<HttpClientWithUrl>,
|
||||
cloud_client: Arc<CloudApiClient>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
credentials_provider: ClientCredentialsProvider,
|
||||
state: RwLock<ClientState>,
|
||||
@@ -271,8 +283,6 @@ pub enum Status {
|
||||
SignedOut,
|
||||
UpgradeRequired,
|
||||
Authenticating,
|
||||
Authenticated,
|
||||
AuthenticationError,
|
||||
Connecting,
|
||||
ConnectionError,
|
||||
Connected {
|
||||
@@ -576,7 +586,6 @@ impl Client {
|
||||
id: AtomicU64::new(0),
|
||||
peer: Peer::new(0),
|
||||
telemetry: Telemetry::new(clock, http.clone(), cx),
|
||||
cloud_client: Arc::new(CloudApiClient::new(http.clone())),
|
||||
http,
|
||||
credentials_provider: ClientCredentialsProvider::new(cx),
|
||||
state: Default::default(),
|
||||
@@ -609,10 +618,6 @@ impl Client {
|
||||
self.http.clone()
|
||||
}
|
||||
|
||||
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
|
||||
self.cloud_client.clone()
|
||||
}
|
||||
|
||||
pub fn set_id(&self, id: u64) -> &Self {
|
||||
self.id.store(id, Ordering::SeqCst);
|
||||
self
|
||||
@@ -699,7 +704,7 @@ impl Client {
|
||||
|
||||
let mut delay = INITIAL_RECONNECTION_DELAY;
|
||||
loop {
|
||||
match client.connect(true, &cx).await {
|
||||
match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("client connect attempt timed out")
|
||||
}
|
||||
@@ -722,10 +727,11 @@ impl Client {
|
||||
},
|
||||
&cx,
|
||||
);
|
||||
let jitter =
|
||||
Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
|
||||
cx.background_executor().timer(delay + jitter).await;
|
||||
delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
|
||||
cx.background_executor().timer(delay).await;
|
||||
delay = delay
|
||||
.mul_f32(rng.gen_range(0.5..=2.5))
|
||||
.max(INITIAL_RECONNECTION_DELAY)
|
||||
.min(MAX_RECONNECTION_DELAY);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
@@ -869,122 +875,17 @@ impl Client {
|
||||
.is_some()
|
||||
}
|
||||
|
||||
pub async fn sign_in(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Credentials> {
|
||||
if self.status().borrow().is_signed_out() {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx);
|
||||
}
|
||||
|
||||
let mut credentials = None;
|
||||
|
||||
let old_credentials = self.state.read().credentials.clone();
|
||||
if let Some(old_credentials) = old_credentials {
|
||||
self.cloud_client.set_credentials(
|
||||
old_credentials.user_id as u32,
|
||||
old_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the old credentials, to ensure they are still valid.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(old_credentials);
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() && try_provider {
|
||||
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
|
||||
self.cloud_client.set_credentials(
|
||||
stored_credentials.user_id as u32,
|
||||
stored_credentials.access_token.clone(),
|
||||
);
|
||||
|
||||
// Fetch the authenticated user with the stored credentials, and
|
||||
// clear them from the credentials provider if that fails.
|
||||
if self.cloud_client.get_authenticated_user().await.is_ok() {
|
||||
credentials = Some(stored_credentials);
|
||||
} else {
|
||||
self.credentials_provider
|
||||
.delete_credentials(cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => {
|
||||
if IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider
|
||||
.write_credentials(creds.user_id, creds.access_token.clone(), cx)
|
||||
.await
|
||||
.log_err();
|
||||
}
|
||||
|
||||
credentials = Some(creds);
|
||||
},
|
||||
Err(err) => {
|
||||
self.set_status(Status::AuthenticationError, cx);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return Err(anyhow!("authentication canceled"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
self.set_status(Status::Authenticated, cx);
|
||||
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Performs a sign-in and also connects to Collab.
|
||||
///
|
||||
/// This is called in places where we *don't* need to connect in the future. We will replace these calls with calls
|
||||
/// to `sign_in` when we're ready to remove auto-connection to Collab.
|
||||
pub async fn sign_in_with_optional_connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let credentials = self.sign_in(try_provider, cx).await?;
|
||||
|
||||
let connect_result = match self.connect_with_credentials(credentials, cx).await {
|
||||
ConnectionResult::Timeout => Err(anyhow!("connection timed out")),
|
||||
ConnectionResult::ConnectionReset => Err(anyhow!("connection reset")),
|
||||
ConnectionResult::Result(result) => result.context("client auth and connect"),
|
||||
};
|
||||
connect_result.log_err();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
#[async_recursion(?Send)]
|
||||
pub async fn authenticate_and_connect(
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::SignedOut | Status::Authenticated => true,
|
||||
Status::SignedOut => true,
|
||||
Status::ConnectionError
|
||||
| Status::ConnectionLost
|
||||
| Status::Authenticating { .. }
|
||||
| Status::AuthenticationError
|
||||
| Status::Reauthenticating { .. }
|
||||
| Status::ReconnectionError { .. } => false,
|
||||
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
|
||||
@@ -997,10 +898,39 @@ impl Client {
|
||||
);
|
||||
}
|
||||
};
|
||||
let credentials = match self.sign_in(try_provider, cx).await {
|
||||
Ok(credentials) => credentials,
|
||||
Err(err) => return ConnectionResult::Result(Err(err)),
|
||||
};
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
} else {
|
||||
self.set_status(Status::Reauthenticating, cx)
|
||||
}
|
||||
|
||||
let mut read_from_provider = false;
|
||||
let mut credentials = self.state.read().credentials.clone();
|
||||
if credentials.is_none() && try_provider {
|
||||
credentials = self.credentials_provider.read_credentials(cx).await;
|
||||
read_from_provider = credentials.is_some();
|
||||
}
|
||||
|
||||
if credentials.is_none() {
|
||||
let mut status_rx = self.status();
|
||||
let _ = status_rx.next().await;
|
||||
futures::select_biased! {
|
||||
authenticate = self.authenticate(cx).fuse() => {
|
||||
match authenticate {
|
||||
Ok(creds) => credentials = Some(creds),
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return ConnectionResult::Result(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
|
||||
}
|
||||
}
|
||||
}
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting, cx);
|
||||
@@ -1008,20 +938,17 @@ impl Client {
|
||||
self.set_status(Status::Reconnecting, cx);
|
||||
}
|
||||
|
||||
self.connect_with_credentials(credentials, cx).await
|
||||
}
|
||||
|
||||
async fn connect_with_credentials(
|
||||
self: &Arc<Self>,
|
||||
credentials: Credentials,
|
||||
cx: &AsyncApp,
|
||||
) -> ConnectionResult<()> {
|
||||
let mut timeout =
|
||||
futures::FutureExt::fuse(cx.background_executor().timer(CONNECTION_TIMEOUT));
|
||||
futures::select_biased! {
|
||||
connection = self.establish_connection(&credentials, cx).fuse() => {
|
||||
match connection {
|
||||
Ok(conn) => {
|
||||
self.state.write().credentials = Some(credentials.clone());
|
||||
if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
|
||||
self.credentials_provider.write_credentials(credentials.user_id, credentials.access_token, cx).await.log_err();
|
||||
}
|
||||
|
||||
futures::select_biased! {
|
||||
result = self.set_connection(conn, cx).fuse() => {
|
||||
match result.context("client auth and connect") {
|
||||
@@ -1039,8 +966,15 @@ impl Client {
|
||||
}
|
||||
}
|
||||
Err(EstablishConnectionError::Unauthorized) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
self.state.write().credentials.take();
|
||||
if read_from_provider {
|
||||
self.credentials_provider.delete_credentials(cx).await.log_err();
|
||||
self.set_status(Status::SignedOut, cx);
|
||||
self.authenticate_and_connect(false, cx).await
|
||||
} else {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
}
|
||||
Err(EstablishConnectionError::UpgradeRequired) => {
|
||||
self.set_status(Status::UpgradeRequired, cx);
|
||||
@@ -1204,7 +1138,7 @@ impl Client {
|
||||
.to_str()
|
||||
.map_err(EstablishConnectionError::other)?
|
||||
.to_string();
|
||||
Url::parse(&collab_url).with_context(|| format!("parsing collab rpc url {collab_url}"))
|
||||
Url::parse(&collab_url).with_context(|| format!("parsing colab rpc url {collab_url}"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1224,7 +1158,6 @@ impl Client {
|
||||
|
||||
let http = self.http.clone();
|
||||
let proxy = http.proxy().cloned();
|
||||
let user_agent = http.user_agent().cloned();
|
||||
let credentials = credentials.clone();
|
||||
let rpc_url = self.rpc_url(http, release_channel);
|
||||
let system_id = self.telemetry.system_id();
|
||||
@@ -1276,7 +1209,7 @@ impl Client {
|
||||
// We then modify the request to add our desired headers.
|
||||
let request_headers = request.headers_mut();
|
||||
request_headers.insert(
|
||||
http::header::AUTHORIZATION,
|
||||
"Authorization",
|
||||
HeaderValue::from_str(&credentials.authorization_header())?,
|
||||
);
|
||||
request_headers.insert(
|
||||
@@ -1288,9 +1221,6 @@ impl Client {
|
||||
"x-zed-release-channel",
|
||||
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
|
||||
);
|
||||
if let Some(user_agent) = user_agent {
|
||||
request_headers.insert(http::header::USER_AGENT, user_agent);
|
||||
}
|
||||
if let Some(system_id) = system_id {
|
||||
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
|
||||
}
|
||||
@@ -1435,31 +1365,96 @@ impl Client {
|
||||
self: &Arc<Self>,
|
||||
http: Arc<HttpClientWithUrl>,
|
||||
login: String,
|
||||
api_token: String,
|
||||
mut api_token: String,
|
||||
) -> Result<Credentials> {
|
||||
#[derive(Serialize)]
|
||||
struct ImpersonateUserBody {
|
||||
github_login: String,
|
||||
#[derive(Deserialize)]
|
||||
struct AuthenticatedUserResponse {
|
||||
user: User,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ImpersonateUserResponse {
|
||||
user_id: u64,
|
||||
access_token: String,
|
||||
struct User {
|
||||
id: u64,
|
||||
}
|
||||
|
||||
let url = self
|
||||
.http
|
||||
.build_zed_cloud_url("/internal/users/impersonate", &[])?;
|
||||
let request = Request::post(url.as_str())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {api_token}"))
|
||||
.body(
|
||||
serde_json::to_string(&ImpersonateUserBody {
|
||||
github_login: login,
|
||||
})?
|
||||
.into(),
|
||||
)?;
|
||||
let github_user = {
|
||||
#[derive(Deserialize)]
|
||||
struct GithubUser {
|
||||
id: i32,
|
||||
login: String,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
let request = {
|
||||
let mut request_builder =
|
||||
Request::get(&format!("https://api.github.com/users/{login}"));
|
||||
if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", github_token));
|
||||
}
|
||||
|
||||
request_builder.body(AsyncBody::empty())?
|
||||
};
|
||||
|
||||
let mut response = http
|
||||
.send(request)
|
||||
.await
|
||||
.context("error fetching GitHub user")?;
|
||||
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("error reading GitHub user")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = String::from_utf8_lossy(body.as_slice());
|
||||
bail!(
|
||||
"status error {}, response: {text:?}",
|
||||
response.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
serde_json::from_slice::<GithubUser>(body.as_slice()).map_err(|err| {
|
||||
log::error!("Error deserializing: {:?}", err);
|
||||
log::error!(
|
||||
"GitHub API response text: {:?}",
|
||||
String::from_utf8_lossy(body.as_slice())
|
||||
);
|
||||
anyhow!("error deserializing GitHub user")
|
||||
})?
|
||||
};
|
||||
|
||||
let query_params = [
|
||||
("github_login", &github_user.login),
|
||||
("github_user_id", &github_user.id.to_string()),
|
||||
(
|
||||
"github_user_created_at",
|
||||
&github_user.created_at.to_rfc3339(),
|
||||
),
|
||||
];
|
||||
|
||||
// Use the collab server's admin API to retrieve the ID
|
||||
// of the impersonated user.
|
||||
let mut url = self.rpc_url(http.clone(), None).await?;
|
||||
url.set_path("/user");
|
||||
url.set_query(Some(
|
||||
&query_params
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
format!(
|
||||
"{}={}",
|
||||
key,
|
||||
url::form_urlencoded::byte_serialize(value.as_bytes()).collect::<String>()
|
||||
)
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("&"),
|
||||
));
|
||||
let request: http_client::Request<AsyncBody> = Request::get(url.as_str())
|
||||
.header("Authorization", format!("token {api_token}"))
|
||||
.body("".into())?;
|
||||
|
||||
let mut response = http.send(request).await?;
|
||||
let mut body = String::new();
|
||||
@@ -1470,17 +1465,18 @@ impl Client {
|
||||
response.status().as_u16(),
|
||||
body,
|
||||
);
|
||||
let response: ImpersonateUserResponse = serde_json::from_str(&body)?;
|
||||
let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
|
||||
|
||||
// Use the admin API token to authenticate as the impersonated user.
|
||||
api_token.insert_str(0, "ADMIN_TOKEN:");
|
||||
Ok(Credentials {
|
||||
user_id: response.user_id,
|
||||
access_token: response.access_token,
|
||||
user_id: response.user.id,
|
||||
access_token: api_token,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
|
||||
self.state.write().credentials = None;
|
||||
self.cloud_client.clear_credentials();
|
||||
self.disconnect(cx);
|
||||
|
||||
if self.has_credentials(cx).await {
|
||||
@@ -1790,7 +1786,7 @@ mod tests {
|
||||
});
|
||||
let auth_and_connect = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert!(matches!(status.next().await, Some(Status::Connecting)));
|
||||
@@ -1867,7 +1863,7 @@ mod tests {
|
||||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
move |cx| async move { client.connect(false, &cx).await }
|
||||
move |cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 1);
|
||||
@@ -1875,7 +1871,7 @@ mod tests {
|
||||
|
||||
let _authenticate = cx.spawn({
|
||||
let client = client.clone();
|
||||
|cx| async move { client.connect(false, &cx).await }
|
||||
|cx| async move { client.authenticate_and_connect(false, &cx).await }
|
||||
});
|
||||
executor.run_until_parked();
|
||||
assert_eq!(*auth_count.lock(), 2);
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::Duration;
|
||||
use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
|
||||
use futures::{StreamExt, stream::BoxStream};
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
|
||||
use http_client::{AsyncBody, Method, Request, http};
|
||||
use parking_lot::Mutex;
|
||||
use rpc::{
|
||||
ConnectionId, Peer, Receipt, TypedEnvelope,
|
||||
@@ -42,44 +39,6 @@ impl FakeServer {
|
||||
executor: cx.executor(),
|
||||
};
|
||||
|
||||
client.http_client().as_fake().replace_handler({
|
||||
let state = server.state.clone();
|
||||
move |old_handler, req| {
|
||||
let state = state.clone();
|
||||
let old_handler = old_handler.clone();
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/client/users/me") => {
|
||||
let credentials = parse_authorization_header(&req);
|
||||
if credentials
|
||||
!= Some(Credentials {
|
||||
user_id: client_user_id,
|
||||
access_token: state.lock().access_token.to_string(),
|
||||
})
|
||||
{
|
||||
return Ok(http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
Ok(http_client::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
serde_json::to_string(&make_get_authenticated_user_response(
|
||||
client_user_id as i32,
|
||||
format!("user-{client_user_id}"),
|
||||
))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
_ => old_handler(req).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
client
|
||||
.override_authenticate({
|
||||
let state = Arc::downgrade(&server.state);
|
||||
@@ -146,7 +105,7 @@ impl FakeServer {
|
||||
});
|
||||
|
||||
client
|
||||
.connect(false, &cx.to_async())
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -264,54 +223,3 @@ impl Drop for FakeServer {
|
||||
self.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
|
||||
let mut auth_header = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)?
|
||||
.to_str()
|
||||
.ok()?
|
||||
.split_whitespace();
|
||||
let user_id = auth_header.next()?.parse().ok()?;
|
||||
let access_token = auth_header.next()?;
|
||||
Some(Credentials {
|
||||
user_id,
|
||||
access_token: access_token.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn make_get_authenticated_user_response(
|
||||
user_id: i32,
|
||||
github_login: String,
|
||||
) -> GetAuthenticatedUserResponse {
|
||||
GetAuthenticatedUserResponse {
|
||||
user: AuthenticatedUser {
|
||||
id: user_id,
|
||||
metrics_id: format!("metrics-id-{user_id}"),
|
||||
avatar_url: "".to_string(),
|
||||
github_login,
|
||||
name: None,
|
||||
is_staff: false,
|
||||
accepted_tos_at: None,
|
||||
},
|
||||
feature_flags: vec![],
|
||||
plan: PlanInfo {
|
||||
plan: Plan::ZedPro,
|
||||
subscription_period: None,
|
||||
usage: CurrentUsage {
|
||||
model_requests: UsageData {
|
||||
used: 0,
|
||||
limit: UsageLimit::Limited(500),
|
||||
},
|
||||
edit_predictions: UsageData {
|
||||
used: 250,
|
||||
limit: UsageLimit::Unlimited,
|
||||
},
|
||||
},
|
||||
trial_started_at: None,
|
||||
is_usage_based_billing_enabled: false,
|
||||
is_account_too_young: false,
|
||||
has_overdue_invoices: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
use super::{Client, Status, TypedEnvelope, proto};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_api_client::{GetAuthenticatedUserResponse, PlanInfo};
|
||||
use cloud_llm_client::{
|
||||
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||
};
|
||||
use collections::{HashMap, HashSet, hash_map::Entry};
|
||||
use derive_more::Deref;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
@@ -21,7 +16,11 @@ use std::{
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
use text::ReplicaId;
|
||||
use util::{ResultExt, TryFutureExt as _};
|
||||
use util::{TryFutureExt as _, maybe};
|
||||
use zed_llm_client::{
|
||||
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||
};
|
||||
|
||||
pub type UserId = u64;
|
||||
|
||||
@@ -56,7 +55,7 @@ pub struct ParticipantIndex(pub u32);
|
||||
#[derive(Default, Debug)]
|
||||
pub struct User {
|
||||
pub id: UserId,
|
||||
pub github_login: SharedString,
|
||||
pub github_login: String,
|
||||
pub avatar_uri: SharedUri,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
@@ -108,14 +107,19 @@ pub enum ContactRequestStatus {
|
||||
|
||||
pub struct UserStore {
|
||||
users: HashMap<u64, Arc<User>>,
|
||||
by_github_login: HashMap<SharedString, u64>,
|
||||
by_github_login: HashMap<String, u64>,
|
||||
participant_indices: HashMap<u64, ParticipantIndex>,
|
||||
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
|
||||
current_plan: Option<proto::Plan>,
|
||||
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
trial_started_at: Option<DateTime<Utc>>,
|
||||
model_request_usage: Option<ModelRequestUsage>,
|
||||
edit_prediction_usage: Option<EditPredictionUsage>,
|
||||
plan_info: Option<PlanInfo>,
|
||||
is_usage_based_billing_enabled: Option<bool>,
|
||||
account_too_young: Option<bool>,
|
||||
has_overdue_invoices: Option<bool>,
|
||||
current_user: watch::Receiver<Option<Arc<User>>>,
|
||||
accepted_tos_at: Option<Option<cloud_api_client::Timestamp>>,
|
||||
accepted_tos_at: Option<Option<DateTime<Utc>>>,
|
||||
contacts: Vec<Arc<Contact>>,
|
||||
incoming_contact_requests: Vec<Arc<User>>,
|
||||
outgoing_contact_requests: Vec<Arc<User>>,
|
||||
@@ -141,7 +145,6 @@ pub enum Event {
|
||||
ShowContacts,
|
||||
ParticipantIndicesChanged,
|
||||
PrivateUserInfoUpdated,
|
||||
PlanUpdated,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@@ -185,9 +188,14 @@ impl UserStore {
|
||||
users: Default::default(),
|
||||
by_github_login: Default::default(),
|
||||
current_user: current_user_rx,
|
||||
plan_info: None,
|
||||
current_plan: None,
|
||||
subscription_period: None,
|
||||
trial_started_at: None,
|
||||
model_request_usage: None,
|
||||
edit_prediction_usage: None,
|
||||
is_usage_based_billing_enabled: None,
|
||||
account_too_young: None,
|
||||
has_overdue_invoices: None,
|
||||
accepted_tos_at: None,
|
||||
contacts: Default::default(),
|
||||
incoming_contact_requests: Default::default(),
|
||||
@@ -217,30 +225,53 @@ impl UserStore {
|
||||
return Ok(());
|
||||
};
|
||||
match status {
|
||||
Status::Authenticated | Status::Connected { .. } => {
|
||||
Status::Connected { .. } => {
|
||||
if let Some(user_id) = client.user_id() {
|
||||
let response = client.cloud_client().get_authenticated_user().await;
|
||||
let mut current_user = None;
|
||||
let fetch_user = if let Ok(fetch_user) =
|
||||
this.update(cx, |this, cx| this.get_user(user_id, cx).log_err())
|
||||
{
|
||||
fetch_user
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
let fetch_private_user_info =
|
||||
client.request(proto::GetPrivateUserInfo {}).log_err();
|
||||
let (user, info) =
|
||||
futures::join!(fetch_user, fetch_private_user_info);
|
||||
|
||||
cx.update(|cx| {
|
||||
if let Some(response) = response.log_err() {
|
||||
let user = Arc::new(User {
|
||||
id: user_id,
|
||||
github_login: response.user.github_login.clone().into(),
|
||||
avatar_uri: response.user.avatar_url.clone().into(),
|
||||
name: response.user.name.clone(),
|
||||
});
|
||||
current_user = Some(user.clone());
|
||||
if let Some(info) = info {
|
||||
let staff =
|
||||
info.staff && !*feature_flags::ZED_DISABLE_STAFF;
|
||||
cx.update_flags(staff, info.flags);
|
||||
client.telemetry.set_authenticated_user_info(
|
||||
Some(info.metrics_id.clone()),
|
||||
staff,
|
||||
);
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.by_github_login
|
||||
.insert(user.github_login.clone(), user_id);
|
||||
this.users.insert(user_id, user);
|
||||
this.update_authenticated_user(response, cx)
|
||||
let accepted_tos_at = {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok()
|
||||
{
|
||||
None
|
||||
} else {
|
||||
info.accepted_tos_at
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
info.accepted_tos_at
|
||||
};
|
||||
|
||||
this.set_current_user_accepted_tos_at(accepted_tos_at);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
})
|
||||
} else {
|
||||
anyhow::Ok(())
|
||||
}
|
||||
})??;
|
||||
current_user_tx.send(current_user).await.ok();
|
||||
|
||||
current_user_tx.send(user).await.ok();
|
||||
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
@@ -321,22 +352,59 @@ impl UserStore {
|
||||
|
||||
async fn handle_update_plan(
|
||||
this: Entity<Self>,
|
||||
_message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
message: TypedEnvelope<proto::UpdateUserPlan>,
|
||||
mut cx: AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = this
|
||||
.read_with(&cx, |this, _| this.client.upgrade())?
|
||||
.context("client was dropped")?;
|
||||
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.get_authenticated_user()
|
||||
.await
|
||||
.context("failed to fetch authenticated user")?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.update_authenticated_user(response, cx);
|
||||
})
|
||||
this.current_plan = Some(message.payload.plan());
|
||||
this.subscription_period = maybe!({
|
||||
let period = message.payload.subscription_period?;
|
||||
let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?;
|
||||
let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?;
|
||||
|
||||
Some((started_at, ended_at))
|
||||
});
|
||||
this.trial_started_at = message
|
||||
.payload
|
||||
.trial_started_at
|
||||
.and_then(|trial_started_at| DateTime::from_timestamp(trial_started_at as i64, 0));
|
||||
this.is_usage_based_billing_enabled = message.payload.is_usage_based_billing_enabled;
|
||||
this.account_too_young = message.payload.account_too_young;
|
||||
this.has_overdue_invoices = message.payload.has_overdue_invoices;
|
||||
|
||||
if let Some(usage) = message.payload.usage {
|
||||
// limits are always present even though they are wrapped in Option
|
||||
this.model_request_usage = usage
|
||||
.model_requests_usage_limit
|
||||
.and_then(|limit| {
|
||||
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
||||
})
|
||||
.map(ModelRequestUsage);
|
||||
this.edit_prediction_usage = usage
|
||||
.edit_predictions_usage_limit
|
||||
.and_then(|limit| {
|
||||
RequestUsage::from_proto(usage.model_requests_usage_amount, limit)
|
||||
})
|
||||
.map(EditPredictionUsage);
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
|
||||
self.model_request_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_usage(
|
||||
&mut self,
|
||||
usage: EditPredictionUsage,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.edit_prediction_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn update_contacts(&mut self, message: UpdateContacts, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
@@ -695,131 +763,59 @@ impl UserStore {
|
||||
self.current_user.borrow().clone()
|
||||
}
|
||||
|
||||
pub fn plan(&self) -> Option<cloud_llm_client::Plan> {
|
||||
pub fn current_plan(&self) -> Option<proto::Plan> {
|
||||
#[cfg(debug_assertions)]
|
||||
if let Ok(plan) = std::env::var("ZED_SIMULATE_PLAN").as_ref() {
|
||||
return match plan.as_str() {
|
||||
"free" => Some(cloud_llm_client::Plan::ZedFree),
|
||||
"trial" => Some(cloud_llm_client::Plan::ZedProTrial),
|
||||
"pro" => Some(cloud_llm_client::Plan::ZedPro),
|
||||
"free" => Some(proto::Plan::Free),
|
||||
"trial" => Some(proto::Plan::ZedProTrial),
|
||||
"pro" => Some(proto::Plan::ZedPro),
|
||||
_ => {
|
||||
panic!("ZED_SIMULATE_PLAN must be one of 'free', 'trial', or 'pro'");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
self.plan_info.as_ref().map(|info| info.plan)
|
||||
self.current_plan
|
||||
}
|
||||
|
||||
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.subscription_period)
|
||||
.map(|subscription_period| {
|
||||
(
|
||||
subscription_period.started_at.0,
|
||||
subscription_period.ended_at.0,
|
||||
)
|
||||
})
|
||||
self.subscription_period
|
||||
}
|
||||
|
||||
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.and_then(|plan| plan.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.0)
|
||||
self.trial_started_at
|
||||
}
|
||||
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_account_too_young)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||
pub fn has_overdue_invoices(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.has_overdue_invoices)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn is_usage_based_billing_enabled(&self) -> bool {
|
||||
self.plan_info
|
||||
.as_ref()
|
||||
.map(|plan| plan.is_usage_based_billing_enabled)
|
||||
.unwrap_or_default()
|
||||
pub fn usage_based_billing_enabled(&self) -> Option<bool> {
|
||||
self.is_usage_based_billing_enabled
|
||||
}
|
||||
|
||||
pub fn model_request_usage(&self) -> Option<ModelRequestUsage> {
|
||||
self.model_request_usage
|
||||
}
|
||||
|
||||
pub fn update_model_request_usage(&mut self, usage: ModelRequestUsage, cx: &mut Context<Self>) {
|
||||
self.model_request_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn edit_prediction_usage(&self) -> Option<EditPredictionUsage> {
|
||||
self.edit_prediction_usage
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_usage(
|
||||
&mut self,
|
||||
usage: EditPredictionUsage,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.edit_prediction_usage = Some(usage);
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn update_authenticated_user(
|
||||
&mut self,
|
||||
response: GetAuthenticatedUserResponse,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let staff = response.user.is_staff && !*feature_flags::ZED_DISABLE_STAFF;
|
||||
cx.update_flags(staff, response.feature_flags);
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
client
|
||||
.telemetry
|
||||
.set_authenticated_user_info(Some(response.user.metrics_id.clone()), staff);
|
||||
}
|
||||
|
||||
let accepted_tos_at = {
|
||||
#[cfg(debug_assertions)]
|
||||
if std::env::var("ZED_IGNORE_ACCEPTED_TOS").is_ok() {
|
||||
None
|
||||
} else {
|
||||
response.user.accepted_tos_at
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
response.user.accepted_tos_at
|
||||
};
|
||||
|
||||
self.accepted_tos_at = Some(accepted_tos_at);
|
||||
self.model_request_usage = Some(ModelRequestUsage(RequestUsage {
|
||||
limit: response.plan.usage.model_requests.limit,
|
||||
amount: response.plan.usage.model_requests.used as i32,
|
||||
}));
|
||||
self.edit_prediction_usage = Some(EditPredictionUsage(RequestUsage {
|
||||
limit: response.plan.usage.edit_predictions.limit,
|
||||
amount: response.plan.usage.edit_predictions.used as i32,
|
||||
}));
|
||||
self.plan_info = Some(response.plan);
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
}
|
||||
|
||||
pub fn watch_current_user(&self) -> watch::Receiver<Option<Arc<User>>> {
|
||||
self.current_user.clone()
|
||||
}
|
||||
|
||||
pub fn has_accepted_terms_of_service(&self) -> bool {
|
||||
/// Returns whether the user's account is too new to use the service.
|
||||
pub fn account_too_young(&self) -> bool {
|
||||
self.account_too_young.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Returns whether the current user has overdue invoices and usage should be blocked.
|
||||
pub fn has_overdue_invoices(&self) -> bool {
|
||||
self.has_overdue_invoices.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
|
||||
self.accepted_tos_at
|
||||
.map_or(false, |accepted_tos_at| accepted_tos_at.is_some())
|
||||
.map(|accepted_tos_at| accepted_tos_at.is_some())
|
||||
}
|
||||
|
||||
pub fn accept_terms_of_service(&self, cx: &Context<Self>) -> Task<Result<()>> {
|
||||
@@ -831,18 +827,23 @@ impl UserStore {
|
||||
cx.spawn(async move |this, cx| -> anyhow::Result<()> {
|
||||
let client = client.upgrade().context("client not found")?;
|
||||
let response = client
|
||||
.cloud_client()
|
||||
.accept_terms_of_service()
|
||||
.request(proto::AcceptTermsOfService {})
|
||||
.await
|
||||
.context("error accepting tos")?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.accepted_tos_at = Some(response.user.accepted_tos_at);
|
||||
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
|
||||
cx.emit(Event::PrivateUserInfoUpdated);
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
|
||||
self.accepted_tos_at = Some(
|
||||
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
|
||||
);
|
||||
}
|
||||
|
||||
fn load_users(
|
||||
&self,
|
||||
request: impl RequestMessage<Response = UsersResponse>,
|
||||
@@ -901,7 +902,7 @@ impl UserStore {
|
||||
let mut missing_user_ids = Vec::new();
|
||||
for id in user_ids {
|
||||
if let Some(github_login) = self.get_cached_user(id).map(|u| u.github_login.clone()) {
|
||||
ret.insert(id, github_login);
|
||||
ret.insert(id, github_login.into());
|
||||
} else {
|
||||
missing_user_ids.push(id)
|
||||
}
|
||||
@@ -922,7 +923,7 @@ impl User {
|
||||
fn new(message: proto::User) -> Arc<Self> {
|
||||
Arc::new(User {
|
||||
id: message.id,
|
||||
github_login: message.github_login.into(),
|
||||
github_login: message.github_login,
|
||||
avatar_uri: message.avatar_url.into(),
|
||||
name: message.name,
|
||||
})
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_api_client"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "Apache-2.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_api_client.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde_json.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1,155 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use cloud_api_types::*;
|
||||
use futures::AsyncReadExt as _;
|
||||
use http_client::http::request;
|
||||
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
struct Credentials {
|
||||
user_id: u32,
|
||||
access_token: String,
|
||||
}
|
||||
|
||||
pub struct CloudApiClient {
|
||||
credentials: RwLock<Option<Credentials>>,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
}
|
||||
|
||||
impl CloudApiClient {
|
||||
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
|
||||
Self {
|
||||
credentials: RwLock::new(None),
|
||||
http_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_credentials(&self) -> bool {
|
||||
self.credentials.read().is_some()
|
||||
}
|
||||
|
||||
pub fn set_credentials(&self, user_id: u32, access_token: String) {
|
||||
*self.credentials.write() = Some(Credentials {
|
||||
user_id,
|
||||
access_token,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn clear_credentials(&self) {
|
||||
*self.credentials.write() = None;
|
||||
}
|
||||
|
||||
fn authorization_header(&self) -> Result<String> {
|
||||
let guard = self.credentials.read();
|
||||
let credentials = guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("No credentials provided"))?;
|
||||
|
||||
Ok(format!(
|
||||
"{} {}",
|
||||
credentials.user_id, credentials.access_token
|
||||
))
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
&self,
|
||||
req: request::Builder,
|
||||
body: impl Into<AsyncBody>,
|
||||
) -> Result<Request<AsyncBody>> {
|
||||
Ok(req
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", self.authorization_header()?)
|
||||
.body(body.into())?)
|
||||
}
|
||||
|
||||
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
|
||||
let request = self.build_request(
|
||||
Request::builder().method(Method::GET).uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/users/me", &[])?
|
||||
.as_ref(),
|
||||
),
|
||||
AsyncBody::default(),
|
||||
)?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
|
||||
response.status()
|
||||
)
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
pub async fn accept_terms_of_service(&self) -> Result<AcceptTermsOfServiceResponse> {
|
||||
let request = self.build_request(
|
||||
Request::builder().method(Method::POST).uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/terms_of_service/accept", &[])?
|
||||
.as_ref(),
|
||||
),
|
||||
AsyncBody::default(),
|
||||
)?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to accept terms of service.\nStatus: {:?}\nBody: {body}",
|
||||
response.status()
|
||||
)
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
pub async fn create_llm_token(
|
||||
&self,
|
||||
system_id: Option<String>,
|
||||
) -> Result<CreateLlmTokenResponse> {
|
||||
let mut request_builder = Request::builder().method(Method::POST).uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/llm_tokens", &[])?
|
||||
.as_ref(),
|
||||
);
|
||||
|
||||
if let Some(system_id) = system_id {
|
||||
request_builder = request_builder.header(ZED_SYSTEM_ID_HEADER_NAME, system_id);
|
||||
}
|
||||
|
||||
let request = self.build_request(request_builder, AsyncBody::default())?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to create LLM token.\nStatus: {:?}\nBody: {body}",
|
||||
response.status()
|
||||
)
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_api_types"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "Apache-2.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_api_types.rs"
|
||||
|
||||
[dependencies]
|
||||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
serde.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions.workspace = true
|
||||
serde_json.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1,55 +0,0 @@
|
||||
mod timestamp;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use crate::timestamp::Timestamp;
|
||||
|
||||
pub const ZED_SYSTEM_ID_HEADER_NAME: &str = "x-zed-system-id";
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct GetAuthenticatedUserResponse {
|
||||
pub user: AuthenticatedUser,
|
||||
pub feature_flags: Vec<String>,
|
||||
pub plan: PlanInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AuthenticatedUser {
|
||||
pub id: i32,
|
||||
pub metrics_id: String,
|
||||
pub avatar_url: String,
|
||||
pub github_login: String,
|
||||
pub name: Option<String>,
|
||||
pub is_staff: bool,
|
||||
pub accepted_tos_at: Option<Timestamp>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PlanInfo {
|
||||
pub plan: cloud_llm_client::Plan,
|
||||
pub subscription_period: Option<SubscriptionPeriod>,
|
||||
pub usage: cloud_llm_client::CurrentUsage,
|
||||
pub trial_started_at: Option<Timestamp>,
|
||||
pub is_usage_based_billing_enabled: bool,
|
||||
pub is_account_too_young: bool,
|
||||
pub has_overdue_invoices: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct SubscriptionPeriod {
|
||||
pub started_at: Timestamp,
|
||||
pub ended_at: Timestamp,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AcceptTermsOfServiceResponse {
|
||||
pub user: AuthenticatedUser,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToken(pub String);
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateLlmTokenResponse {
|
||||
pub token: LlmToken,
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
use chrono::{DateTime, NaiveDateTime, SecondsFormat, Utc};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
/// A timestamp with a serialized representation in RFC 3339 format.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
||||
pub struct Timestamp(pub DateTime<Utc>);
|
||||
|
||||
impl Timestamp {
|
||||
pub fn new(datetime: DateTime<Utc>) -> Self {
|
||||
Self(datetime)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DateTime<Utc>> for Timestamp {
|
||||
fn from(value: DateTime<Utc>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NaiveDateTime> for Timestamp {
|
||||
fn from(value: NaiveDateTime) -> Self {
|
||||
Self(value.and_utc())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Timestamp {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let rfc3339_string = self.0.to_rfc3339_opts(SecondsFormat::Millis, true);
|
||||
serializer.serialize_str(&rfc3339_string)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Timestamp {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = String::deserialize(deserializer)?;
|
||||
let datetime = DateTime::parse_from_rfc3339(&value)
|
||||
.map_err(serde::de::Error::custom)?
|
||||
.to_utc();
|
||||
Ok(Self(datetime))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::NaiveDate;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_serialization() {
|
||||
let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
let timestamp = Timestamp::new(datetime);
|
||||
|
||||
let json = serde_json::to_string(×tamp).unwrap();
|
||||
assert_eq!(json, "\"2023-12-25T14:30:45.123Z\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization() {
|
||||
let json = "\"2023-12-25T14:30:45.123Z\"";
|
||||
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
|
||||
|
||||
let expected = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_roundtrip() {
|
||||
let original = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
|
||||
let timestamp = Timestamp::new(original);
|
||||
let json = serde_json::to_string(×tamp).unwrap();
|
||||
let deserialized: Timestamp = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.0, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_from_datetime_utc() {
|
||||
let datetime = DateTime::parse_from_rfc3339("2023-12-25T14:30:45.123Z")
|
||||
.unwrap()
|
||||
.to_utc();
|
||||
|
||||
let timestamp = Timestamp::from(datetime);
|
||||
assert_eq!(timestamp.0, datetime);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_from_naive_datetime() {
|
||||
let naive_dt = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_milli_opt(14, 30, 45, 123)
|
||||
.unwrap();
|
||||
|
||||
let timestamp = Timestamp::from(naive_dt);
|
||||
let expected = naive_dt.and_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_serialization_with_microseconds() {
|
||||
// Test that microseconds are truncated to milliseconds
|
||||
let datetime = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_micro_opt(14, 30, 45, 123456)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
let timestamp = Timestamp::new(datetime);
|
||||
let json = serde_json::to_string(×tamp).unwrap();
|
||||
|
||||
// Should be truncated to milliseconds
|
||||
assert_eq!(json, "\"2023-12-25T14:30:45.123Z\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization_without_milliseconds() {
|
||||
let json = "\"2023-12-25T14:30:45Z\"";
|
||||
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
|
||||
|
||||
let expected = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_opt(14, 30, 45)
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization_with_timezone() {
|
||||
let json = "\"2023-12-25T14:30:45.123+05:30\"";
|
||||
let timestamp: Timestamp = serde_json::from_str(json).unwrap();
|
||||
|
||||
// Should be converted to UTC
|
||||
let expected = NaiveDate::from_ymd_opt(2023, 12, 25)
|
||||
.unwrap()
|
||||
.and_hms_milli_opt(9, 0, 45, 123) // 14:30:45 + 5:30 = 20:00:45, but we want UTC so subtract 5:30
|
||||
.unwrap()
|
||||
.and_utc();
|
||||
|
||||
assert_eq!(timestamp.0, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_deserialization_with_invalid_format() {
|
||||
let json = "\"invalid-date\"";
|
||||
let result: Result<Timestamp, _> = serde_json::from_str(json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_llm_client"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
license = "Apache-2.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_llm_client.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
serde = { workspace = true, features = ["derive", "rc"] }
|
||||
serde_json.workspace = true
|
||||
strum = { workspace = true, features = ["derive"] }
|
||||
uuid = { workspace = true, features = ["serde"] }
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1,370 +0,0 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::{Display, EnumIter, EnumString};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// The name of the header used to indicate which version of Zed the client is running.
|
||||
pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
|
||||
|
||||
/// The name of the header used to indicate when a request failed due to an
|
||||
/// expired LLM token.
|
||||
///
|
||||
/// The client may use this as a signal to refresh the token.
|
||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||
|
||||
/// The name of the header used to indicate what plan the user is currently on.
|
||||
pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
|
||||
|
||||
/// The name of the header used to indicate the usage limit for model requests.
|
||||
pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
|
||||
|
||||
/// The name of the header used to indicate the usage amount for model requests.
|
||||
pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
|
||||
|
||||
/// The name of the header used to indicate the usage limit for edit predictions.
|
||||
pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
|
||||
|
||||
/// The name of the header used to indicate the usage amount for edit predictions.
|
||||
pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
|
||||
|
||||
/// The name of the header used to indicate the resource for which the subscription limit has been reached.
|
||||
pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
|
||||
|
||||
pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
|
||||
pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
|
||||
|
||||
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
|
||||
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
|
||||
|
||||
/// The name of the header used to indicate the the minimum required Zed version.
|
||||
///
|
||||
/// This can be used to force a Zed upgrade in order to continue communicating
|
||||
/// with the LLM service.
|
||||
pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
|
||||
|
||||
/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
|
||||
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
|
||||
"x-zed-client-supports-status-messages";
|
||||
|
||||
/// The name of the header used by the server to indicate to the client that it supports sending status messages.
|
||||
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
|
||||
"x-zed-server-supports-status-messages";
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UsageLimit {
|
||||
Limited(i32),
|
||||
Unlimited,
|
||||
}
|
||||
|
||||
impl FromStr for UsageLimit {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value {
|
||||
"unlimited" => Ok(Self::Unlimited),
|
||||
limit => limit
|
||||
.parse::<i32>()
|
||||
.map(Self::Limited)
|
||||
.context("failed to parse limit"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Plan {
|
||||
#[default]
|
||||
#[serde(alias = "Free")]
|
||||
ZedFree,
|
||||
#[serde(alias = "ZedPro")]
|
||||
ZedPro,
|
||||
#[serde(alias = "ZedProTrial")]
|
||||
ZedProTrial,
|
||||
}
|
||||
|
||||
impl Plan {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Plan::ZedFree => "zed_free",
|
||||
Plan::ZedPro => "zed_pro",
|
||||
Plan::ZedProTrial => "zed_pro_trial",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_requests_limit(&self) -> UsageLimit {
|
||||
match self {
|
||||
Plan::ZedPro => UsageLimit::Limited(500),
|
||||
Plan::ZedProTrial => UsageLimit::Limited(150),
|
||||
Plan::ZedFree => UsageLimit::Limited(50),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_predictions_limit(&self) -> UsageLimit {
|
||||
match self {
|
||||
Plan::ZedPro => UsageLimit::Unlimited,
|
||||
Plan::ZedProTrial => UsageLimit::Unlimited,
|
||||
Plan::ZedFree => UsageLimit::Limited(2_000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Plan {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value {
|
||||
"zed_free" => Ok(Plan::ZedFree),
|
||||
"zed_pro" => Ok(Plan::ZedPro),
|
||||
"zed_pro_trial" => Ok(Plan::ZedProTrial),
|
||||
plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum LanguageModelProvider {
|
||||
Anthropic,
|
||||
OpenAi,
|
||||
Google,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsBody {
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub outline: Option<String>,
|
||||
pub input_events: String,
|
||||
pub input_excerpt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub speculated_output: Option<String>,
|
||||
/// Whether the user provided consent for sampling this interaction.
|
||||
#[serde(default, alias = "data_collection_permission")]
|
||||
pub can_collect_data: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsResponse {
|
||||
pub request_id: Uuid,
|
||||
pub output_excerpt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AcceptEditPredictionBody {
|
||||
pub request_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionMode {
|
||||
Normal,
|
||||
Max,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionIntent {
|
||||
UserPrompt,
|
||||
ToolResults,
|
||||
ThreadSummarization,
|
||||
ThreadContextSummarization,
|
||||
CreateFile,
|
||||
EditFile,
|
||||
InlineAssist,
|
||||
TerminalInlineAssist,
|
||||
GenerateGitCommitMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CompletionBody {
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub thread_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub prompt_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub intent: Option<CompletionIntent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub mode: Option<CompletionMode>,
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub provider_request: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionRequestStatus {
|
||||
Queued {
|
||||
position: usize,
|
||||
},
|
||||
Started,
|
||||
Failed {
|
||||
code: String,
|
||||
message: String,
|
||||
request_id: Uuid,
|
||||
/// Retry duration in seconds.
|
||||
retry_after: Option<f64>,
|
||||
},
|
||||
UsageUpdated {
|
||||
amount: usize,
|
||||
limit: UsageLimit,
|
||||
},
|
||||
ToolUseLimitReached,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionEvent<T> {
|
||||
Status(CompletionRequestStatus),
|
||||
Event(T),
|
||||
}
|
||||
|
||||
impl<T> CompletionEvent<T> {
|
||||
pub fn into_status(self) -> Option<CompletionRequestStatus> {
|
||||
match self {
|
||||
Self::Status(status) => Some(status),
|
||||
Self::Event(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_event(self) -> Option<T> {
|
||||
match self {
|
||||
Self::Event(event) => Some(event),
|
||||
Self::Status(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct WebSearchBody {
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct WebSearchResponse {
|
||||
pub results: Vec<WebSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct WebSearchResult {
|
||||
pub title: String,
|
||||
pub url: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CountTokensBody {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub provider_request: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CountTokensResponse {
|
||||
pub tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelId(pub Arc<str>);
|
||||
|
||||
impl std::fmt::Display for LanguageModelId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct LanguageModel {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub id: LanguageModelId,
|
||||
pub display_name: String,
|
||||
pub max_token_count: usize,
|
||||
pub max_token_count_in_max_mode: Option<usize>,
|
||||
pub max_output_tokens: usize,
|
||||
pub supports_tools: bool,
|
||||
pub supports_images: bool,
|
||||
pub supports_thinking: bool,
|
||||
pub supports_max_mode: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub models: Vec<LanguageModel>,
|
||||
pub default_model: LanguageModelId,
|
||||
pub default_fast_model: LanguageModelId,
|
||||
pub recommended_models: Vec<LanguageModelId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GetSubscriptionResponse {
|
||||
pub plan: Plan,
|
||||
pub usage: Option<CurrentUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CurrentUsage {
|
||||
pub model_requests: UsageData,
|
||||
pub edit_predictions: UsageData,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct UsageData {
|
||||
pub used: u32,
|
||||
pub limit: UsageLimit,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_plan_deserialize_snake_case() {
|
||||
let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedFree);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedPro);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedProTrial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_deserialize_aliases() {
|
||||
let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedFree);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedPro);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedProTrial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_limit_from_str() {
|
||||
let limit = UsageLimit::from_str("unlimited").unwrap();
|
||||
assert!(matches!(limit, UsageLimit::Unlimited));
|
||||
|
||||
let limit = UsageLimit::from_str(&0.to_string()).unwrap();
|
||||
assert!(matches!(limit, UsageLimit::Limited(0)));
|
||||
|
||||
let limit = UsageLimit::from_str(&50.to_string()).unwrap();
|
||||
assert!(matches!(limit, UsageLimit::Limited(50)));
|
||||
|
||||
for value in ["not_a_number", "50xyz"] {
|
||||
let limit = UsageLimit::from_str(value);
|
||||
assert!(limit.is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,14 +23,13 @@ async-stripe.workspace = true
|
||||
async-trait.workspace = true
|
||||
async-tungstenite.workspace = true
|
||||
aws-config = { version = "1.1.5" }
|
||||
aws-sdk-kinesis = "1.51.0"
|
||||
aws-sdk-s3 = { version = "1.15.0" }
|
||||
aws-sdk-kinesis = "1.51.0"
|
||||
axum = { version = "0.6", features = ["json", "headers", "ws"] }
|
||||
axum-extra = { version = "0.4", features = ["erased-json"] }
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
@@ -76,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
agent_settings.workspace = true
|
||||
|
||||
@@ -100,6 +100,7 @@ impl std::fmt::Display for SystemIdHeader {
|
||||
|
||||
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/user", get(update_or_create_authenticated_user))
|
||||
.route("/users/look_up", get(look_up_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
||||
@@ -144,6 +145,48 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
|
||||
Ok::<_, Error>(next.run(req).await)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AuthenticatedUserParams {
|
||||
github_user_id: i32,
|
||||
github_login: String,
|
||||
github_email: Option<String>,
|
||||
github_name: Option<String>,
|
||||
github_user_created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AuthenticatedUserResponse {
|
||||
user: User,
|
||||
metrics_id: String,
|
||||
feature_flags: Vec<String>,
|
||||
}
|
||||
|
||||
async fn update_or_create_authenticated_user(
|
||||
Query(params): Query<AuthenticatedUserParams>,
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
) -> Result<Json<AuthenticatedUserResponse>> {
|
||||
let initial_channel_id = app.config.auto_join_channel_id;
|
||||
|
||||
let user = app
|
||||
.db
|
||||
.update_or_create_user_by_github_account(
|
||||
¶ms.github_login,
|
||||
params.github_user_id,
|
||||
params.github_email.as_deref(),
|
||||
params.github_name.as_deref(),
|
||||
params.github_user_created_at,
|
||||
initial_channel_id,
|
||||
)
|
||||
.await?;
|
||||
let metrics_id = app.db.get_user_metrics_id(user.id).await?;
|
||||
let feature_flags = app.db.get_user_flags(user.id).await?;
|
||||
Ok(Json(AuthenticatedUserResponse {
|
||||
user,
|
||||
metrics_id,
|
||||
feature_flags,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LookUpUserParams {
|
||||
identifier: String,
|
||||
@@ -310,9 +353,9 @@ async fn refresh_llm_tokens(
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UpdatePlanBody {
|
||||
pub plan: cloud_llm_client::Plan,
|
||||
pub plan: zed_llm_client::Plan,
|
||||
pub subscription_period: SubscriptionPeriod,
|
||||
pub usage: cloud_llm_client::CurrentUsage,
|
||||
pub usage: zed_llm_client::CurrentUsage,
|
||||
pub trial_started_at: Option<DateTime<Utc>>,
|
||||
pub is_usage_based_billing_enabled: bool,
|
||||
pub is_account_too_young: bool,
|
||||
@@ -334,9 +377,9 @@ async fn update_plan(
|
||||
extract::Json(body): extract::Json<UpdatePlanBody>,
|
||||
) -> Result<Json<UpdatePlanResponse>> {
|
||||
let plan = match body.plan {
|
||||
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
let update_user_plan = proto::UpdateUserPlan {
|
||||
@@ -368,15 +411,15 @@ async fn update_plan(
|
||||
Ok(Json(UpdatePlanResponse {}))
|
||||
}
|
||||
|
||||
fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||
proto::UsageLimit {
|
||||
variant: Some(match limit {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use anyhow::{Context as _, bail};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::{HashMap, HashSet};
|
||||
use sea_orm::ActiveValue;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
|
||||
use util::{ResultExt, maybe};
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::{
|
||||
@@ -87,14 +87,6 @@ async fn poll_stripe_events(
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
real_stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
let sync_events_using_cloud = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
|
||||
if sync_events_using_cloud {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
||||
// so we need to unquote it.
|
||||
@@ -577,14 +569,6 @@ async fn sync_model_request_usage_with_stripe(
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
let sync_model_request_usage_using_cloud = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
|
||||
if sync_model_request_usage_using_cloud {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::info!("Stripe usage sync: Starting");
|
||||
let started_at = Utc::now();
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use axum::{
|
||||
use chrono::{NaiveDateTime, SecondsFormat};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::AuthenticatedUserParams;
|
||||
use crate::db::ContributorSelector;
|
||||
use crate::{AppState, Result};
|
||||
|
||||
@@ -103,18 +104,9 @@ impl RenovateBot {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AddContributorBody {
|
||||
github_user_id: i32,
|
||||
github_login: String,
|
||||
github_email: Option<String>,
|
||||
github_name: Option<String>,
|
||||
github_user_created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
async fn add_contributor(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(params): extract::Json<AddContributorBody>,
|
||||
extract::Json(params): extract::Json<AuthenticatedUserParams>,
|
||||
) -> Result<()> {
|
||||
let initial_channel_id = app.config.auto_join_channel_id;
|
||||
app.db
|
||||
|
||||
@@ -95,7 +95,7 @@ pub enum SubscriptionKind {
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
impl From<SubscriptionKind> for cloud_llm_client::Plan {
|
||||
impl From<SubscriptionKind> for zed_llm_client::Plan {
|
||||
fn from(value: SubscriptionKind) -> Self {
|
||||
match value {
|
||||
SubscriptionKind::ZedPro => Self::ZedPro,
|
||||
|
||||
@@ -6,11 +6,11 @@ mod tables;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::HashMap;
|
||||
pub use ids::*;
|
||||
pub use seed::*;
|
||||
pub use tables::*;
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use tests::TestLlmDb;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use pretty_assertions::assert_eq;
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::test_llm_db;
|
||||
|
||||
@@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA
|
||||
use crate::{Config, db::billing_preference};
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use cloud_llm_client::Plan;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::Plan;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
@@ -23,7 +23,6 @@ use anyhow::{Context as _, anyhow, bail};
|
||||
use async_tungstenite::tungstenite::{
|
||||
Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
|
||||
};
|
||||
use axum::headers::UserAgent;
|
||||
use axum::{
|
||||
Extension, Router, TypedHeader,
|
||||
body::Body,
|
||||
@@ -42,7 +41,7 @@ use collections::{HashMap, HashSet};
|
||||
pub use connection_pool::{ConnectionPool, ZedVersion};
|
||||
use core::fmt::{self, Debug, Formatter};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use rpc::proto::{MultiLspQuery, split_repository_update};
|
||||
use rpc::proto::split_repository_update;
|
||||
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
|
||||
|
||||
use futures::{
|
||||
@@ -374,7 +373,7 @@ impl Server {
|
||||
.add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
|
||||
.add_request_handler(multi_lsp_query)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
|
||||
.add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
|
||||
@@ -751,7 +750,6 @@ impl Server {
|
||||
address: String,
|
||||
principal: Principal,
|
||||
zed_version: ZedVersion,
|
||||
user_agent: Option<String>,
|
||||
geoip_country_code: Option<String>,
|
||||
system_id: Option<String>,
|
||||
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
|
||||
@@ -764,14 +762,9 @@ impl Server {
|
||||
user_id=field::Empty,
|
||||
login=field::Empty,
|
||||
impersonator=field::Empty,
|
||||
user_agent=field::Empty,
|
||||
geoip_country_code=field::Empty
|
||||
);
|
||||
principal.update_span(&span);
|
||||
if let Some(user_agent) = user_agent {
|
||||
span.record("user_agent", user_agent);
|
||||
}
|
||||
|
||||
if let Some(country_code) = geoip_country_code.as_ref() {
|
||||
span.record("geoip_country_code", country_code);
|
||||
}
|
||||
@@ -838,7 +831,7 @@ impl Server {
|
||||
// This arrangement ensures we will attempt to process earlier messages first, but fall
|
||||
// back to processing messages arrived later in the spirit of making progress.
|
||||
let mut foreground_message_handlers = FuturesUnordered::new();
|
||||
let concurrent_handlers = Arc::new(Semaphore::new(256));
|
||||
let concurrent_handlers = Arc::new(Semaphore::new(512));
|
||||
loop {
|
||||
let next_message = async {
|
||||
let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
|
||||
@@ -865,7 +858,6 @@ impl Server {
|
||||
user_id=field::Empty,
|
||||
login=field::Empty,
|
||||
impersonator=field::Empty,
|
||||
multi_lsp_query_request=field::Empty,
|
||||
);
|
||||
principal.update_span(&span);
|
||||
let span_enter = span.enter();
|
||||
@@ -1180,7 +1172,6 @@ pub async fn handle_websocket_request(
|
||||
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
|
||||
Extension(server): Extension<Arc<Server>>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
user_agent: Option<TypedHeader<UserAgent>>,
|
||||
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
|
||||
system_id_header: Option<TypedHeader<SystemIdHeader>>,
|
||||
ws: WebSocketUpgrade,
|
||||
@@ -1236,7 +1227,6 @@ pub async fn handle_websocket_request(
|
||||
socket_address,
|
||||
principal,
|
||||
version,
|
||||
user_agent.map(|header| header.to_string()),
|
||||
country_code_header.map(|header| header.to_string()),
|
||||
system_id_header.map(|header| header.to_string()),
|
||||
None,
|
||||
@@ -2330,15 +2320,6 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn multi_lsp_query(
|
||||
request: MultiLspQuery,
|
||||
response: Response<MultiLspQuery>,
|
||||
session: Session,
|
||||
) -> Result<()> {
|
||||
tracing::Span::current().record("multi_lsp_query_request", request.request_str());
|
||||
forward_mutating_project_request(request, response, session).await
|
||||
}
|
||||
|
||||
/// Notify other participants that a new buffer has been created
|
||||
async fn create_buffer_for_peer(
|
||||
request: proto::CreateBufferForPeer,
|
||||
@@ -2878,12 +2859,12 @@ async fn make_update_user_plan_message(
|
||||
}
|
||||
|
||||
fn model_requests_limit(
|
||||
plan: cloud_llm_client::Plan,
|
||||
plan: zed_llm_client::Plan,
|
||||
feature_flags: &Vec<String>,
|
||||
) -> cloud_llm_client::UsageLimit {
|
||||
) -> zed_llm_client::UsageLimit {
|
||||
match plan.model_requests_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial
|
||||
&& feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
||||
@@ -2893,9 +2874,9 @@ fn model_requests_limit(
|
||||
limit
|
||||
};
|
||||
|
||||
cloud_llm_client::UsageLimit::Limited(limit)
|
||||
zed_llm_client::UsageLimit::Limited(limit)
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
|
||||
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2905,21 +2886,21 @@ fn subscription_usage_to_proto(
|
||||
feature_flags: &Vec<String>,
|
||||
) -> proto::SubscriptionUsage {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: usage.model_requests as u32,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit(plan, feature_flags) {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
@@ -2927,12 +2908,12 @@ fn subscription_usage_to_proto(
|
||||
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
@@ -2945,21 +2926,21 @@ fn make_default_subscription_usage(
|
||||
feature_flags: &Vec<String>,
|
||||
) -> proto::SubscriptionUsage {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: 0,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit(plan, feature_flags) {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
@@ -2967,12 +2948,12 @@ fn make_default_subscription_usage(
|
||||
edit_predictions_usage_amount: 0,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -38,12 +38,12 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
|
||||
let mut remote = room
|
||||
.remote_participants()
|
||||
.values()
|
||||
.map(|participant| participant.user.github_login.clone().to_string())
|
||||
.map(|participant| participant.user.github_login.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let mut pending = room
|
||||
.pending_participants()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone().to_string())
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect::<Vec<_>>();
|
||||
remote.sort();
|
||||
pending.sort();
|
||||
|
||||
@@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room(
|
||||
|
||||
// Allow user A to reconnect to the server.
|
||||
server.allow_connections();
|
||||
executor.advance_clock(RECONNECT_TIMEOUT);
|
||||
executor.advance_clock(RECEIVE_TIMEOUT);
|
||||
|
||||
// Call user B again from client A.
|
||||
active_call_a
|
||||
@@ -1286,7 +1286,7 @@ async fn test_calls_on_multiple_connections(
|
||||
client_b1.disconnect(&cx_b1.to_async());
|
||||
executor.advance_clock(RECEIVE_TIMEOUT);
|
||||
client_b1
|
||||
.connect(false, &cx_b1.to_async())
|
||||
.authenticate_and_connect(false, &cx_b1.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections(
|
||||
|
||||
// User A reconnects automatically, then calls user B again.
|
||||
server.allow_connections();
|
||||
executor.advance_clock(RECONNECT_TIMEOUT);
|
||||
executor.advance_clock(RECEIVE_TIMEOUT);
|
||||
active_call_a
|
||||
.update(cx_a, |call, cx| {
|
||||
call.invite(client_b1.user_id().unwrap(), None, cx)
|
||||
@@ -1667,7 +1667,7 @@ async fn test_project_reconnect(
|
||||
// Client A reconnects. Their project is re-shared, and client B re-joins it.
|
||||
server.allow_connections();
|
||||
client_a
|
||||
.connect(false, &cx_a.to_async())
|
||||
.authenticate_and_connect(false, &cx_a.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -1796,7 +1796,7 @@ async fn test_project_reconnect(
|
||||
// Client B reconnects. They re-join the room and the remaining shared project.
|
||||
server.allow_connections();
|
||||
client_b
|
||||
.connect(false, &cx_b.to_async())
|
||||
.authenticate_and_connect(false, &cx_b.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -1881,7 +1881,7 @@ async fn test_active_call_events(
|
||||
vec![room::Event::RemoteProjectShared {
|
||||
owner: Arc::new(User {
|
||||
id: client_a.user_id().unwrap(),
|
||||
github_login: "user_a".into(),
|
||||
github_login: "user_a".to_string(),
|
||||
avatar_uri: "avatar_a".into(),
|
||||
name: None,
|
||||
}),
|
||||
@@ -1900,7 +1900,7 @@ async fn test_active_call_events(
|
||||
vec![room::Event::RemoteProjectShared {
|
||||
owner: Arc::new(User {
|
||||
id: client_b.user_id().unwrap(),
|
||||
github_login: "user_b".into(),
|
||||
github_login: "user_b".to_string(),
|
||||
avatar_uri: "avatar_b".into(),
|
||||
name: None,
|
||||
}),
|
||||
@@ -5738,7 +5738,7 @@ async fn test_contacts(
|
||||
|
||||
server.allow_connections();
|
||||
client_c
|
||||
.connect(false, &cx_c.to_async())
|
||||
.authenticate_and_connect(false, &cx_c.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -6079,7 +6079,7 @@ async fn test_contacts(
|
||||
.iter()
|
||||
.map(|contact| {
|
||||
(
|
||||
contact.user.github_login.clone().to_string(),
|
||||
contact.user.github_login.clone(),
|
||||
if contact.online { "online" } else { "offline" },
|
||||
if contact.busy { "busy" } else { "free" },
|
||||
)
|
||||
@@ -6269,7 +6269,7 @@ async fn test_contact_requests(
|
||||
client.disconnect(&cx.to_async());
|
||||
client.clear_contacts(cx).await;
|
||||
client
|
||||
.connect(false, &cx.to_async())
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::sync::Arc;
|
||||
use gpui::{BackgroundExecutor, TestAppContext};
|
||||
use notifications::NotificationEvent;
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rpc::{Notification, proto};
|
||||
|
||||
use crate::tests::TestServer;
|
||||
@@ -18,9 +17,6 @@ async fn test_notifications(
|
||||
let client_a = server.create_client(cx_a, "user_a").await;
|
||||
let client_b = server.create_client(cx_b, "user_b").await;
|
||||
|
||||
// Wait for authentication/connection to Collab to be established.
|
||||
executor.run_until_parked();
|
||||
|
||||
let notification_events_a = Arc::new(Mutex::new(Vec::new()));
|
||||
let notification_events_b = Arc::new(Mutex::new(Vec::new()));
|
||||
client_a.notification_store().update(cx_a, |_, cx| {
|
||||
|
||||
@@ -8,7 +8,6 @@ use crate::{
|
||||
use anyhow::anyhow;
|
||||
use call::ActiveCall;
|
||||
use channel::{ChannelBuffer, ChannelStore};
|
||||
use client::test::{make_get_authenticated_user_response, parse_authorization_header};
|
||||
use client::{
|
||||
self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
|
||||
proto::PeerId,
|
||||
@@ -21,7 +20,7 @@ use fs::FakeFs;
|
||||
use futures::{StreamExt as _, channel::oneshot};
|
||||
use git::GitHostingProviderRegistry;
|
||||
use gpui::{AppContext as _, BackgroundExecutor, Entity, Task, TestAppContext, VisualTestContext};
|
||||
use http_client::{FakeHttpClient, Method};
|
||||
use http_client::FakeHttpClient;
|
||||
use language::LanguageRegistry;
|
||||
use node_runtime::NodeRuntime;
|
||||
use notifications::NotificationStore;
|
||||
@@ -162,8 +161,6 @@ impl TestServer {
|
||||
}
|
||||
|
||||
pub async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> TestClient {
|
||||
const ACCESS_TOKEN: &str = "the-token";
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
|
||||
cx.update(|cx| {
|
||||
@@ -178,7 +175,7 @@ impl TestServer {
|
||||
});
|
||||
|
||||
let clock = Arc::new(FakeSystemClock::new());
|
||||
|
||||
let http = FakeHttpClient::with_404_response();
|
||||
let user_id = if let Ok(Some(user)) = self.app_state.db.get_user_by_github_login(name).await
|
||||
{
|
||||
user.id
|
||||
@@ -200,47 +197,6 @@ impl TestServer {
|
||||
.expect("creating user failed")
|
||||
.user_id
|
||||
};
|
||||
|
||||
let http = FakeHttpClient::create({
|
||||
let name = name.to_string();
|
||||
move |req| {
|
||||
let name = name.clone();
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/client/users/me") => {
|
||||
let credentials = parse_authorization_header(&req);
|
||||
if credentials
|
||||
!= Some(Credentials {
|
||||
user_id: user_id.to_proto(),
|
||||
access_token: ACCESS_TOKEN.into(),
|
||||
})
|
||||
{
|
||||
return Ok(http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
Ok(http_client::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
serde_json::to_string(&make_get_authenticated_user_response(
|
||||
user_id.0, name,
|
||||
))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
_ => Ok(http_client::Response::builder()
|
||||
.status(404)
|
||||
.body("Not Found".into())
|
||||
.unwrap()),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let client_name = name.to_string();
|
||||
let mut client = cx.update(|cx| Client::new(clock, http.clone(), cx));
|
||||
let server = self.server.clone();
|
||||
@@ -252,10 +208,11 @@ impl TestServer {
|
||||
.unwrap()
|
||||
.set_id(user_id.to_proto())
|
||||
.override_authenticate(move |cx| {
|
||||
let access_token = "the-token".to_string();
|
||||
cx.spawn(async move |_| {
|
||||
Ok(Credentials {
|
||||
user_id: user_id.to_proto(),
|
||||
access_token: ACCESS_TOKEN.into(),
|
||||
access_token,
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -264,7 +221,7 @@ impl TestServer {
|
||||
credentials,
|
||||
&Credentials {
|
||||
user_id: user_id.0 as u64,
|
||||
access_token: ACCESS_TOKEN.into(),
|
||||
access_token: "the-token".into()
|
||||
}
|
||||
);
|
||||
|
||||
@@ -299,7 +256,6 @@ impl TestServer {
|
||||
ZedVersion(SemanticVersion::new(1, 0, 0)),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(connection_id_tx),
|
||||
Executor::Deterministic(cx.background_executor().clone()),
|
||||
None,
|
||||
@@ -362,7 +318,7 @@ impl TestServer {
|
||||
});
|
||||
|
||||
client
|
||||
.connect(false, &cx.to_async())
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
@@ -735,17 +691,17 @@ impl TestClient {
|
||||
current: store
|
||||
.contacts()
|
||||
.iter()
|
||||
.map(|contact| contact.user.github_login.clone().to_string())
|
||||
.map(|contact| contact.user.github_login.clone())
|
||||
.collect(),
|
||||
outgoing_requests: store
|
||||
.outgoing_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone().to_string())
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect(),
|
||||
incoming_requests: store
|
||||
.incoming_contact_requests()
|
||||
.iter()
|
||||
.map(|user| user.github_login.clone().to_string())
|
||||
.map(|user| user.github_login.clone())
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -940,7 +940,7 @@ impl CollabPanel {
|
||||
room.read(cx).local_participant().role == proto::ChannelRole::Admin
|
||||
});
|
||||
|
||||
ListItem::new(user.github_login.clone())
|
||||
ListItem::new(SharedString::from(user.github_login.clone()))
|
||||
.start_slot(Avatar::new(user.avatar_uri.clone()))
|
||||
.child(Label::new(user.github_login.clone()))
|
||||
.toggle_state(is_selected)
|
||||
@@ -2331,7 +2331,7 @@ impl CollabPanel {
|
||||
let client = this.client.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
client
|
||||
.connect(true, &cx)
|
||||
.authenticate_and_connect(true, &cx)
|
||||
.await
|
||||
.into_response()
|
||||
.notify_async_err(cx);
|
||||
@@ -2583,7 +2583,7 @@ impl CollabPanel {
|
||||
) -> impl IntoElement {
|
||||
let online = contact.online;
|
||||
let busy = contact.busy || calling;
|
||||
let github_login = contact.user.github_login.clone();
|
||||
let github_login = SharedString::from(contact.user.github_login.clone());
|
||||
let item = ListItem::new(github_login.clone())
|
||||
.indent_level(1)
|
||||
.indent_step_size(px(20.))
|
||||
@@ -2662,7 +2662,7 @@ impl CollabPanel {
|
||||
is_selected: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let github_login = user.github_login.clone();
|
||||
let github_login = SharedString::from(user.github_login.clone());
|
||||
let user_id = user.id;
|
||||
let is_response_pending = self.user_store.read(cx).is_contact_request_pending(user);
|
||||
let color = if is_response_pending {
|
||||
|
||||
@@ -634,13 +634,13 @@ impl Render for NotificationPanel {
|
||||
.child(Icon::new(IconName::Envelope)),
|
||||
)
|
||||
.map(|this| {
|
||||
if !self.client.status().borrow().is_connected() {
|
||||
if self.client.user_id().is_none() {
|
||||
this.child(
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.p_4()
|
||||
.child(
|
||||
Button::new("connect_prompt_button", "Connect")
|
||||
Button::new("sign_in_prompt_button", "Sign in")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
@@ -652,7 +652,10 @@ impl Render for NotificationPanel {
|
||||
let client = client.clone();
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
match client.connect(true, &cx).await {
|
||||
match client
|
||||
.authenticate_and_connect(true, &cx)
|
||||
.await
|
||||
{
|
||||
util::ConnectionResult::Timeout => {
|
||||
log::error!("Connection timeout");
|
||||
}
|
||||
@@ -670,7 +673,7 @@ impl Render for NotificationPanel {
|
||||
)
|
||||
.child(
|
||||
div().flex().w_full().items_center().child(
|
||||
Label::new("Connect to view notifications.")
|
||||
Label::new("Sign in to view notifications.")
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
|
||||
@@ -158,7 +158,6 @@ impl Client {
|
||||
pub fn stdio(
|
||||
server_id: ContextServerId,
|
||||
binary: ModelContextServerBinary,
|
||||
working_directory: &Option<PathBuf>,
|
||||
cx: AsyncApp,
|
||||
) -> Result<Self> {
|
||||
log::info!(
|
||||
@@ -173,7 +172,7 @@ impl Client {
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or_else(String::new);
|
||||
|
||||
let transport = Arc::new(StdioTransport::new(binary, working_directory, &cx)?);
|
||||
let transport = Arc::new(StdioTransport::new(binary, &cx)?);
|
||||
Self::new(server_id, server_name.into(), transport, cx)
|
||||
}
|
||||
|
||||
@@ -441,12 +440,14 @@ impl Client {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn on_notification(
|
||||
&self,
|
||||
method: &'static str,
|
||||
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||
) {
|
||||
self.notification_handlers.lock().insert(method, f);
|
||||
#[allow(unused)]
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.notification_handlers
|
||||
.lock()
|
||||
.insert(method, Box::new(f));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ impl std::fmt::Debug for ContextServerCommand {
|
||||
}
|
||||
|
||||
enum ContextServerTransport {
|
||||
Stdio(ContextServerCommand, Option<PathBuf>),
|
||||
Stdio(ContextServerCommand),
|
||||
Custom(Arc<dyn crate::transport::Transport>),
|
||||
}
|
||||
|
||||
@@ -64,18 +64,11 @@ pub struct ContextServer {
|
||||
}
|
||||
|
||||
impl ContextServer {
|
||||
pub fn stdio(
|
||||
id: ContextServerId,
|
||||
command: ContextServerCommand,
|
||||
working_directory: Option<Arc<Path>>,
|
||||
) -> Self {
|
||||
pub fn stdio(id: ContextServerId, command: ContextServerCommand) -> Self {
|
||||
Self {
|
||||
id,
|
||||
client: RwLock::new(None),
|
||||
configuration: ContextServerTransport::Stdio(
|
||||
command,
|
||||
working_directory.map(|directory| directory.to_path_buf()),
|
||||
),
|
||||
configuration: ContextServerTransport::Stdio(command),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,36 +88,15 @@ impl ContextServer {
|
||||
self.client.read().clone()
|
||||
}
|
||||
|
||||
pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
|
||||
self.initialize(self.new_client(cx)?).await
|
||||
}
|
||||
|
||||
/// Starts the context server, making sure handlers are registered before initialization happens
|
||||
pub async fn start_with_handlers(
|
||||
&self,
|
||||
notification_handlers: Vec<(
|
||||
&'static str,
|
||||
Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
|
||||
)>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<()> {
|
||||
let client = self.new_client(cx)?;
|
||||
for (method, handler) in notification_handlers {
|
||||
client.on_notification(method, handler);
|
||||
}
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
|
||||
Ok(match &self.configuration {
|
||||
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
|
||||
pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
|
||||
let client = match &self.configuration {
|
||||
ContextServerTransport::Stdio(command) => Client::stdio(
|
||||
client::ContextServerId(self.id.0.clone()),
|
||||
client::ModelContextServerBinary {
|
||||
executable: Path::new(&command.path).to_path_buf(),
|
||||
args: command.args.clone(),
|
||||
env: command.env.clone(),
|
||||
},
|
||||
working_directory,
|
||||
cx.clone(),
|
||||
)?,
|
||||
ContextServerTransport::Custom(transport) => Client::new(
|
||||
@@ -133,7 +105,8 @@ impl ContextServer {
|
||||
transport.clone(),
|
||||
cx.clone(),
|
||||
)?,
|
||||
})
|
||||
};
|
||||
self.initialize(client).await
|
||||
}
|
||||
|
||||
async fn initialize(&self, client: Client) -> Result<()> {
|
||||
|
||||
@@ -83,18 +83,14 @@ impl McpServer {
|
||||
}
|
||||
|
||||
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
|
||||
let mut settings = schemars::generate::SchemaSettings::draft07();
|
||||
settings.inline_subschemas = true;
|
||||
let mut generator = settings.into_generator();
|
||||
|
||||
let output_schema = generator.root_schema_for::<T::Output>();
|
||||
let unit_schema = generator.root_schema_for::<T::Output>();
|
||||
let output_schema = schemars::schema_for!(T::Output);
|
||||
let unit_schema = schemars::schema_for!(());
|
||||
|
||||
let registered_tool = RegisteredTool {
|
||||
tool: Tool {
|
||||
name: T::NAME.into(),
|
||||
description: Some(tool.description().into()),
|
||||
input_schema: generator.root_schema_for::<T::Input>().into(),
|
||||
input_schema: schemars::schema_for!(T::Input).into(),
|
||||
output_schema: if output_schema == unit_schema {
|
||||
None
|
||||
} else {
|
||||
|
||||
@@ -115,11 +115,10 @@ impl InitializedContextServerProtocol {
|
||||
self.inner.notify(T::METHOD, params)
|
||||
}
|
||||
|
||||
pub fn on_notification(
|
||||
&self,
|
||||
method: &'static str,
|
||||
f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
|
||||
) {
|
||||
pub fn on_notification<F>(&self, method: &'static str, f: F)
|
||||
where
|
||||
F: 'static + Send + FnMut(Value, AsyncApp),
|
||||
{
|
||||
self.inner.on_notification(method, f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
|
||||
use anyhow::{Context as _, Result};
|
||||
@@ -23,11 +22,7 @@ pub struct StdioTransport {
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub fn new(
|
||||
binary: ModelContextServerBinary,
|
||||
working_directory: &Option<PathBuf>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
pub fn new(binary: ModelContextServerBinary, cx: &AsyncApp) -> Result<Self> {
|
||||
let mut command = util::command::new_smol_command(&binary.executable);
|
||||
command
|
||||
.args(&binary.args)
|
||||
@@ -37,10 +32,6 @@ impl StdioTransport {
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
|
||||
if let Some(working_directory) = working_directory {
|
||||
command.current_dir(working_directory);
|
||||
}
|
||||
|
||||
let mut server = command.spawn().with_context(|| {
|
||||
format!(
|
||||
"failed to spawn command. (path={:?}, args={:?})",
|
||||
|
||||
@@ -295,7 +295,7 @@ mod tests {
|
||||
request: dap_types::StartDebuggingRequestArgumentsRequest::Launch,
|
||||
},
|
||||
},
|
||||
Box::new(|_| {}),
|
||||
Box::new(|_| panic!("Did not expect to hit this code path")),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
|
||||