Compare commits
194 Commits
close-agen
...
gpui-butto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1711f91f6 | ||
|
|
8fd2316deb | ||
|
|
7357796969 | ||
|
|
e26620d1cf | ||
|
|
9dabf491f0 | ||
|
|
f2dcc98216 | ||
|
|
23bbfc4b94 | ||
|
|
98aefcca83 | ||
|
|
9be1e9aab1 | ||
|
|
33b60bc16d | ||
|
|
0355b9dfab | ||
|
|
6bec76cd5d | ||
|
|
d4f47aa653 | ||
|
|
5112fcebeb | ||
|
|
dcf7f714f7 | ||
|
|
16f668b8e3 | ||
|
|
0f4e52bde8 | ||
|
|
dfe37b0a07 | ||
|
|
2da37988b5 | ||
|
|
05955e4faa | ||
|
|
1d043b37fb | ||
|
|
18d39e3f81 | ||
|
|
cc3a28a8e8 | ||
|
|
0f17e82154 | ||
|
|
a316428686 | ||
|
|
355266988d | ||
|
|
72007c9a62 | ||
|
|
c2feffac9d | ||
|
|
4b7b5db58c | ||
|
|
58ba833792 | ||
|
|
f021b401f4 | ||
|
|
47f6d4e5a7 | ||
|
|
e60f029525 | ||
|
|
d7b5c61ec8 | ||
|
|
23d42e3eaf | ||
|
|
b2fc4064c0 | ||
|
|
bba3db9378 | ||
|
|
5078f0b5ef | ||
|
|
607bfd3b1c | ||
|
|
87cb498a41 | ||
|
|
6420df3975 | ||
|
|
83498ebf2b | ||
|
|
1fb1fecb0a | ||
|
|
bc99a86bb7 | ||
|
|
fcfe4e2c14 | ||
|
|
ef511976be | ||
|
|
c80aaca0c5 | ||
|
|
234d6ce5f5 | ||
|
|
96a0568fb7 | ||
|
|
b6828e5ce8 | ||
|
|
78d3ce4090 | ||
|
|
d01559f9bc | ||
|
|
645f662853 | ||
|
|
d42cb111f4 | ||
|
|
dce6e96c16 | ||
|
|
4280bff10a | ||
|
|
ea5b289459 | ||
|
|
09503333af | ||
|
|
775370fd7d | ||
|
|
1077f2771e | ||
|
|
f4eea0db2e | ||
|
|
ed361ff6a2 | ||
|
|
7f9a365d8f | ||
|
|
255d8f7cf8 | ||
|
|
22f76ac1a7 | ||
|
|
25cc05b45c | ||
|
|
a4766e296f | ||
|
|
2f26a860a9 | ||
|
|
f1fe505649 | ||
|
|
9826b7b5c1 | ||
|
|
6fc9036063 | ||
|
|
2b74163a48 | ||
|
|
71ea7aee3b | ||
|
|
48b376fdc9 | ||
|
|
f98c6fb2cf | ||
|
|
1ace5a27bc | ||
|
|
dd6594621f | ||
|
|
68afe4fdda | ||
|
|
6f297132b4 | ||
|
|
8fe134e361 | ||
|
|
7aabbb0426 | ||
|
|
85c6a3dd0c | ||
|
|
81dcc12c62 | ||
|
|
1fd8fbe6d1 | ||
|
|
7eb226b3fc | ||
|
|
9426caa061 | ||
|
|
7cad943fde | ||
|
|
29da105dd5 | ||
|
|
8fdf309a4a | ||
|
|
f01af006e1 | ||
|
|
01488c4f91 | ||
|
|
18e911002f | ||
|
|
54c6d482b6 | ||
|
|
32c7fcd78c | ||
|
|
fff349a644 | ||
|
|
90c2d17042 | ||
|
|
c6e69fae17 | ||
|
|
e5d497ee08 | ||
|
|
229f3dab22 | ||
|
|
67f9da0846 | ||
|
|
ab455e1c43 | ||
|
|
986d271ea7 | ||
|
|
98a18e04f7 | ||
|
|
3ea86da16f | ||
|
|
3173f87dc3 | ||
|
|
6592314984 | ||
|
|
93b6fdb8e5 | ||
|
|
e79d1b27b1 | ||
|
|
1a0eedb787 | ||
|
|
8db0333b04 | ||
|
|
a13c8b70dd | ||
|
|
ddc649bdb8 | ||
|
|
33c896c23d | ||
|
|
19b6c4444e | ||
|
|
8e39281699 | ||
|
|
8294981ab5 | ||
|
|
a3105c92a4 | ||
|
|
a6c3d49bb9 | ||
|
|
5a38bbbd22 | ||
|
|
196586e352 | ||
|
|
a1d8e50ec1 | ||
|
|
24bc9fd0a0 | ||
|
|
03f02804e5 | ||
|
|
41b0a5cf10 | ||
|
|
739236e968 | ||
|
|
f14e48d202 | ||
|
|
634b275931 | ||
|
|
8000151aa9 | ||
|
|
f0f0a52793 | ||
|
|
907b2f0521 | ||
|
|
0ad582eec4 | ||
|
|
58ed81b698 | ||
|
|
83319c8a6d | ||
|
|
4deb8cce8d | ||
|
|
8d79226445 | ||
|
|
5abca0f867 | ||
|
|
68945ac53e | ||
|
|
49887d6934 | ||
|
|
d867897746 | ||
|
|
1f58ce80f2 | ||
|
|
ed772e6baf | ||
|
|
559725d8f5 | ||
|
|
f0da3b74f8 | ||
|
|
cee9f4b013 | ||
|
|
ae31aa2759 | ||
|
|
82a7aca5a6 | ||
|
|
b34f19a46f | ||
|
|
09ace088ac | ||
|
|
49ba4ed49c | ||
|
|
06af0310f7 | ||
|
|
1fa19c69a6 | ||
|
|
5ba1d3edec | ||
|
|
e4525b80f8 | ||
|
|
18a2a50227 | ||
|
|
172a475515 | ||
|
|
471e02d48f | ||
|
|
39da72161f | ||
|
|
daa777440d | ||
|
|
79ba22673b | ||
|
|
074e78301a | ||
|
|
fbeee1f832 | ||
|
|
bff259731f | ||
|
|
6c5b9b43c1 | ||
|
|
f29c6e5661 | ||
|
|
000077facf | ||
|
|
2b249f9e68 | ||
|
|
e13ecc07bc | ||
|
|
bef25c7290 | ||
|
|
65b13968a2 | ||
|
|
9afc6f6f5c | ||
|
|
82d271cb5b | ||
|
|
77ad6d7fbb | ||
|
|
d6ab416168 | ||
|
|
8f07135201 | ||
|
|
1dfddf0a29 | ||
|
|
cf8f003916 | ||
|
|
00292450e0 | ||
|
|
49c01c60b7 | ||
|
|
863d7ccb6d | ||
|
|
d270f6b953 | ||
|
|
08f516ce9a | ||
|
|
9cff5cfe3a | ||
|
|
0abee5668a | ||
|
|
c58b6903b8 | ||
|
|
11b6ce46e2 | ||
|
|
8c8357387e | ||
|
|
25ced2e3c2 | ||
|
|
f248da5921 | ||
|
|
89ce49d5b7 | ||
|
|
30f3efe697 | ||
|
|
023a60806a | ||
|
|
2c602bb0e5 | ||
|
|
857134d6dc | ||
|
|
d8980c25d2 |
4
.github/ISSUE_TEMPLATE/01_bug_agent.yml
vendored
4
.github/ISSUE_TEMPLATE/01_bug_agent.yml
vendored
@@ -29,8 +29,8 @@ body:
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: Copy System Specs Into Clipboard"
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -29,8 +29,8 @@ body:
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: Copy System Specs Into Clipboard"
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/03_bug_git.yml
vendored
4
.github/ISSUE_TEMPLATE/03_bug_git.yml
vendored
@@ -28,8 +28,8 @@ body:
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: Copy System Specs Into Clipboard"
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
35
.github/ISSUE_TEMPLATE/04_bug_debugger.yml
vendored
Normal file
35
.github/ISSUE_TEMPLATE/04_bug_debugger.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Bug Report (Debugger)
|
||||
description: Zed Debugger-Related Bugs
|
||||
type: "Bug"
|
||||
labels: ["debugger"]
|
||||
title: "Debugger: <a short description of the Debugger bug>"
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Describe the bug with a one line summary, and provide detailed reproduction steps
|
||||
value: |
|
||||
<!-- Please insert a one line summary of the issue below -->
|
||||
SUMMARY_SENTENCE_HERE
|
||||
|
||||
### Description
|
||||
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
|
||||
Steps to trigger the problem:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
Actual Behavior:
|
||||
Expected Behavior:
|
||||
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
4
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/10_bug_report.yml
vendored
@@ -49,8 +49,8 @@ body:
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: |
|
||||
Open Zed, from the command palette select "zed: Copy System Specs Into Clipboard"
|
||||
Open Zed, from the command palette select "zed: copy system specs into clipboard"
|
||||
placeholder: |
|
||||
Output of "zed: Copy System Specs Into Clipboard"
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/11_crash_report.yml
vendored
4
.github/ISSUE_TEMPLATE/11_crash_report.yml
vendored
@@ -26,9 +26,9 @@ body:
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: Copy System Specs Into Clipboard"
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
**/cargo-target
|
||||
**/target
|
||||
**/venv
|
||||
**/.direnv
|
||||
*.wasm
|
||||
*.xcodeproj
|
||||
.DS_Store
|
||||
|
||||
@@ -2,16 +2,14 @@
|
||||
{
|
||||
"label": "Debug Zed (CodeLLDB)",
|
||||
"adapter": "CodeLLDB",
|
||||
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
|
||||
"request": "launch",
|
||||
"cwd": "$ZED_WORKTREE_ROOT"
|
||||
"program": "target/debug/zed",
|
||||
"request": "launch"
|
||||
},
|
||||
{
|
||||
"label": "Debug Zed (GDB)",
|
||||
"adapter": "GDB",
|
||||
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
|
||||
"program": "target/debug/zed",
|
||||
"request": "launch",
|
||||
"cwd": "$ZED_WORKTREE_ROOT",
|
||||
"initialize_args": {
|
||||
"stopAtBeginningOfMainSubprogram": true
|
||||
}
|
||||
|
||||
119
Cargo.lock
generated
119
Cargo.lock
generated
@@ -81,12 +81,12 @@ dependencies = [
|
||||
"http_client",
|
||||
"indexed_docs",
|
||||
"indoc",
|
||||
"inventory",
|
||||
"itertools 0.14.0",
|
||||
"jsonschema",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_model_selector",
|
||||
"linkme",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
@@ -674,7 +674,6 @@ dependencies = [
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"linkme",
|
||||
"log",
|
||||
"markdown",
|
||||
"open",
|
||||
@@ -2793,6 +2792,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-recursion 0.3.2",
|
||||
"async-tungstenite",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clock",
|
||||
"cocoa 0.26.0",
|
||||
@@ -2804,6 +2804,7 @@ dependencies = [
|
||||
"gpui_tokio",
|
||||
"http_client",
|
||||
"http_client_tls",
|
||||
"httparse",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
@@ -2824,6 +2825,7 @@ dependencies = [
|
||||
"time",
|
||||
"tiny_http",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-socks",
|
||||
"url",
|
||||
"util",
|
||||
@@ -3068,6 +3070,7 @@ dependencies = [
|
||||
"gpui",
|
||||
"http_client",
|
||||
"language",
|
||||
"log",
|
||||
"menu",
|
||||
"notifications",
|
||||
"picker",
|
||||
@@ -3176,39 +3179,13 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"collections",
|
||||
"gpui",
|
||||
"linkme",
|
||||
"inventory",
|
||||
"parking_lot",
|
||||
"strum 0.27.1",
|
||||
"theme",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "component_preview"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"agent",
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"client",
|
||||
"collections",
|
||||
"component",
|
||||
"db",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"languages",
|
||||
"log",
|
||||
"notifications",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"serde",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.5.0"
|
||||
@@ -3334,6 +3311,7 @@ dependencies = [
|
||||
"http_client",
|
||||
"indoc",
|
||||
"inline_completion",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"log",
|
||||
"lsp",
|
||||
@@ -3343,11 +3321,9 @@ dependencies = [
|
||||
"paths",
|
||||
"project",
|
||||
"rpc",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"strum 0.27.1",
|
||||
"task",
|
||||
"theme",
|
||||
"ui",
|
||||
@@ -3533,9 +3509,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cosmic-text"
|
||||
version = "0.13.2"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e418dd4f5128c3e93eab12246391c54a20c496811131f85754dc8152ee207892"
|
||||
checksum = "3e1ecbb5db9a4c2ee642df67bcfa8f044dd867dbbaa21bfab139cbc204ffbf67"
|
||||
dependencies = [
|
||||
"bitflags 2.9.0",
|
||||
"fontdb 0.16.2",
|
||||
@@ -4160,6 +4136,18 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debug_adapter_extension"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"dap",
|
||||
"extension",
|
||||
"gpui",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "debugger_tools"
|
||||
version = "0.1.0"
|
||||
@@ -4193,6 +4181,7 @@ dependencies = [
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"feature_flags",
|
||||
"file_icons",
|
||||
"futures 0.3.31",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
@@ -4208,6 +4197,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"shlex",
|
||||
"sysinfo",
|
||||
"task",
|
||||
"tasks_ui",
|
||||
@@ -4339,7 +4329,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"indoc",
|
||||
"language",
|
||||
"linkme",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
@@ -5063,6 +5052,7 @@ dependencies = [
|
||||
"async-tar",
|
||||
"async-trait",
|
||||
"collections",
|
||||
"dap",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
@@ -5075,8 +5065,10 @@ dependencies = [
|
||||
"semantic_version",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"task",
|
||||
"toml 0.8.20",
|
||||
"util",
|
||||
"wasi-preview1-component-adapter-provider",
|
||||
"wasm-encoder 0.221.3",
|
||||
"wasmparser 0.221.3",
|
||||
"wit-component 0.221.3",
|
||||
@@ -5118,6 +5110,7 @@ dependencies = [
|
||||
"client",
|
||||
"collections",
|
||||
"ctor",
|
||||
"dap",
|
||||
"env_logger 0.11.8",
|
||||
"extension",
|
||||
"fs",
|
||||
@@ -5170,6 +5163,7 @@ dependencies = [
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"language",
|
||||
"log",
|
||||
"num-format",
|
||||
"picker",
|
||||
"project",
|
||||
@@ -6034,7 +6028,6 @@ dependencies = [
|
||||
"language",
|
||||
"language_model",
|
||||
"linkify",
|
||||
"linkme",
|
||||
"log",
|
||||
"markdown",
|
||||
"menu",
|
||||
@@ -7250,7 +7243,6 @@ dependencies = [
|
||||
"lsp",
|
||||
"paths",
|
||||
"project",
|
||||
"proto",
|
||||
"regex",
|
||||
"serde_json",
|
||||
"settings",
|
||||
@@ -7258,7 +7250,6 @@ dependencies = [
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
@@ -7813,9 +7804,12 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"collections",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"language_model",
|
||||
"log",
|
||||
"ordered-float 2.10.1",
|
||||
"picker",
|
||||
"proto",
|
||||
"ui",
|
||||
@@ -8170,26 +8164,6 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linkme"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22d227772b5999ddc0690e733f734f95ca05387e329c4084fe65678c51198ffe"
|
||||
dependencies = [
|
||||
"linkme-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linkme-impl"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71a98813fa0073a317ed6a8055dcd4722a49d9b862af828ee68449adb799b6be"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
@@ -9112,7 +9086,6 @@ dependencies = [
|
||||
"component",
|
||||
"db",
|
||||
"gpui",
|
||||
"linkme",
|
||||
"rpc",
|
||||
"settings",
|
||||
"sum_tree",
|
||||
@@ -10061,7 +10034,7 @@ name = "perplexity"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"zed_extension_api 0.5.0",
|
||||
"zed_extension_api 0.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -15693,7 +15666,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"icons",
|
||||
"itertools 0.14.0",
|
||||
"linkme",
|
||||
"menu",
|
||||
"serde",
|
||||
"settings",
|
||||
@@ -15714,7 +15686,6 @@ dependencies = [
|
||||
"component",
|
||||
"editor",
|
||||
"gpui",
|
||||
"linkme",
|
||||
"settings",
|
||||
"theme",
|
||||
"ui",
|
||||
@@ -15726,7 +15697,6 @@ name = "ui_macros"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"convert_case 0.8.0",
|
||||
"linkme",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
@@ -16245,6 +16215,12 @@ dependencies = [
|
||||
"wit-bindgen-rt 0.39.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi-preview1-component-adapter-provider"
|
||||
version = "29.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcd9f21bbde82ba59e415a8725e6ad0d0d7e9e460b1a3ccbca5bdee952c1a324"
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
@@ -16947,7 +16923,6 @@ dependencies = [
|
||||
"gpui",
|
||||
"install_cli",
|
||||
"language",
|
||||
"linkme",
|
||||
"picker",
|
||||
"project",
|
||||
"schemars",
|
||||
@@ -18041,7 +18016,6 @@ dependencies = [
|
||||
"aho-corasick",
|
||||
"anstream",
|
||||
"arrayvec",
|
||||
"async-compression",
|
||||
"async-std",
|
||||
"async-tungstenite",
|
||||
"aws-config",
|
||||
@@ -18549,7 +18523,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.187.0"
|
||||
version = "0.188.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -18559,6 +18533,7 @@ dependencies = [
|
||||
"assets",
|
||||
"assistant_context_editor",
|
||||
"assistant_settings",
|
||||
"assistant_tool",
|
||||
"assistant_tools",
|
||||
"async-watch",
|
||||
"audio",
|
||||
@@ -18575,7 +18550,7 @@ dependencies = [
|
||||
"collab_ui",
|
||||
"collections",
|
||||
"command_palette",
|
||||
"component_preview",
|
||||
"component",
|
||||
"copilot",
|
||||
"dap",
|
||||
"dap_adapters",
|
||||
@@ -18601,6 +18576,7 @@ dependencies = [
|
||||
"gpui_tokio",
|
||||
"http_client",
|
||||
"image_viewer",
|
||||
"indoc",
|
||||
"inline_completion_button",
|
||||
"install_cli",
|
||||
"journal",
|
||||
@@ -18613,6 +18589,7 @@ dependencies = [
|
||||
"languages",
|
||||
"libc",
|
||||
"log",
|
||||
"markdown",
|
||||
"markdown_preview",
|
||||
"menu",
|
||||
"migrator",
|
||||
@@ -18664,6 +18641,7 @@ dependencies = [
|
||||
"tree-sitter-md",
|
||||
"tree-sitter-rust",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"ui_prompt",
|
||||
"url",
|
||||
"urlencoding",
|
||||
@@ -18715,7 +18693,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_extension_api"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -18738,9 +18716,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.8.0"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a23b2fd00776b0c55072f389654910ceb501eb0083d7f78905ab0e5cc86949ec"
|
||||
checksum = "16d993fc42f9ec43ab76fa46c6eb579a66e116bb08cd2bc9a67f3afcaa05d39d"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"serde",
|
||||
@@ -18775,7 +18753,7 @@ dependencies = [
|
||||
name = "zed_test_extension"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.5.0",
|
||||
"zed_extension_api 0.6.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -18948,6 +18926,7 @@ dependencies = [
|
||||
"paths",
|
||||
"postage",
|
||||
"project",
|
||||
"proto",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -31,13 +31,13 @@ members = [
|
||||
"crates/command_palette",
|
||||
"crates/command_palette_hooks",
|
||||
"crates/component",
|
||||
"crates/component_preview",
|
||||
"crates/context_server",
|
||||
"crates/copilot",
|
||||
"crates/credentials_provider",
|
||||
"crates/dap",
|
||||
"crates/dap_adapters",
|
||||
"crates/db",
|
||||
"crates/debug_adapter_extension",
|
||||
"crates/debugger_tools",
|
||||
"crates/debugger_ui",
|
||||
"crates/deepseek",
|
||||
@@ -238,13 +238,13 @@ collections = { path = "crates/collections" }
|
||||
command_palette = { path = "crates/command_palette" }
|
||||
command_palette_hooks = { path = "crates/command_palette_hooks" }
|
||||
component = { path = "crates/component" }
|
||||
component_preview = { path = "crates/component_preview" }
|
||||
context_server = { path = "crates/context_server" }
|
||||
copilot = { path = "crates/copilot" }
|
||||
credentials_provider = { path = "crates/credentials_provider" }
|
||||
dap = { path = "crates/dap" }
|
||||
dap_adapters = { path = "crates/dap_adapters" }
|
||||
db = { path = "crates/db" }
|
||||
debug_adapter_extension = { path = "crates/debug_adapter_extension" }
|
||||
debugger_tools = { path = "crates/debugger_tools" }
|
||||
debugger_ui = { path = "crates/debugger_ui" }
|
||||
deepseek = { path = "crates/deepseek" }
|
||||
@@ -465,7 +465,6 @@ jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,r
|
||||
libc = "0.2"
|
||||
libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
|
||||
linkify = "0.10.0"
|
||||
linkme = "0.3.31"
|
||||
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
|
||||
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
|
||||
markup5ever_rcdom = "0.3.0"
|
||||
@@ -595,6 +594,7 @@ url = "2.2"
|
||||
urlencoding = "2.1.2"
|
||||
uuid = { version = "1.1.2", features = ["v4", "v5", "v7", "serde"] }
|
||||
walkdir = "2.3"
|
||||
wasi-preview1-component-adapter-provider = "29"
|
||||
wasm-encoder = "0.221"
|
||||
wasmparser = "0.221"
|
||||
wasmtime = { version = "29", default-features = false, features = [
|
||||
@@ -608,7 +608,7 @@ wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.8.0"
|
||||
zed_llm_client = "0.8.1"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
@@ -788,6 +788,9 @@ let_underscore_future = "allow"
|
||||
# running afoul of the borrow checker.
|
||||
too_many_arguments = "allow"
|
||||
|
||||
# We often have large enum variants yet we rarely actually bother with splitting them up.
|
||||
large_enum_variant = "allow"
|
||||
|
||||
[workspace.metadata.cargo-machete]
|
||||
ignored = [
|
||||
"bindgen",
|
||||
@@ -795,7 +798,6 @@ ignored = [
|
||||
"prost_build",
|
||||
"serde",
|
||||
"component",
|
||||
"linkme",
|
||||
"documented",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# syntax = docker/dockerfile:1.2
|
||||
|
||||
FROM rust:1.86-bookworm as builder
|
||||
FROM rust:1.87-bookworm as builder
|
||||
WORKDIR app
|
||||
COPY . .
|
||||
|
||||
|
||||
1
assets/icons/load_circle.svg
Normal file
1
assets/icons/load_circle.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-loader-circle-icon lucide-loader-circle"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>
|
||||
|
After Width: | Height: | Size: 289 B |
@@ -242,10 +242,9 @@
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "agent::ToggleModelSelector",
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
"ctrl-w": "agent::Close",
|
||||
"ctrl-shift-o": "agent::ToggleNavigationMenu",
|
||||
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
||||
"shift-escape": "agent::ExpandMessageEditor",
|
||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||
"ctrl-alt-e": "agent::RemoveAllContext",
|
||||
"ctrl-shift-e": "project_panel::ToggleFocus"
|
||||
}
|
||||
@@ -767,7 +766,7 @@
|
||||
"alt-ctrl-r": "project_panel::RevealInFileManager",
|
||||
"ctrl-shift-enter": "project_panel::OpenWithSystem",
|
||||
"shift-find": "project_panel::NewSearchInDirectory",
|
||||
"ctrl-shift-f": "project_panel::NewSearchInDirectory",
|
||||
"ctrl-alt-shift-f": "project_panel::NewSearchInDirectory",
|
||||
"shift-down": "menu::SelectNext",
|
||||
"shift-up": "menu::SelectPrevious",
|
||||
"escape": "menu::Cancel"
|
||||
@@ -979,5 +978,12 @@
|
||||
"bindings": {
|
||||
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "DebugConsole > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "menu::Confirm"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -288,10 +288,9 @@
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
"cmd-w": "agent::Close",
|
||||
"cmd-shift-o": "agent::ToggleNavigationMenu",
|
||||
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
||||
"shift-escape": "agent::ExpandMessageEditor",
|
||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||
"cmd-alt-e": "agent::RemoveAllContext",
|
||||
"cmd-shift-e": "project_panel::ToggleFocus"
|
||||
}
|
||||
@@ -826,7 +825,7 @@
|
||||
"alt-cmd-r": "project_panel::RevealInFileManager",
|
||||
"ctrl-shift-enter": "project_panel::OpenWithSystem",
|
||||
"cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }],
|
||||
"cmd-shift-f": "project_panel::NewSearchInDirectory",
|
||||
"cmd-alt-shift-f": "project_panel::NewSearchInDirectory",
|
||||
"shift-down": "menu::SelectNext",
|
||||
"shift-up": "menu::SelectPrevious",
|
||||
"escape": "menu::Cancel"
|
||||
@@ -1085,5 +1084,12 @@
|
||||
"bindings": {
|
||||
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "DebugConsole > Editor",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "menu::Confirm"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -49,10 +49,9 @@ And here's the section to rewrite based on that prompt again for reference:
|
||||
</rewrite_this>
|
||||
|
||||
{{#if diagnostic_errors}}
|
||||
{{#each diagnostic_errors}}
|
||||
|
||||
Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to.
|
||||
|
||||
{{#each diagnostic_errors}}
|
||||
<diagnostic_error>
|
||||
<line_number>{{line_number}}</line_number>
|
||||
<error_message>{{error_message}}</error_message>
|
||||
|
||||
@@ -113,8 +113,8 @@
|
||||
// Whether to show the informational hover box when moving the mouse
|
||||
// over symbols in the editor.
|
||||
"hover_popover_enabled": true,
|
||||
// Time to wait before showing the informational hover box
|
||||
"hover_popover_delay": 350,
|
||||
// Time to wait in milliseconds before showing the informational hover box.
|
||||
"hover_popover_delay": 300,
|
||||
// Whether to confirm before quitting Zed.
|
||||
"confirm_quit": false,
|
||||
// Whether to restore last closed project when fresh Zed instance is opened.
|
||||
@@ -328,10 +328,16 @@
|
||||
"title_bar": {
|
||||
// Whether to show the branch icon beside branch switcher in the titlebar.
|
||||
"show_branch_icon": false,
|
||||
// Whether to show the branch name button in the titlebar.
|
||||
"show_branch_name": true,
|
||||
// Whether to show the project host and name in the titlebar.
|
||||
"show_project_items": true,
|
||||
// Whether to show onboarding banners in the titlebar.
|
||||
"show_onboarding_banner": true,
|
||||
// Whether to show user picture in the titlebar.
|
||||
"show_user_picture": true
|
||||
"show_user_picture": true,
|
||||
// Whether to show the sign in button in the titlebar.
|
||||
"show_sign_in": true
|
||||
},
|
||||
// Scrollbar related settings
|
||||
"scrollbar": {
|
||||
@@ -470,6 +476,8 @@
|
||||
"search_wrap": true,
|
||||
// Search options to enable by default when opening new project and buffer searches.
|
||||
"search": {
|
||||
// Whether to show the project search button in the status bar.
|
||||
"button": true,
|
||||
"whole_word": false,
|
||||
"case_sensitive": false,
|
||||
"include_ignored": false,
|
||||
@@ -750,6 +758,8 @@
|
||||
"stream_edits": false,
|
||||
// When enabled, agent edits will be displayed in single-file editors for review
|
||||
"single_file_review": true,
|
||||
// When enabled, show voting thumbs for feedback on agent edits.
|
||||
"enable_feedback": true,
|
||||
"default_profile": "write",
|
||||
"profiles": {
|
||||
"write": {
|
||||
@@ -1002,6 +1012,8 @@
|
||||
"auto_update": true,
|
||||
// Diagnostics configuration.
|
||||
"diagnostics": {
|
||||
// Whether to show the project diagnostics button in the status bar.
|
||||
"button": true,
|
||||
// Whether to show warnings or not by default.
|
||||
"include_warnings": true,
|
||||
// Settings for inline diagnostics
|
||||
@@ -1297,21 +1309,22 @@
|
||||
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json"],
|
||||
"Shell Script": [".env.*"]
|
||||
},
|
||||
// By default use a recent system version of node, or install our own.
|
||||
// You can override this to use a version of node that is not in $PATH with:
|
||||
// {
|
||||
// "node": {
|
||||
// "path": "/path/to/node"
|
||||
// "npm_path": "/path/to/npm" (defaults to node_path/../npm)
|
||||
// }
|
||||
// }
|
||||
// or to ensure Zed always downloads and installs an isolated version of node:
|
||||
// {
|
||||
// "node": {
|
||||
// "ignore_system_version": true,
|
||||
// }
|
||||
// NOTE: changing this setting currently requires restarting Zed.
|
||||
"node": {},
|
||||
// Settings for which version of Node.js and NPM to use when installing
|
||||
// language servers and Copilot.
|
||||
//
|
||||
// Note: changing this setting currently requires restarting Zed.
|
||||
"node": {
|
||||
// By default, Zed will look for `node` and `npm` on your `$PATH`, and use the
|
||||
// existing executables if their version is recent enough. Set this to `true`
|
||||
// to prevent this, and force Zed to always download and install its own
|
||||
// version of Node.
|
||||
"ignore_system_version": false,
|
||||
// You can also specify alternative paths to Node and NPM. If you specify
|
||||
// `path`, but not `npm_path`, Zed will assume that `npm` is located at
|
||||
// `${path}/../npm`.
|
||||
"path": null,
|
||||
"npm_path": null
|
||||
},
|
||||
// The extensions that Zed should automatically install on startup.
|
||||
//
|
||||
// If you don't want any of these extensions, add this field to your settings
|
||||
|
||||
@@ -47,12 +47,12 @@ heed.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
inventory.workspace = true
|
||||
itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_model_selector.workspace = true
|
||||
linkme.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
markdown.workspace = true
|
||||
|
||||
@@ -3,9 +3,10 @@ use crate::context::{AgentContextHandle, RULES_ICON};
|
||||
use crate::context_picker::{ContextPicker, MentionLink};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::message_editor::insert_message_creases;
|
||||
use crate::thread::{
|
||||
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
|
||||
ThreadFeedback,
|
||||
LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError,
|
||||
ThreadEvent, ThreadFeedback, ThreadSummary,
|
||||
};
|
||||
use crate::thread_store::{RulesLoadingError, TextThreadStore, ThreadStore};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse};
|
||||
@@ -32,7 +33,9 @@ use language_model::{
|
||||
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason,
|
||||
};
|
||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
use markdown::{
|
||||
HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange,
|
||||
};
|
||||
use project::{ProjectEntryId, ProjectItem as _};
|
||||
use rope::Point;
|
||||
use settings::{Settings as _, SettingsStore, update_settings_file};
|
||||
@@ -182,12 +185,14 @@ pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle
|
||||
let ui_font_size = TextSize::Default.rems(cx);
|
||||
let buffer_font_size = TextSize::Small.rems(cx);
|
||||
let mut text_style = window.text_style();
|
||||
let line_height = buffer_font_size * 1.75;
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.ui_font.family.clone()),
|
||||
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.ui_font.features.clone()),
|
||||
font_size: Some(ui_font_size.into()),
|
||||
line_height: Some(line_height.into()),
|
||||
color: Some(cx.theme().colors().text),
|
||||
..Default::default()
|
||||
});
|
||||
@@ -327,6 +332,7 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
|
||||
}
|
||||
}
|
||||
|
||||
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
|
||||
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10;
|
||||
|
||||
fn render_markdown_code_block(
|
||||
@@ -379,18 +385,25 @@ fn render_markdown_code_block(
|
||||
)
|
||||
} else {
|
||||
let content = if let Some(parent) = path_range.path.parent() {
|
||||
let file_name = file_name.to_string_lossy().to_string();
|
||||
let path = parent.to_string_lossy().to_string();
|
||||
let path_and_file = format!("{}/{}", path, file_name);
|
||||
|
||||
h_flex()
|
||||
.id(("code-block-header-label", ix))
|
||||
.ml_1()
|
||||
.gap_1()
|
||||
.child(
|
||||
Label::new(file_name.to_string_lossy().to_string())
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Label::new(parent.to_string_lossy().to_string())
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(Label::new(file_name).size(LabelSize::Small))
|
||||
.child(Label::new(path).color(Color::Muted).size(LabelSize::Small))
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Jump to File",
|
||||
None,
|
||||
path_and_file.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.into_any_element()
|
||||
} else {
|
||||
Label::new(path_range.path.to_string_lossy().to_string())
|
||||
@@ -400,7 +413,7 @@ fn render_markdown_code_block(
|
||||
};
|
||||
|
||||
h_flex()
|
||||
.id(("code-block-header-label", ix))
|
||||
.id(("code-block-header-button", ix))
|
||||
.w_full()
|
||||
.max_w_full()
|
||||
.px_1()
|
||||
@@ -408,7 +421,6 @@ fn render_markdown_code_block(
|
||||
.cursor_pointer()
|
||||
.rounded_sm()
|
||||
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
|
||||
.tooltip(Tooltip::text("Jump to File"))
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
@@ -428,49 +440,8 @@ fn render_markdown_code_block(
|
||||
let path_range = path_range.clone();
|
||||
move |_, window, cx| {
|
||||
workspace
|
||||
.update(cx, {
|
||||
|workspace, cx| {
|
||||
let Some(project_path) = workspace
|
||||
.project()
|
||||
.read(cx)
|
||||
.find_project_path(&path_range.path, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let Some(target) = path_range.range.as_ref().map(|range| {
|
||||
Point::new(
|
||||
// Line number is 1-based
|
||||
range.start.line.saturating_sub(1),
|
||||
range.start.col.unwrap_or(0),
|
||||
)
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let open_task = workspace.open_path(
|
||||
project_path,
|
||||
None,
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let item = open_task.await?;
|
||||
if let Some(active_editor) =
|
||||
item.downcast::<Editor>()
|
||||
{
|
||||
active_editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(
|
||||
target, window, cx,
|
||||
);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
.update(cx, |workspace, cx| {
|
||||
open_path(&path_range, window, workspace, cx)
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -485,12 +456,13 @@ fn render_markdown_code_block(
|
||||
.copied_code_block_ids
|
||||
.contains(&(message_id, ix));
|
||||
|
||||
let is_expanded = active_thread
|
||||
.read(cx)
|
||||
.expanded_code_blocks
|
||||
.get(&(message_id, ix))
|
||||
.copied()
|
||||
.unwrap_or(true);
|
||||
let can_expand = metadata.line_count >= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK;
|
||||
|
||||
let is_expanded = if can_expand {
|
||||
active_thread.read(cx).is_codeblock_expanded(message_id, ix)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let codeblock_header_bg = cx
|
||||
.theme()
|
||||
@@ -498,10 +470,87 @@ fn render_markdown_code_block(
|
||||
.element_background
|
||||
.blend(cx.theme().colors().editor_foreground.opacity(0.01));
|
||||
|
||||
let control_buttons = h_flex()
|
||||
.visible_on_hover(CODEBLOCK_CONTAINER_GROUP)
|
||||
.absolute()
|
||||
.top_0()
|
||||
.right_0()
|
||||
.h_full()
|
||||
.bg(codeblock_header_bg)
|
||||
.rounded_tr_md()
|
||||
.px_1()
|
||||
.gap_1()
|
||||
.child(
|
||||
IconButton::new(
|
||||
("copy-markdown-code", ix),
|
||||
if codeblock_was_copied {
|
||||
IconName::Check
|
||||
} else {
|
||||
IconName::Copy
|
||||
},
|
||||
)
|
||||
.icon_color(Color::Muted)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.tooltip(Tooltip::text("Copy Code"))
|
||||
.on_click({
|
||||
let active_thread = active_thread.clone();
|
||||
let parsed_markdown = parsed_markdown.clone();
|
||||
let code_block_range = metadata.content_range.clone();
|
||||
move |_event, _window, cx| {
|
||||
active_thread.update(cx, |this, cx| {
|
||||
this.copied_code_block_ids.insert((message_id, ix));
|
||||
|
||||
let code = parsed_markdown.source()[code_block_range.clone()].to_string();
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(code));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
cx.background_executor().timer(Duration::from_secs(2)).await;
|
||||
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.copied_code_block_ids.remove(&(message_id, ix));
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
}
|
||||
}),
|
||||
)
|
||||
.when(can_expand, |header| {
|
||||
header.child(
|
||||
IconButton::new(
|
||||
("expand-collapse-code", ix),
|
||||
if is_expanded {
|
||||
IconName::ChevronUp
|
||||
} else {
|
||||
IconName::ChevronDown
|
||||
},
|
||||
)
|
||||
.icon_color(Color::Muted)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.tooltip(Tooltip::text(if is_expanded {
|
||||
"Collapse Code"
|
||||
} else {
|
||||
"Expand Code"
|
||||
}))
|
||||
.on_click({
|
||||
let active_thread = active_thread.clone();
|
||||
move |_event, _window, cx| {
|
||||
active_thread.update(cx, |this, cx| {
|
||||
this.toggle_codeblock_expanded(message_id, ix);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}),
|
||||
)
|
||||
});
|
||||
|
||||
let codeblock_header = h_flex()
|
||||
.py_1()
|
||||
.pl_1p5()
|
||||
.pr_1()
|
||||
.relative()
|
||||
.p_1()
|
||||
.gap_1()
|
||||
.justify_between()
|
||||
.border_b_1()
|
||||
@@ -509,89 +558,10 @@ fn render_markdown_code_block(
|
||||
.bg(codeblock_header_bg)
|
||||
.rounded_t_md()
|
||||
.children(label)
|
||||
.child(
|
||||
h_flex()
|
||||
.visible_on_hover("codeblock_container")
|
||||
.gap_1()
|
||||
.child(
|
||||
IconButton::new(
|
||||
("copy-markdown-code", ix),
|
||||
if codeblock_was_copied {
|
||||
IconName::Check
|
||||
} else {
|
||||
IconName::Copy
|
||||
},
|
||||
)
|
||||
.icon_color(Color::Muted)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.tooltip(Tooltip::text("Copy Code"))
|
||||
.on_click({
|
||||
let active_thread = active_thread.clone();
|
||||
let parsed_markdown = parsed_markdown.clone();
|
||||
let code_block_range = metadata.content_range.clone();
|
||||
move |_event, _window, cx| {
|
||||
active_thread.update(cx, |this, cx| {
|
||||
this.copied_code_block_ids.insert((message_id, ix));
|
||||
|
||||
let code =
|
||||
parsed_markdown.source()[code_block_range.clone()].to_string();
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(code));
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
cx.background_executor().timer(Duration::from_secs(2)).await;
|
||||
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
this.copied_code_block_ids.remove(&(message_id, ix));
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.detach();
|
||||
});
|
||||
}
|
||||
}),
|
||||
)
|
||||
.when(
|
||||
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|
||||
|header| {
|
||||
header.child(
|
||||
IconButton::new(
|
||||
("expand-collapse-code", ix),
|
||||
if is_expanded {
|
||||
IconName::ChevronUp
|
||||
} else {
|
||||
IconName::ChevronDown
|
||||
},
|
||||
)
|
||||
.icon_color(Color::Muted)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.tooltip(Tooltip::text(if is_expanded {
|
||||
"Collapse Code"
|
||||
} else {
|
||||
"Expand Code"
|
||||
}))
|
||||
.on_click({
|
||||
let active_thread = active_thread.clone();
|
||||
move |_event, _window, cx| {
|
||||
active_thread.update(cx, |this, cx| {
|
||||
let is_expanded = this
|
||||
.expanded_code_blocks
|
||||
.entry((message_id, ix))
|
||||
.or_insert(true);
|
||||
*is_expanded = !*is_expanded;
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}),
|
||||
)
|
||||
},
|
||||
),
|
||||
);
|
||||
.child(control_buttons);
|
||||
|
||||
v_flex()
|
||||
.group("codeblock_container")
|
||||
.group(CODEBLOCK_CONTAINER_GROUP)
|
||||
.my_2()
|
||||
.overflow_hidden()
|
||||
.rounded_lg()
|
||||
@@ -599,16 +569,46 @@ fn render_markdown_code_block(
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.child(codeblock_header)
|
||||
.when(
|
||||
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|
||||
|this| {
|
||||
if is_expanded {
|
||||
this.h_full()
|
||||
} else {
|
||||
this.max_h_80()
|
||||
}
|
||||
},
|
||||
.when(can_expand && !is_expanded, |this| this.max_h_80())
|
||||
}
|
||||
|
||||
fn open_path(
|
||||
path_range: &PathWithRange,
|
||||
window: &mut Window,
|
||||
workspace: &mut Workspace,
|
||||
cx: &mut Context<'_, Workspace>,
|
||||
) {
|
||||
let Some(project_path) = workspace
|
||||
.project()
|
||||
.read(cx)
|
||||
.find_project_path(&path_range.path, cx)
|
||||
else {
|
||||
return; // TODO instead of just bailing out, open that path in a buffer.
|
||||
};
|
||||
|
||||
let Some(target) = path_range.range.as_ref().map(|range| {
|
||||
Point::new(
|
||||
// Line number is 1-based
|
||||
range.start.line.saturating_sub(1),
|
||||
range.start.col.unwrap_or(0),
|
||||
)
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let open_task = workspace.open_path(project_path, None, true, window, cx);
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let item = open_task.await?;
|
||||
if let Some(active_editor) = item.downcast::<Editor>() {
|
||||
active_editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(target, window, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn render_code_language(
|
||||
@@ -827,12 +827,12 @@ impl ActiveThread {
|
||||
self.messages.is_empty()
|
||||
}
|
||||
|
||||
pub fn summary(&self, cx: &App) -> Option<SharedString> {
|
||||
pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary {
|
||||
self.thread.read(cx).summary()
|
||||
}
|
||||
|
||||
pub fn summary_or_default(&self, cx: &App) -> SharedString {
|
||||
self.thread.read(cx).summary_or_default()
|
||||
pub fn regenerate_summary(&self, cx: &mut App) {
|
||||
self.thread.update(cx, |thread, cx| thread.summarize(cx))
|
||||
}
|
||||
|
||||
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
|
||||
@@ -1138,11 +1138,7 @@ impl ActiveThread {
|
||||
return;
|
||||
}
|
||||
|
||||
let title = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.summary()
|
||||
.unwrap_or("Agent Panel".into());
|
||||
let title = self.thread.read(cx).summary().unwrap_or("Agent Panel");
|
||||
|
||||
match AssistantSettings::get_global(cx).notify_when_agent_waiting {
|
||||
NotifyWhenAgentWaiting::PrimaryScreen => {
|
||||
@@ -1272,6 +1268,7 @@ impl ActiveThread {
|
||||
&mut self,
|
||||
message_id: MessageId,
|
||||
message_segments: &[MessageSegment],
|
||||
message_creases: &[MessageCrease],
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
@@ -1291,6 +1288,7 @@ impl ActiveThread {
|
||||
);
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(message_text.clone(), window, cx);
|
||||
insert_message_creases(editor, message_creases, &self.context_store, window, cx);
|
||||
editor.focus_handle(cx).focus(window);
|
||||
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
|
||||
});
|
||||
@@ -1724,10 +1722,11 @@ impl ActiveThread {
|
||||
.on_action(cx.listener(Self::confirm_editing_message))
|
||||
.capture_action(cx.listener(Self::paste))
|
||||
.min_h_6()
|
||||
.flex_grow()
|
||||
.w_full()
|
||||
.flex_grow()
|
||||
.gap_2()
|
||||
.child(EditorElement::new(
|
||||
.child(state.context_strip.clone())
|
||||
.child(div().pt(px(-3.)).px_neg_0p5().child(EditorElement::new(
|
||||
&state.editor,
|
||||
EditorStyle {
|
||||
background: colors.editor_background,
|
||||
@@ -1736,8 +1735,7 @@ impl ActiveThread {
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
.child(state.context_strip.clone())
|
||||
)))
|
||||
}
|
||||
|
||||
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
|
||||
@@ -1745,6 +1743,7 @@ impl ActiveThread {
|
||||
let Some(message) = self.thread.read(cx).message(message_id) else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
let message_creases = message.creases.clone();
|
||||
|
||||
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
|
||||
return Empty.into_any();
|
||||
@@ -1864,7 +1863,8 @@ impl ActiveThread {
|
||||
.child(open_as_markdown),
|
||||
)
|
||||
.into_any_element(),
|
||||
None => feedback_container
|
||||
None if AssistantSettings::get_global(cx).enable_feedback =>
|
||||
feedback_container
|
||||
.child(
|
||||
div().visible_on_hover("feedback_container").child(
|
||||
Label::new(
|
||||
@@ -1907,6 +1907,9 @@ impl ActiveThread {
|
||||
.child(open_as_markdown),
|
||||
)
|
||||
.into_any_element(),
|
||||
None => feedback_container
|
||||
.child(h_flex().child(open_as_markdown))
|
||||
.into_any_element(),
|
||||
};
|
||||
|
||||
let message_is_empty = message.should_display_content();
|
||||
@@ -1920,16 +1923,6 @@ impl ActiveThread {
|
||||
v_flex()
|
||||
.w_full()
|
||||
.gap_1()
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(div().min_h_6().child(self.render_message_content(
|
||||
message_id,
|
||||
rendered_message,
|
||||
has_tool_uses,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)))
|
||||
})
|
||||
.when(!added_context.is_empty(), |parent| {
|
||||
parent.child(h_flex().flex_wrap().gap_1().children(
|
||||
added_context.into_iter().map(|added_context| {
|
||||
@@ -1948,6 +1941,16 @@ impl ActiveThread {
|
||||
}),
|
||||
))
|
||||
})
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(div().pt_0p5().min_h_6().child(self.render_message_content(
|
||||
message_id,
|
||||
rendered_message,
|
||||
has_tool_uses,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
)))
|
||||
})
|
||||
.into_any_element()
|
||||
}
|
||||
});
|
||||
@@ -1973,6 +1976,7 @@ impl ActiveThread {
|
||||
h_flex()
|
||||
.p_2p5()
|
||||
.gap_1()
|
||||
.items_end()
|
||||
.children(message_content)
|
||||
.when_some(editing_message_state, |this, state| {
|
||||
let focus_handle = state.editor.focus_handle(cx).clone();
|
||||
@@ -1986,6 +1990,7 @@ impl ActiveThread {
|
||||
)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_color(Color::Error)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
@@ -2003,11 +2008,12 @@ impl ActiveThread {
|
||||
.child(
|
||||
IconButton::new(
|
||||
"confirm-edit-message",
|
||||
IconName::Check,
|
||||
IconName::Return,
|
||||
)
|
||||
.disabled(state.editor.read(cx).is_empty(cx))
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_color(Color::Success)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_size(IconSize::Small)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
move |window, cx| {
|
||||
@@ -2027,15 +2033,13 @@ impl ActiveThread {
|
||||
)
|
||||
}),
|
||||
)
|
||||
.when(editing_message_state.is_none(), |this| {
|
||||
this.tooltip(Tooltip::text("Click To Edit"))
|
||||
})
|
||||
.on_click(cx.listener({
|
||||
let message_segments = message.segments.clone();
|
||||
move |this, _, window, cx| {
|
||||
this.start_editing_message(
|
||||
message_id,
|
||||
&message_segments,
|
||||
&message_creases,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -2069,6 +2073,16 @@ impl ActiveThread {
|
||||
|
||||
let panel_background = cx.theme().colors().panel_background;
|
||||
|
||||
let backdrop = div()
|
||||
.id("backdrop")
|
||||
.stop_mouse_events_except_scroll()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.size_full()
|
||||
.bg(panel_background)
|
||||
.opacity(0.8)
|
||||
.on_click(cx.listener(Self::handle_cancel_click));
|
||||
|
||||
v_flex()
|
||||
.w_full()
|
||||
.map(|parent| {
|
||||
@@ -2238,15 +2252,7 @@ impl ActiveThread {
|
||||
})
|
||||
.when(after_editing_message, |parent| {
|
||||
// Backdrop to dim out the whole thread below the editing user message
|
||||
parent.relative().child(
|
||||
div()
|
||||
.occlude()
|
||||
.absolute()
|
||||
.inset_0()
|
||||
.size_full()
|
||||
.bg(panel_background)
|
||||
.opacity(0.8),
|
||||
)
|
||||
parent.relative().child(backdrop)
|
||||
})
|
||||
.into_any()
|
||||
}
|
||||
@@ -2359,19 +2365,21 @@ impl ActiveThread {
|
||||
let editor_bg = cx.theme().colors().editor_background;
|
||||
|
||||
move |el, range, metadata, _, cx| {
|
||||
let is_expanded = active_thread
|
||||
.read(cx)
|
||||
.expanded_code_blocks
|
||||
.get(&(message_id, range.start))
|
||||
.copied()
|
||||
.unwrap_or(true);
|
||||
let can_expand = metadata.line_count
|
||||
>= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK;
|
||||
|
||||
if is_expanded
|
||||
|| metadata.line_count
|
||||
<= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK
|
||||
{
|
||||
if !can_expand {
|
||||
return el;
|
||||
}
|
||||
|
||||
let is_expanded = active_thread
|
||||
.read(cx)
|
||||
.is_codeblock_expanded(message_id, range.start);
|
||||
|
||||
if is_expanded {
|
||||
return el;
|
||||
}
|
||||
|
||||
el.child(
|
||||
div()
|
||||
.absolute()
|
||||
@@ -2397,6 +2405,7 @@ impl ActiveThread {
|
||||
markdown_element.code_block_renderer(
|
||||
markdown::CodeBlockRenderer::Default {
|
||||
copy_button: false,
|
||||
copy_button_on_hover: false,
|
||||
border: true,
|
||||
},
|
||||
)
|
||||
@@ -2716,6 +2725,7 @@ impl ActiveThread {
|
||||
)
|
||||
.code_block_renderer(markdown::CodeBlockRenderer::Default {
|
||||
copy_button: false,
|
||||
copy_button_on_hover: false,
|
||||
border: false,
|
||||
})
|
||||
.on_url_click({
|
||||
@@ -2746,6 +2756,7 @@ impl ActiveThread {
|
||||
)
|
||||
.code_block_renderer(markdown::CodeBlockRenderer::Default {
|
||||
copy_button: false,
|
||||
copy_button_on_hover: false,
|
||||
border: false,
|
||||
})
|
||||
.on_url_click({
|
||||
@@ -3382,6 +3393,21 @@ impl ActiveThread {
|
||||
.log_err();
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool {
|
||||
self.expanded_code_blocks
|
||||
.get(&(message_id, ix))
|
||||
.copied()
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
pub fn toggle_codeblock_expanded(&mut self, message_id: MessageId, ix: usize) {
|
||||
let is_expanded = self
|
||||
.expanded_code_blocks
|
||||
.entry((message_id, ix))
|
||||
.or_insert(true);
|
||||
*is_expanded = !*is_expanded;
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ActiveThreadEvent {
|
||||
@@ -3395,6 +3421,7 @@ impl Render for ActiveThread {
|
||||
v_flex()
|
||||
.size_full()
|
||||
.relative()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.on_mouse_move(cx.listener(|this, _, _, cx| {
|
||||
this.show_scrollbar = true;
|
||||
this.hide_scrollbar_later(cx);
|
||||
@@ -3436,10 +3463,7 @@ pub(crate) fn open_active_thread_as_markdown(
|
||||
workspace.update_in(cx, |workspace, window, cx| {
|
||||
let thread = thread.read(cx);
|
||||
let markdown = thread.to_markdown(cx)?;
|
||||
let thread_summary = thread
|
||||
.summary()
|
||||
.map(|summary| summary.to_string())
|
||||
.unwrap_or_else(|| "Thread".to_string());
|
||||
let thread_summary = thread.summary().or_default().to_string();
|
||||
|
||||
let project = workspace.project().clone();
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
use crate::slash_command_settings::SlashCommandSettings;
|
||||
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::{TextThreadStore, ThreadStore};
|
||||
pub use crate::thread_store::{SerializedThread, TextThreadStore, ThreadStore};
|
||||
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
|
||||
pub use context_store::ContextStore;
|
||||
pub use ui::preview::{all_agent_previews, get_agent_preview};
|
||||
@@ -85,7 +85,6 @@ actions!(
|
||||
KeepAll,
|
||||
Follow,
|
||||
ResetTrialUpsell,
|
||||
Close,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageMod
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
|
||||
Switch, SwitchColor, Tooltip, prelude::*,
|
||||
Disclosure, ElevationIndex, Indicator, Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip,
|
||||
prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use zed_actions::ExtensionCategoryFilter;
|
||||
@@ -36,6 +36,7 @@ pub struct AgentConfiguration {
|
||||
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
expanded_context_server_tools: HashMap<ContextServerId, bool>,
|
||||
expanded_provider_configurations: HashMap<LanguageModelProviderId, bool>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
_registry_subscription: Subscription,
|
||||
scroll_handle: ScrollHandle,
|
||||
@@ -78,6 +79,7 @@ impl AgentConfiguration {
|
||||
configuration_views_by_provider: HashMap::default(),
|
||||
context_server_store,
|
||||
expanded_context_server_tools: HashMap::default(),
|
||||
expanded_provider_configurations: HashMap::default(),
|
||||
tools,
|
||||
_registry_subscription: registry_subscription,
|
||||
scroll_handle,
|
||||
@@ -96,6 +98,7 @@ impl AgentConfiguration {
|
||||
|
||||
fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
|
||||
self.configuration_views_by_provider.remove(provider_id);
|
||||
self.expanded_provider_configurations.remove(provider_id);
|
||||
}
|
||||
|
||||
fn add_provider_configuration_view(
|
||||
@@ -135,9 +138,14 @@ impl AgentConfiguration {
|
||||
.get(&provider.id())
|
||||
.cloned();
|
||||
|
||||
let is_expanded = self
|
||||
.expanded_provider_configurations
|
||||
.get(&provider.id())
|
||||
.copied()
|
||||
.unwrap_or(false);
|
||||
|
||||
v_flex()
|
||||
.pt_3()
|
||||
.pb_1()
|
||||
.gap_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border.opacity(0.6))
|
||||
@@ -152,36 +160,63 @@ impl AgentConfiguration {
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new(provider_name.clone()).size(LabelSize::Large)),
|
||||
.child(Label::new(provider_name.clone()).size(LabelSize::Large))
|
||||
.when(provider.is_authenticated(cx) && !is_expanded, |parent| {
|
||||
parent.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
}),
|
||||
)
|
||||
.when(provider.is_authenticated(cx), |parent| {
|
||||
parent.child(
|
||||
Button::new(
|
||||
SharedString::from(format!("new-thread-{provider_id}")),
|
||||
"Start New Thread",
|
||||
)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Plus)
|
||||
.icon_size(IconSize::Small)
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let provider = provider.clone();
|
||||
move |_this, _event, _window, cx| {
|
||||
cx.emit(AssistantConfigurationEvent::NewThread(
|
||||
provider.clone(),
|
||||
))
|
||||
}
|
||||
})),
|
||||
)
|
||||
}),
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.when(provider.is_authenticated(cx), |parent| {
|
||||
parent.child(
|
||||
Button::new(
|
||||
SharedString::from(format!("new-thread-{provider_id}")),
|
||||
"Start New Thread",
|
||||
)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Plus)
|
||||
.icon_size(IconSize::Small)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.label_size(LabelSize::Small)
|
||||
.on_click(cx.listener({
|
||||
let provider = provider.clone();
|
||||
move |_this, _event, _window, cx| {
|
||||
cx.emit(AssistantConfigurationEvent::NewThread(
|
||||
provider.clone(),
|
||||
))
|
||||
}
|
||||
})),
|
||||
)
|
||||
})
|
||||
.child(
|
||||
Disclosure::new(
|
||||
SharedString::from(format!(
|
||||
"provider-disclosure-{provider_id}"
|
||||
)),
|
||||
is_expanded,
|
||||
)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.on_click(cx.listener({
|
||||
let provider_id = provider.id().clone();
|
||||
move |this, _event, _window, _cx| {
|
||||
let is_expanded = this
|
||||
.expanded_provider_configurations
|
||||
.entry(provider_id.clone())
|
||||
.or_insert(false);
|
||||
|
||||
*is_expanded = !*is_expanded;
|
||||
}
|
||||
})),
|
||||
),
|
||||
),
|
||||
)
|
||||
.map(|parent| match configuration_view {
|
||||
.when(is_expanded, |parent| match configuration_view {
|
||||
Some(configuration_view) => parent.child(configuration_view),
|
||||
None => parent.child(div().child(Label::new(format!(
|
||||
None => parent.child(Label::new(format!(
|
||||
"No configuration view for {provider_name}",
|
||||
)))),
|
||||
))),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -195,7 +230,8 @@ impl AgentConfiguration {
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.gap_4()
|
||||
.flex_1()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
@@ -296,7 +332,8 @@ impl AgentConfiguration {
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.gap_2p5()
|
||||
.flex_1()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(Headline::new("General Settings"))
|
||||
.child(self.render_command_permission(cx))
|
||||
.child(self.render_single_file_review(cx))
|
||||
@@ -309,18 +346,17 @@ impl AgentConfiguration {
|
||||
) -> impl IntoElement {
|
||||
let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
|
||||
|
||||
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
|
||||
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("Model Context Protocol (MCP) Servers"))
|
||||
.child(Label::new(SUBHEADING).color(Color::Muted)),
|
||||
.child(Label::new("Connect to context servers via the Model Context Protocol either via Zed extensions or directly.").color(Color::Muted)),
|
||||
)
|
||||
.children(
|
||||
context_server_ids.into_iter().map(|context_server_id| {
|
||||
@@ -387,6 +423,7 @@ impl AgentConfiguration {
|
||||
.unwrap_or(ContextServerStatus::Stopped);
|
||||
|
||||
let is_running = matches!(server_status, ContextServerStatus::Running);
|
||||
let item_id = SharedString::from(context_server_id.0.clone());
|
||||
|
||||
let error = if let ContextServerStatus::Error(error) = server_status.clone() {
|
||||
Some(error)
|
||||
@@ -408,9 +445,38 @@ impl AgentConfiguration {
|
||||
let tool_count = tools.len();
|
||||
|
||||
let border_color = cx.theme().colors().border.opacity(0.6);
|
||||
let success_color = Color::Success.color(cx);
|
||||
|
||||
let (status_indicator, tooltip_text) = match server_status {
|
||||
ContextServerStatus::Starting => (
|
||||
Indicator::dot()
|
||||
.color(Color::Success)
|
||||
.with_animation(
|
||||
SharedString::from(format!("{}-starting", context_server_id.0.clone(),)),
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 1.)),
|
||||
move |this, delta| this.color(success_color.alpha(delta).into()),
|
||||
)
|
||||
.into_any_element(),
|
||||
"Server is starting.",
|
||||
),
|
||||
ContextServerStatus::Running => (
|
||||
Indicator::dot().color(Color::Success).into_any_element(),
|
||||
"Server is running.",
|
||||
),
|
||||
ContextServerStatus::Error(_) => (
|
||||
Indicator::dot().color(Color::Error).into_any_element(),
|
||||
"Server has an error.",
|
||||
),
|
||||
ContextServerStatus::Stopped => (
|
||||
Indicator::dot().color(Color::Muted).into_any_element(),
|
||||
"Server is stopped.",
|
||||
),
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.id(SharedString::from(context_server_id.0.clone()))
|
||||
.id(item_id.clone())
|
||||
.border_1()
|
||||
.rounded_md()
|
||||
.border_color(border_color)
|
||||
@@ -445,35 +511,12 @@ impl AgentConfiguration {
|
||||
}
|
||||
})),
|
||||
)
|
||||
.child(match server_status {
|
||||
ContextServerStatus::Starting => {
|
||||
let color = Color::Success.color(cx);
|
||||
Indicator::dot()
|
||||
.color(Color::Success)
|
||||
.with_animation(
|
||||
SharedString::from(format!(
|
||||
"{}-starting",
|
||||
context_server_id.0.clone(),
|
||||
)),
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.4, 1.)),
|
||||
move |this, delta| {
|
||||
this.color(color.alpha(delta).into())
|
||||
},
|
||||
)
|
||||
.into_any_element()
|
||||
}
|
||||
ContextServerStatus::Running => {
|
||||
Indicator::dot().color(Color::Success).into_any_element()
|
||||
}
|
||||
ContextServerStatus::Error(_) => {
|
||||
Indicator::dot().color(Color::Error).into_any_element()
|
||||
}
|
||||
ContextServerStatus::Stopped => {
|
||||
Indicator::dot().color(Color::Muted).into_any_element()
|
||||
}
|
||||
})
|
||||
.child(
|
||||
div()
|
||||
.id(item_id.clone())
|
||||
.tooltip(Tooltip::text(tooltip_text))
|
||||
.child(status_indicator),
|
||||
)
|
||||
.child(Label::new(context_server_id.0.clone()).ml_0p5())
|
||||
.when(is_running, |this| {
|
||||
this.child(
|
||||
@@ -588,9 +631,7 @@ impl Render for AgentConfiguration {
|
||||
.size_full()
|
||||
.overflow_y_scroll()
|
||||
.child(self.render_general_settings_section(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_context_servers_section(window, cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_provider_configuration_section(cx)),
|
||||
)
|
||||
.child(
|
||||
|
||||
@@ -30,7 +30,6 @@ pub(crate) struct ConfigureContextServerModal {
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
enum Configuration {
|
||||
NotAvailable,
|
||||
Required(ConfigurationRequiredState),
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
use crate::{
|
||||
Keep, KeepAll, OpenAgentDiff, Reject, RejectAll, Thread, ThreadEvent, ui::AnimatedLabel,
|
||||
};
|
||||
use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll, Thread, ThreadEvent};
|
||||
use anyhow::Result;
|
||||
use assistant_settings::AssistantSettings;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
@@ -11,8 +9,9 @@ use editor::{
|
||||
scroll::Autoscroll,
|
||||
};
|
||||
use gpui::{
|
||||
Action, AnyElement, AnyView, App, AppContext, Empty, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, Global, SharedString, Subscription, Task, WeakEntity, Window, prelude::*,
|
||||
Action, Animation, AnimationExt, AnyElement, AnyView, App, AppContext, Empty, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, Global, SharedString, Subscription, Task, Transformation,
|
||||
WeakEntity, Window, percentage, prelude::*,
|
||||
};
|
||||
|
||||
use language::{Buffer, Capability, DiskState, OffsetRangeExt, Point};
|
||||
@@ -25,6 +24,7 @@ use std::{
|
||||
collections::hash_map::Entry,
|
||||
ops::Range,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use ui::{IconButtonShape, KeyBinding, Tooltip, prelude::*, vertical_divider};
|
||||
use util::ResultExt;
|
||||
@@ -215,11 +215,7 @@ impl AgentDiffPane {
|
||||
}
|
||||
|
||||
fn update_title(&mut self, cx: &mut Context<Self>) {
|
||||
let new_title = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.summary()
|
||||
.unwrap_or("Agent Changes".into());
|
||||
let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes");
|
||||
if new_title != self.title {
|
||||
self.title = new_title;
|
||||
cx.emit(EditorEvent::TitleChanged);
|
||||
@@ -469,11 +465,7 @@ impl Item for AgentDiffPane {
|
||||
}
|
||||
|
||||
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
|
||||
let summary = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.summary()
|
||||
.unwrap_or("Agent Changes".into());
|
||||
let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes");
|
||||
Label::new(format!("Review: {}", summary))
|
||||
.color(if params.selected {
|
||||
Color::Default
|
||||
@@ -978,9 +970,20 @@ impl ToolbarItemView for AgentDiffToolbar {
|
||||
|
||||
impl Render for AgentDiffToolbar {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let generating_label = div()
|
||||
.w(rems_from_px(110.)) // Arbitrary size so the label doesn't dance around
|
||||
.child(AnimatedLabel::new("Generating"))
|
||||
let spinner_icon = div()
|
||||
.px_0p5()
|
||||
.id("generating")
|
||||
.tooltip(Tooltip::text("Generating Changes…"))
|
||||
.child(
|
||||
Icon::new(IconName::LoadCircle)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Accent)
|
||||
.with_animation(
|
||||
"load_circle",
|
||||
Animation::new(Duration::from_secs(3)).repeat(),
|
||||
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
|
||||
),
|
||||
)
|
||||
.into_any();
|
||||
|
||||
let Some(active_item) = self.active_item.as_ref() else {
|
||||
@@ -997,7 +1000,7 @@ impl Render for AgentDiffToolbar {
|
||||
|
||||
let content = match state {
|
||||
EditorState::Idle => return Empty.into_any(),
|
||||
EditorState::Generating => vec![generating_label],
|
||||
EditorState::Generating => vec![spinner_icon],
|
||||
EditorState::Reviewing => vec![
|
||||
h_flex()
|
||||
.child(
|
||||
@@ -1115,7 +1118,7 @@ impl Render for AgentDiffToolbar {
|
||||
|
||||
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
|
||||
if is_generating {
|
||||
return div().px_2().child(generating_label).into_any();
|
||||
return div().px_2().child(spinner_icon).into_any();
|
||||
}
|
||||
|
||||
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
|
||||
|
||||
@@ -10,8 +10,8 @@ use serde::{Deserialize, Serialize};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_context_editor::{
|
||||
AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent,
|
||||
SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate,
|
||||
render_remaining_tokens,
|
||||
ContextSummary, SlashCommandCompletionProvider, humanize_token_count,
|
||||
make_lsp_adapter_delegate, render_remaining_tokens,
|
||||
};
|
||||
use assistant_settings::{AssistantDockPosition, AssistantSettings};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
@@ -46,7 +46,9 @@ use ui::{
|
||||
};
|
||||
use util::{ResultExt as _, maybe};
|
||||
use workspace::dock::{DockPosition, Panel, PanelEvent};
|
||||
use workspace::{CollaboratorId, DraggedSelection, DraggedTab, ToolbarItemView, Workspace};
|
||||
use workspace::{
|
||||
CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace,
|
||||
};
|
||||
use zed_actions::agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding};
|
||||
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
||||
use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize};
|
||||
@@ -55,17 +57,17 @@ use zed_llm_client::UsageLimit;
|
||||
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
|
||||
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
|
||||
use crate::agent_diff::AgentDiff;
|
||||
use crate::history_store::{HistoryEntry, HistoryStore, RecentEntry};
|
||||
use crate::history_store::{HistoryStore, RecentEntry};
|
||||
use crate::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
|
||||
use crate::thread_history::{EntryTimeFormat, PastContext, PastThread, ThreadHistory};
|
||||
use crate::thread::{Thread, ThreadError, ThreadId, ThreadSummary, TokenUsageRatio};
|
||||
use crate::thread_history::{HistoryEntryElement, ThreadHistory};
|
||||
use crate::thread_store::ThreadStore;
|
||||
use crate::ui::AgentOnboardingModal;
|
||||
use crate::{
|
||||
AddContextServer, AgentDiffPane, Close, ContextStore, DeleteRecentlyOpenThread,
|
||||
ExpandMessageEditor, Follow, InlineAssistant, NewTextThread, NewThread,
|
||||
OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ResetTrialUpsell, TextThreadStore,
|
||||
ThreadEvent, ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu,
|
||||
AddContextServer, AgentDiffPane, ContextStore, DeleteRecentlyOpenThread, ExpandMessageEditor,
|
||||
Follow, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff,
|
||||
OpenHistory, ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleContextPicker,
|
||||
ToggleNavigationMenu, ToggleOptionsMenu,
|
||||
};
|
||||
|
||||
const AGENT_PANEL_KEY: &str = "agent_panel";
|
||||
@@ -156,11 +158,6 @@ pub fn init(cx: &mut App) {
|
||||
})
|
||||
.register_action(|_workspace, _: &ResetTrialUpsell, _window, cx| {
|
||||
set_trial_upsell_dismissed(false, cx);
|
||||
})
|
||||
.register_action(|workspace, _: &Close, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| panel.close_panel(window, cx));
|
||||
}
|
||||
});
|
||||
},
|
||||
)
|
||||
@@ -199,7 +196,7 @@ impl ActiveView {
|
||||
}
|
||||
|
||||
pub fn thread(thread: Entity<Thread>, window: &mut Window, cx: &mut App) -> Self {
|
||||
let summary = thread.read(cx).summary_or_default();
|
||||
let summary = thread.read(cx).summary().or_default();
|
||||
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::single_line(window, cx);
|
||||
@@ -221,7 +218,7 @@ impl ActiveView {
|
||||
}
|
||||
EditorEvent::Blurred => {
|
||||
if editor.read(cx).text(cx).is_empty() {
|
||||
let summary = thread.read(cx).summary_or_default();
|
||||
let summary = thread.read(cx).summary().or_default();
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(summary, window, cx);
|
||||
@@ -236,7 +233,7 @@ impl ActiveView {
|
||||
let editor = editor.clone();
|
||||
move |thread, event, window, cx| match event {
|
||||
ThreadEvent::SummaryGenerated => {
|
||||
let summary = thread.read(cx).summary_or_default();
|
||||
let summary = thread.read(cx).summary().or_default();
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(summary, window, cx);
|
||||
@@ -299,7 +296,8 @@ impl ActiveView {
|
||||
.read(cx)
|
||||
.context()
|
||||
.read(cx)
|
||||
.summary_or_default();
|
||||
.summary()
|
||||
.or_default();
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(summary, window, cx);
|
||||
@@ -314,7 +312,7 @@ impl ActiveView {
|
||||
let editor = editor.clone();
|
||||
move |assistant_context, event, window, cx| match event {
|
||||
ContextEvent::SummaryGenerated => {
|
||||
let summary = assistant_context.read(cx).summary_or_default();
|
||||
let summary = assistant_context.read(cx).summary().or_default();
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.set_text(summary, window, cx);
|
||||
@@ -361,21 +359,19 @@ pub struct AgentPanel {
|
||||
previous_view: Option<ActiveView>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
history: Entity<ThreadHistory>,
|
||||
hovered_recent_history_item: Option<usize>,
|
||||
assistant_dropdown_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
assistant_navigation_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
assistant_navigation_menu: Option<Entity<ContextMenu>>,
|
||||
width: Option<Pixels>,
|
||||
height: Option<Pixels>,
|
||||
zoomed: bool,
|
||||
pending_serialization: Option<Task<Result<()>>>,
|
||||
hide_trial_upsell: bool,
|
||||
_trial_markdown: Entity<Markdown>,
|
||||
}
|
||||
|
||||
impl AgentPanel {
|
||||
fn close_panel(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(PanelEvent::Close);
|
||||
}
|
||||
|
||||
fn serialize(&mut self, cx: &mut Context<Self>) {
|
||||
let width = self.width;
|
||||
self.pending_serialization = Some(cx.background_spawn(async move {
|
||||
@@ -705,11 +701,13 @@ impl AgentPanel {
|
||||
previous_view: None,
|
||||
history_store: history_store.clone(),
|
||||
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)),
|
||||
hovered_recent_history_item: None,
|
||||
assistant_dropdown_menu_handle: PopoverMenuHandle::default(),
|
||||
assistant_navigation_menu_handle: PopoverMenuHandle::default(),
|
||||
assistant_navigation_menu: None,
|
||||
width: None,
|
||||
height: None,
|
||||
zoomed: false,
|
||||
pending_serialization: None,
|
||||
hide_trial_upsell: false,
|
||||
_trial_markdown: trial_markdown,
|
||||
@@ -1151,6 +1149,17 @@ impl AgentPanel {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn toggle_zoom(&mut self, _: &ToggleZoom, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.zoomed {
|
||||
cx.emit(PanelEvent::ZoomOut);
|
||||
} else {
|
||||
if !self.focus_handle(cx).contains_focused(window, cx) {
|
||||
cx.focus_self(window);
|
||||
}
|
||||
cx.emit(PanelEvent::ZoomIn);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open_agent_diff(
|
||||
&mut self,
|
||||
_: &OpenAgentDiff,
|
||||
@@ -1423,6 +1432,15 @@ impl Panel for AgentPanel {
|
||||
fn enabled(&self, cx: &App) -> bool {
|
||||
AssistantSettings::get_global(cx).enabled
|
||||
}
|
||||
|
||||
fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool {
|
||||
self.zoomed
|
||||
}
|
||||
|
||||
fn set_zoomed(&mut self, zoomed: bool, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.zoomed = zoomed;
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentPanel {
|
||||
@@ -1435,23 +1453,45 @@ impl AgentPanel {
|
||||
..
|
||||
} => {
|
||||
let active_thread = self.thread.read(cx);
|
||||
let is_empty = active_thread.is_empty();
|
||||
|
||||
let summary = active_thread.summary(cx);
|
||||
|
||||
if is_empty {
|
||||
Label::new(Thread::DEFAULT_SUMMARY.clone())
|
||||
.truncate()
|
||||
.into_any_element()
|
||||
} else if summary.is_none() {
|
||||
Label::new(LOADING_SUMMARY_PLACEHOLDER)
|
||||
.truncate()
|
||||
.into_any_element()
|
||||
let state = if active_thread.is_empty() {
|
||||
&ThreadSummary::Pending
|
||||
} else {
|
||||
div()
|
||||
active_thread.summary(cx)
|
||||
};
|
||||
|
||||
match state {
|
||||
ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone())
|
||||
.truncate()
|
||||
.into_any_element(),
|
||||
ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
|
||||
.truncate()
|
||||
.into_any_element(),
|
||||
ThreadSummary::Ready(_) => div()
|
||||
.w_full()
|
||||
.child(change_title_editor.clone())
|
||||
.into_any_element()
|
||||
.into_any_element(),
|
||||
ThreadSummary::Error => h_flex()
|
||||
.w_full()
|
||||
.child(change_title_editor.clone())
|
||||
.child(
|
||||
ui::IconButton::new("retry-summary-generation", IconName::RotateCcw)
|
||||
.on_click({
|
||||
let active_thread = self.thread.clone();
|
||||
move |_, _window, cx| {
|
||||
active_thread.update(cx, |thread, cx| {
|
||||
thread.regenerate_summary(cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.tooltip(move |_window, cx| {
|
||||
cx.new(|_| {
|
||||
Tooltip::new("Failed to generate title")
|
||||
.meta("Click to try again")
|
||||
})
|
||||
.into()
|
||||
}),
|
||||
)
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
ActiveView::PromptEditor {
|
||||
@@ -1459,14 +1499,13 @@ impl AgentPanel {
|
||||
context_editor,
|
||||
..
|
||||
} => {
|
||||
let context_editor = context_editor.read(cx);
|
||||
let summary = context_editor.context().read(cx).summary();
|
||||
let summary = context_editor.read(cx).context().read(cx).summary();
|
||||
|
||||
match summary {
|
||||
None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone())
|
||||
ContextSummary::Pending => Label::new(ContextSummary::DEFAULT)
|
||||
.truncate()
|
||||
.into_any_element(),
|
||||
Some(summary) => {
|
||||
ContextSummary::Content(summary) => {
|
||||
if summary.done {
|
||||
div()
|
||||
.w_full()
|
||||
@@ -1478,6 +1517,28 @@ impl AgentPanel {
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
ContextSummary::Error => h_flex()
|
||||
.w_full()
|
||||
.child(title_editor.clone())
|
||||
.child(
|
||||
ui::IconButton::new("retry-summary-generation", IconName::RotateCcw)
|
||||
.on_click({
|
||||
let context_editor = context_editor.clone();
|
||||
move |_, _window, cx| {
|
||||
context_editor.update(cx, |context_editor, cx| {
|
||||
context_editor.regenerate_summary(cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.tooltip(move |_window, cx| {
|
||||
cx.new(|_| {
|
||||
Tooltip::new("Failed to generate title")
|
||||
.meta("Click to try again")
|
||||
})
|
||||
.into()
|
||||
}),
|
||||
)
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
ActiveView::History => Label::new("History").truncate().into_any_element(),
|
||||
@@ -1587,6 +1648,12 @@ impl AgentPanel {
|
||||
}),
|
||||
);
|
||||
|
||||
let zoom_in_label = if self.is_zoomed(window, cx) {
|
||||
"Zoom Out"
|
||||
} else {
|
||||
"Zoom In"
|
||||
};
|
||||
|
||||
let agent_extra_menu = PopoverMenu::new("agent-options-menu")
|
||||
.trigger_with_tooltip(
|
||||
IconButton::new("agent-options-menu", IconName::Ellipsis)
|
||||
@@ -1673,7 +1740,8 @@ impl AgentPanel {
|
||||
|
||||
menu = menu
|
||||
.action("Rules…", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration));
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.action(zoom_in_label, Box::new(ToggleZoom));
|
||||
menu
|
||||
}))
|
||||
});
|
||||
@@ -2067,6 +2135,7 @@ impl AgentPanel {
|
||||
|
||||
v_flex()
|
||||
.size_full()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.when(recent_history.is_empty(), |this| {
|
||||
let configuration_error_ref = &configuration_error;
|
||||
this.child(
|
||||
@@ -2221,7 +2290,7 @@ impl AgentPanel {
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.child(
|
||||
Label::new("Past Interactions")
|
||||
Label::new("Recent")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
@@ -2246,18 +2315,20 @@ impl AgentPanel {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.children(
|
||||
recent_history.into_iter().map(|entry| {
|
||||
recent_history.into_iter().enumerate().map(|(index, entry)| {
|
||||
// TODO: Add keyboard navigation.
|
||||
match entry {
|
||||
HistoryEntry::Thread(thread) => {
|
||||
PastThread::new(thread, cx.entity().downgrade(), false, vec![], EntryTimeFormat::DateAndTime)
|
||||
.into_any_element()
|
||||
}
|
||||
HistoryEntry::Context(context) => {
|
||||
PastContext::new(context, cx.entity().downgrade(), false, vec![], EntryTimeFormat::DateAndTime)
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
let is_hovered = self.hovered_recent_history_item == Some(index);
|
||||
HistoryEntryElement::new(entry.clone(), cx.entity().downgrade())
|
||||
.hovered(is_hovered)
|
||||
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
|
||||
if *is_hovered {
|
||||
this.hovered_recent_history_item = Some(index);
|
||||
} else if this.hovered_recent_history_item == Some(index) {
|
||||
this.hovered_recent_history_item = None;
|
||||
}
|
||||
cx.notify();
|
||||
}))
|
||||
.into_any_element()
|
||||
}),
|
||||
)
|
||||
)
|
||||
@@ -2369,9 +2440,6 @@ impl AgentPanel {
|
||||
.occlude()
|
||||
.child(match last_error {
|
||||
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
|
||||
ThreadError::MaxMonthlySpendReached => {
|
||||
self.render_max_monthly_spend_reached_error(cx)
|
||||
}
|
||||
ThreadError::ModelRequestLimitReached { plan } => {
|
||||
self.render_model_request_limit_reached_error(plan, cx)
|
||||
}
|
||||
@@ -2431,56 +2499,6 @@ impl AgentPanel {
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_max_monthly_spend_reached_error(&self, cx: &mut Context<Self>) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(ERROR_MESSAGE))
|
||||
.child(
|
||||
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||
cx.listener(|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
cx.notify();
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_model_request_limit_reached_error(
|
||||
&self,
|
||||
plan: Plan,
|
||||
@@ -2785,6 +2803,7 @@ impl Render for AgentPanel {
|
||||
.on_action(cx.listener(Self::increase_font_size))
|
||||
.on_action(cx.listener(Self::decrease_font_size))
|
||||
.on_action(cx.listener(Self::reset_font_size))
|
||||
.on_action(cx.listener(Self::toggle_zoom))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
.children(self.render_trial_upsell(window, cx))
|
||||
.map(|parent| match &self.active_view {
|
||||
|
||||
@@ -586,10 +586,7 @@ impl ThreadContextHandle {
|
||||
}
|
||||
|
||||
pub fn title(&self, cx: &App) -> SharedString {
|
||||
self.thread
|
||||
.read(cx)
|
||||
.summary()
|
||||
.unwrap_or_else(|| "New thread".into())
|
||||
self.thread.read(cx).summary().or_default()
|
||||
}
|
||||
|
||||
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
|
||||
@@ -597,9 +594,7 @@ impl ThreadContextHandle {
|
||||
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?;
|
||||
let title = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| {
|
||||
thread.summary().unwrap_or_else(|| "New thread".into())
|
||||
})
|
||||
.read_with(cx, |thread, _cx| thread.summary().or_default())
|
||||
.ok()?;
|
||||
let context = AgentContext::Thread(ThreadContext {
|
||||
title,
|
||||
@@ -642,7 +637,7 @@ impl TextThreadContextHandle {
|
||||
}
|
||||
|
||||
pub fn title(&self, cx: &App) -> SharedString {
|
||||
self.context.read(cx).summary_or_default()
|
||||
self.context.read(cx).summary().or_default()
|
||||
}
|
||||
|
||||
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
|
||||
@@ -830,23 +825,20 @@ pub fn load_context(
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
cx: &mut App,
|
||||
) -> Task<ContextLoadResult> {
|
||||
let mut load_tasks = Vec::new();
|
||||
|
||||
for context in contexts.iter().cloned() {
|
||||
match context {
|
||||
AgentContextHandle::File(context) => load_tasks.push(context.load(cx)),
|
||||
AgentContextHandle::Directory(context) => {
|
||||
load_tasks.push(context.load(project.clone(), cx))
|
||||
}
|
||||
AgentContextHandle::Symbol(context) => load_tasks.push(context.load(cx)),
|
||||
AgentContextHandle::Selection(context) => load_tasks.push(context.load(cx)),
|
||||
AgentContextHandle::FetchedUrl(context) => load_tasks.push(context.load()),
|
||||
AgentContextHandle::Thread(context) => load_tasks.push(context.load(cx)),
|
||||
AgentContextHandle::TextThread(context) => load_tasks.push(context.load(cx)),
|
||||
AgentContextHandle::Rules(context) => load_tasks.push(context.load(prompt_store, cx)),
|
||||
AgentContextHandle::Image(context) => load_tasks.push(context.load(cx)),
|
||||
}
|
||||
}
|
||||
let load_tasks: Vec<_> = contexts
|
||||
.into_iter()
|
||||
.map(|context| match context {
|
||||
AgentContextHandle::File(context) => context.load(cx),
|
||||
AgentContextHandle::Directory(context) => context.load(project.clone(), cx),
|
||||
AgentContextHandle::Symbol(context) => context.load(cx),
|
||||
AgentContextHandle::Selection(context) => context.load(cx),
|
||||
AgentContextHandle::FetchedUrl(context) => context.load(),
|
||||
AgentContextHandle::Thread(context) => context.load(cx),
|
||||
AgentContextHandle::TextThread(context) => context.load(cx),
|
||||
AgentContextHandle::Rules(context) => context.load(prompt_store, cx),
|
||||
AgentContextHandle::Image(context) => context.load(cx),
|
||||
})
|
||||
.collect();
|
||||
|
||||
cx.background_spawn(async move {
|
||||
let load_results = future::join_all(load_tasks).await;
|
||||
|
||||
@@ -381,6 +381,16 @@ impl ContextPicker {
|
||||
cx.focus_self(window);
|
||||
}
|
||||
|
||||
pub fn select_first(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
match &self.mode {
|
||||
ContextPickerState::Default(entity) => entity.update(cx, |entity, cx| {
|
||||
entity.select_first(&Default::default(), window, cx)
|
||||
}),
|
||||
// Other variants already select their first entry on open automatically
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn recent_menu_item(
|
||||
&self,
|
||||
context_picker: Entity<ContextPicker>,
|
||||
@@ -932,8 +942,8 @@ impl MentionLink {
|
||||
format!("[@{}]({}:{})", title, Self::THREAD, id)
|
||||
}
|
||||
ThreadContextEntry::Context { path, title } => {
|
||||
let filename = path.file_name().unwrap_or_default();
|
||||
let escaped_filename = urlencoding::encode(&filename.to_string_lossy()).to_string();
|
||||
let filename = path.file_name().unwrap_or_default().to_string_lossy();
|
||||
let escaped_filename = urlencoding::encode(&filename);
|
||||
format!(
|
||||
"[@{}]({}:{}{})",
|
||||
title,
|
||||
|
||||
@@ -84,6 +84,12 @@ impl ContextStrip {
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether or not the context strip has items to display
|
||||
pub fn has_context_items(&self, cx: &App) -> bool {
|
||||
self.context_store.read(cx).context().next().is_some()
|
||||
|| self.suggested_context(cx).is_some()
|
||||
}
|
||||
|
||||
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().read(cx);
|
||||
@@ -104,14 +110,14 @@ impl ContextStrip {
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
fn suggested_context(&self, cx: &App) -> Option<SuggestedContext> {
|
||||
match self.suggest_context_kind {
|
||||
SuggestContextKind::File => self.suggested_file(cx),
|
||||
SuggestContextKind::Thread => self.suggested_thread(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_file(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
fn suggested_file(&self, cx: &App) -> Option<SuggestedContext> {
|
||||
let workspace = self.workspace.upgrade()?;
|
||||
let active_item = workspace.read(cx).active_item(cx)?;
|
||||
|
||||
@@ -138,7 +144,7 @@ impl ContextStrip {
|
||||
})
|
||||
}
|
||||
|
||||
fn suggested_thread(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
fn suggested_thread(&self, cx: &App) -> Option<SuggestedContext> {
|
||||
if !self.context_picker.read(cx).allow_threads() {
|
||||
return None;
|
||||
}
|
||||
@@ -160,7 +166,7 @@ impl ContextStrip {
|
||||
}
|
||||
|
||||
Some(SuggestedContext::Thread {
|
||||
name: active_thread.summary_or_default(),
|
||||
name: active_thread.summary().or_default(),
|
||||
thread: weak_active_thread,
|
||||
})
|
||||
} else if let Some(active_context_editor) = panel.active_context_editor() {
|
||||
@@ -174,7 +180,7 @@ impl ContextStrip {
|
||||
}
|
||||
|
||||
Some(SuggestedContext::TextThread {
|
||||
name: context.summary_or_default(),
|
||||
name: context.summary().or_default(),
|
||||
context: weak_context,
|
||||
})
|
||||
} else {
|
||||
@@ -420,12 +426,25 @@ impl Render for ContextStrip {
|
||||
})
|
||||
.child(
|
||||
PopoverMenu::new("context-picker")
|
||||
.menu(move |window, cx| {
|
||||
context_picker.update(cx, |this, cx| {
|
||||
this.init(window, cx);
|
||||
});
|
||||
.menu({
|
||||
let context_picker = context_picker.clone();
|
||||
move |window, cx| {
|
||||
context_picker.update(cx, |this, cx| {
|
||||
this.init(window, cx);
|
||||
});
|
||||
|
||||
Some(context_picker.clone())
|
||||
Some(context_picker.clone())
|
||||
}
|
||||
})
|
||||
.on_open({
|
||||
let context_picker = context_picker.downgrade();
|
||||
Rc::new(move |window, cx| {
|
||||
context_picker
|
||||
.update(cx, |context_picker, cx| {
|
||||
context_picker.select_first(window, cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
})
|
||||
.trigger_with_tooltip(
|
||||
IconButton::new("add-context", IconName::Plus)
|
||||
|
||||
@@ -75,7 +75,7 @@ impl Default for DebugAccountState {
|
||||
Self {
|
||||
enabled: false,
|
||||
trial_expired: false,
|
||||
plan: Plan::Free,
|
||||
plan: Plan::ZedFree,
|
||||
custom_prompt_usage: RequestUsage {
|
||||
limit: UsageLimit::Unlimited,
|
||||
amount: 0,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::VecDeque, path::Path};
|
||||
use std::{collections::VecDeque, path::Path, sync::Arc};
|
||||
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use assistant_context_editor::{AssistantContext, SavedContextMetadata};
|
||||
@@ -34,6 +34,20 @@ impl HistoryEntry {
|
||||
HistoryEntry::Context(context) => context.mtime.to_utc(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> HistoryEntryId {
|
||||
match self {
|
||||
HistoryEntry::Thread(thread) => HistoryEntryId::Thread(thread.id.clone()),
|
||||
HistoryEntry::Context(context) => HistoryEntryId::Context(context.path.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic identifier for a history entry.
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub enum HistoryEntryId {
|
||||
Thread(ThreadId),
|
||||
Context(Arc<Path>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -57,8 +71,8 @@ impl Eq for RecentEntry {}
|
||||
impl RecentEntry {
|
||||
pub(crate) fn summary(&self, cx: &App) -> SharedString {
|
||||
match self {
|
||||
RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(),
|
||||
RecentEntry::Context(context) => context.read(cx).summary_or_default(),
|
||||
RecentEntry::Thread(_, thread) => thread.read(cx).summary().or_default(),
|
||||
RecentEntry::Context(context) => context.read(cx).summary().or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,13 +338,27 @@ impl InlineAssistant {
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
|
||||
(
|
||||
editor.snapshot(window, cx),
|
||||
editor.selections.all::<Point>(cx),
|
||||
)
|
||||
let (snapshot, initial_selections, newest_selection) = editor.update(cx, |editor, cx| {
|
||||
let selections = editor.selections.all::<Point>(cx);
|
||||
let newest_selection = editor.selections.newest::<Point>(cx);
|
||||
(editor.snapshot(window, cx), selections, newest_selection)
|
||||
});
|
||||
|
||||
// Check if there is already an inline assistant that contains the
|
||||
// newest selection, if there is, focus it
|
||||
if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
|
||||
for assist_id in &editor_assists.assist_ids {
|
||||
let assist = &self.assists[assist_id];
|
||||
let range = assist.range.to_point(&snapshot.buffer_snapshot);
|
||||
if range.start.row <= newest_selection.start.row
|
||||
&& newest_selection.end.row <= range.end.row
|
||||
{
|
||||
self.focus_assist(*assist_id, window, cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut selections = Vec::<Selection<Point>>::new();
|
||||
let mut newest_selection = None;
|
||||
for mut selection in initial_selections {
|
||||
|
||||
@@ -451,7 +451,7 @@ impl<T: 'static> PromptEditor<T> {
|
||||
editor.move_to_end(&Default::default(), window, cx)
|
||||
});
|
||||
}
|
||||
} else {
|
||||
} else if self.context_strip.read(cx).has_context_items(cx) {
|
||||
self.context_strip.focus_handle(cx).focus(window);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,7 +200,13 @@ impl MessageEditor {
|
||||
});
|
||||
|
||||
let profile_selector = cx.new(|cx| {
|
||||
ProfileSelector::new(thread.clone(), thread_store, editor.focus_handle(cx), cx)
|
||||
ProfileSelector::new(
|
||||
fs,
|
||||
thread.clone(),
|
||||
thread_store,
|
||||
editor.focus_handle(cx),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
Self {
|
||||
@@ -395,7 +401,7 @@ impl MessageEditor {
|
||||
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.context_picker_menu_handle.is_deployed() {
|
||||
cx.propagate();
|
||||
} else {
|
||||
} else if self.context_strip.read(cx).has_context_items(cx) {
|
||||
self.context_strip.focus_handle(cx).focus(window);
|
||||
}
|
||||
}
|
||||
@@ -1079,11 +1085,11 @@ impl MessageEditor {
|
||||
let plan = user_store
|
||||
.current_plan()
|
||||
.map(|plan| match plan {
|
||||
Plan::Free => zed_llm_client::Plan::Free,
|
||||
Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
})
|
||||
.unwrap_or(zed_llm_client::Plan::Free);
|
||||
.unwrap_or(zed_llm_client::Plan::ZedFree);
|
||||
let usage = self.thread.read(cx).last_usage().or_else(|| {
|
||||
maybe!({
|
||||
let amount = user_store.model_request_usage_amount()?;
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{
|
||||
AgentProfile, AgentProfileId, AssistantDockPosition, AssistantSettings, GroupedAgentProfiles,
|
||||
builtin_profiles,
|
||||
};
|
||||
use fs::Fs;
|
||||
use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use settings::{Settings as _, SettingsStore, update_settings_file};
|
||||
use ui::{
|
||||
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
|
||||
prelude::*,
|
||||
@@ -15,6 +18,7 @@ use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
|
||||
|
||||
pub struct ProfileSelector {
|
||||
profiles: GroupedAgentProfiles,
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
@@ -24,6 +28,7 @@ pub struct ProfileSelector {
|
||||
|
||||
impl ProfileSelector {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
focus_handle: FocusHandle,
|
||||
@@ -35,6 +40,7 @@ impl ProfileSelector {
|
||||
|
||||
Self {
|
||||
profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)),
|
||||
fs,
|
||||
thread,
|
||||
thread_store,
|
||||
menu_handle: PopoverMenuHandle::default(),
|
||||
@@ -95,7 +101,7 @@ impl ProfileSelector {
|
||||
profile_id: AgentProfileId,
|
||||
profile: &AgentProfile,
|
||||
settings: &AssistantSettings,
|
||||
cx: &App,
|
||||
_cx: &App,
|
||||
) -> ContextMenuEntry {
|
||||
let documentation = match profile.name.to_lowercase().as_str() {
|
||||
builtin_profiles::WRITE => Some("Get help to write anything."),
|
||||
@@ -104,12 +110,8 @@ impl ProfileSelector {
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let current_profile_id = self.thread.read(cx).configured_profile_id();
|
||||
|
||||
let entry = ContextMenuEntry::new(profile.name.clone()).toggleable(
|
||||
IconPosition::End,
|
||||
Some(profile_id.clone()) == current_profile_id,
|
||||
);
|
||||
let entry = ContextMenuEntry::new(profile.name.clone())
|
||||
.toggleable(IconPosition::End, profile_id == settings.default_profile);
|
||||
|
||||
let entry = if let Some(doc_text) = documentation {
|
||||
entry.documentation_aside(documentation_side(settings.dock), move |_| {
|
||||
@@ -120,13 +122,15 @@ impl ProfileSelector {
|
||||
};
|
||||
|
||||
entry.handler({
|
||||
let fs = self.fs.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let profile_id = profile_id.clone();
|
||||
let thread = self.thread.clone();
|
||||
|
||||
move |_window, cx| {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_configured_profile_id(Some(profile_id.clone()), cx);
|
||||
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
|
||||
let profile_id = profile_id.clone();
|
||||
move |settings, _cx| {
|
||||
settings.set_profile(profile_id.clone());
|
||||
}
|
||||
});
|
||||
|
||||
thread_store
|
||||
@@ -142,12 +146,8 @@ impl ProfileSelector {
|
||||
impl Render for ProfileSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
let profile_id = self
|
||||
.thread
|
||||
.read(cx)
|
||||
.configured_profile_id()
|
||||
.unwrap_or(settings.default_profile.clone());
|
||||
let profile = settings.profiles.get(&profile_id).cloned();
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
|
||||
let selected_profile = profile
|
||||
.map(|profile| profile.name.clone())
|
||||
|
||||
@@ -191,7 +191,7 @@ impl TerminalInlineAssistant {
|
||||
};
|
||||
|
||||
self.prompt_history.retain(|prompt| *prompt != user_prompt);
|
||||
self.prompt_history.push_back(user_prompt.clone());
|
||||
self.prompt_history.push_back(user_prompt);
|
||||
if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
|
||||
self.prompt_history.pop_front();
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_settings::{AgentProfileId, AssistantSettings, CompletionMode};
|
||||
use assistant_settings::{AssistantSettings, CompletionMode};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
@@ -22,7 +22,7 @@ use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
@@ -36,7 +36,7 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use thiserror::Error;
|
||||
use ui::Window;
|
||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||
use util::{ResultExt as _, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::CompletionRequestStatus;
|
||||
|
||||
@@ -324,7 +324,7 @@ pub enum QueueState {
|
||||
pub struct Thread {
|
||||
id: ThreadId,
|
||||
updated_at: DateTime<Utc>,
|
||||
summary: Option<SharedString>,
|
||||
summary: ThreadSummary,
|
||||
pending_summary: Task<Option<()>>,
|
||||
detailed_summary_task: Task<Option<()>>,
|
||||
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
|
||||
@@ -359,7 +359,33 @@ pub struct Thread {
|
||||
>,
|
||||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
configured_profile_id: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ThreadSummary {
|
||||
Pending,
|
||||
Generating,
|
||||
Ready(SharedString),
|
||||
Error,
|
||||
}
|
||||
|
||||
impl ThreadSummary {
|
||||
pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn or_default(&self) -> SharedString {
|
||||
self.unwrap_or(Self::DEFAULT)
|
||||
}
|
||||
|
||||
pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
|
||||
self.ready().unwrap_or_else(|| message.into())
|
||||
}
|
||||
|
||||
pub fn ready(&self) -> Option<SharedString> {
|
||||
match self {
|
||||
ThreadSummary::Ready(summary) => Some(summary.clone()),
|
||||
ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -380,13 +406,11 @@ impl Thread {
|
||||
) -> Self {
|
||||
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
|
||||
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
let configured_profile_id = assistant_settings.default_profile.clone();
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
updated_at: Utc::now(),
|
||||
summary: None,
|
||||
summary: ThreadSummary::Pending,
|
||||
pending_summary: Task::ready(None),
|
||||
detailed_summary_task: Task::ready(None),
|
||||
detailed_summary_tx,
|
||||
@@ -424,7 +448,6 @@ impl Thread {
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_profile_id: Some(configured_profile_id),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,7 +458,7 @@ impl Thread {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
project_context: SharedProjectContext,
|
||||
window: &mut Window,
|
||||
window: Option<&mut Window>, // None in headless mode
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let next_message_id = MessageId(
|
||||
@@ -472,12 +495,10 @@ impl Thread {
|
||||
.completion_mode
|
||||
.unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
|
||||
|
||||
let configured_profile_id = serialized.profile.clone();
|
||||
|
||||
Self {
|
||||
id,
|
||||
updated_at: serialized.updated_at,
|
||||
summary: Some(serialized.summary),
|
||||
summary: ThreadSummary::Ready(serialized.summary),
|
||||
pending_summary: Task::ready(None),
|
||||
detailed_summary_task: Task::ready(None),
|
||||
detailed_summary_tx,
|
||||
@@ -547,7 +568,6 @@ impl Thread {
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_profile_id,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -579,10 +599,6 @@ impl Thread {
|
||||
self.last_prompt_id = PromptId::new();
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> Option<SharedString> {
|
||||
self.summary.clone()
|
||||
}
|
||||
|
||||
pub fn project_context(&self) -> SharedProjectContext {
|
||||
self.project_context.clone()
|
||||
}
|
||||
@@ -603,39 +619,25 @@ impl Thread {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn configured_profile_id(&self) -> Option<AgentProfileId> {
|
||||
self.configured_profile_id.clone()
|
||||
}
|
||||
|
||||
pub fn set_configured_profile_id(
|
||||
&mut self,
|
||||
id: Option<AgentProfileId>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.configured_profile_id = id;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
|
||||
pub fn summary(&self) -> &ThreadSummary {
|
||||
&self.summary
|
||||
}
|
||||
|
||||
pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
|
||||
let Some(current_summary) = &self.summary else {
|
||||
// Don't allow setting summary until generated
|
||||
return;
|
||||
let current_summary = match &self.summary {
|
||||
ThreadSummary::Pending | ThreadSummary::Generating => return,
|
||||
ThreadSummary::Ready(summary) => summary,
|
||||
ThreadSummary::Error => &ThreadSummary::DEFAULT,
|
||||
};
|
||||
|
||||
let mut new_summary = new_summary.into();
|
||||
|
||||
if new_summary.is_empty() {
|
||||
new_summary = Self::DEFAULT_SUMMARY;
|
||||
new_summary = ThreadSummary::DEFAULT;
|
||||
}
|
||||
|
||||
if current_summary != &new_summary {
|
||||
self.summary = Some(new_summary);
|
||||
self.summary = ThreadSummary::Ready(new_summary);
|
||||
cx.emit(ThreadEvent::SummaryChanged);
|
||||
}
|
||||
}
|
||||
@@ -878,7 +880,13 @@ impl Thread {
|
||||
}
|
||||
|
||||
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
|
||||
Some(&self.tool_use.tool_result(id)?.content)
|
||||
match &self.tool_use.tool_result(id)?.content {
|
||||
LanguageModelToolResultContent::Text(str) => Some(str),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: We should display image
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
|
||||
@@ -1049,7 +1057,7 @@ impl Thread {
|
||||
let initial_project_snapshot = initial_project_snapshot.await;
|
||||
this.read_with(cx, |this, cx| SerializedThread {
|
||||
version: SerializedThread::VERSION.to_string(),
|
||||
summary: this.summary_or_default(),
|
||||
summary: this.summary().or_default(),
|
||||
updated_at: this.updated_at(),
|
||||
messages: this
|
||||
.messages()
|
||||
@@ -1120,7 +1128,6 @@ impl Thread {
|
||||
provider: model.provider.id().0.to_string(),
|
||||
model: model.model.id().0.to_string(),
|
||||
}),
|
||||
profile: this.configured_profile_id.clone(),
|
||||
completion_mode: Some(this.completion_mode),
|
||||
})
|
||||
})
|
||||
@@ -1646,7 +1653,7 @@ impl Thread {
|
||||
|
||||
// If there is a response without tool use, summarize the message. Otherwise,
|
||||
// allow two tool uses before summarizing.
|
||||
if thread.summary.is_none()
|
||||
if matches!(thread.summary, ThreadSummary::Pending)
|
||||
&& thread.messages.len() >= 2
|
||||
&& (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
|
||||
{
|
||||
@@ -1681,10 +1688,6 @@ impl Thread {
|
||||
|
||||
if error.is::<PaymentRequiredError>() {
|
||||
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
|
||||
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::MaxMonthlySpendReached,
|
||||
));
|
||||
} else if let Some(error) =
|
||||
error.downcast_ref::<ModelRequestLimitReachedError>()
|
||||
{
|
||||
@@ -1760,6 +1763,7 @@ impl Thread {
|
||||
|
||||
pub fn summarize(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
|
||||
println!("No thread summary model");
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -1774,13 +1778,17 @@ impl Thread {
|
||||
|
||||
let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
|
||||
|
||||
self.summary = ThreadSummary::Generating;
|
||||
|
||||
self.pending_summary = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let result = async {
|
||||
let mut messages = model.model.stream_completion(request, &cx).await?;
|
||||
|
||||
let mut new_summary = String::new();
|
||||
while let Some(event) = messages.next().await {
|
||||
let event = event?;
|
||||
let Ok(event) = event else {
|
||||
continue;
|
||||
};
|
||||
let text = match event {
|
||||
LanguageModelCompletionEvent::Text(text) => text,
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
@@ -1806,18 +1814,29 @@ impl Thread {
|
||||
}
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
if !new_summary.is_empty() {
|
||||
this.summary = Some(new_summary.into());
|
||||
}
|
||||
|
||||
cx.emit(ThreadEvent::SummaryGenerated);
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
anyhow::Ok(new_summary)
|
||||
}
|
||||
.log_err()
|
||||
.await
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
match result {
|
||||
Ok(new_summary) => {
|
||||
if new_summary.is_empty() {
|
||||
this.summary = ThreadSummary::Error;
|
||||
} else {
|
||||
this.summary = ThreadSummary::Ready(new_summary.into());
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
this.summary = ThreadSummary::Error;
|
||||
log::error!("Failed to generate thread summary: {}", err);
|
||||
}
|
||||
}
|
||||
cx.emit(ThreadEvent::SummaryGenerated);
|
||||
})
|
||||
.log_err()?;
|
||||
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2224,7 +2243,7 @@ impl Thread {
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.iter()
|
||||
.map(|tool| tool.name().to_string())
|
||||
.map(|tool| tool.name())
|
||||
.collect();
|
||||
|
||||
self.message_feedback.insert(message_id, feedback);
|
||||
@@ -2427,9 +2446,8 @@ impl Thread {
|
||||
pub fn to_markdown(&self, cx: &App) -> Result<String> {
|
||||
let mut markdown = Vec::new();
|
||||
|
||||
if let Some(summary) = self.summary() {
|
||||
writeln!(markdown, "# {summary}\n")?;
|
||||
};
|
||||
let summary = self.summary().or_default();
|
||||
writeln!(markdown, "# {summary}\n")?;
|
||||
|
||||
for message in self.messages() {
|
||||
writeln!(
|
||||
@@ -2486,7 +2504,15 @@ impl Thread {
|
||||
}
|
||||
|
||||
writeln!(markdown, "**\n")?;
|
||||
writeln!(markdown, "{}", tool_result.content)?;
|
||||
match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(str) => {
|
||||
writeln!(markdown, "{}", str)?;
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
writeln!(markdown, "", image.source)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(output) = tool_result.output.as_ref() {
|
||||
writeln!(
|
||||
markdown,
|
||||
@@ -2557,7 +2583,7 @@ impl Thread {
|
||||
.read(cx)
|
||||
.current_user()
|
||||
.map(|user| user.github_login.clone());
|
||||
let client = self.project.read(cx).client().clone();
|
||||
let client = self.project.read(cx).client();
|
||||
let serialize_task = self.serialize(cx);
|
||||
|
||||
cx.background_executor()
|
||||
@@ -2676,8 +2702,6 @@ impl Thread {
|
||||
pub enum ThreadError {
|
||||
#[error("Payment required")]
|
||||
PaymentRequired,
|
||||
#[error("Max monthly spend reached")]
|
||||
MaxMonthlySpendReached,
|
||||
#[error("Model request limit reached")]
|
||||
ModelRequestLimitReached { plan: Plan },
|
||||
#[error("Message {header}: {message}")]
|
||||
@@ -2746,7 +2770,7 @@ mod tests {
|
||||
use assistant_tool::ToolRegistry;
|
||||
use editor::EditorSettings;
|
||||
use gpui::TestAppContext;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
|
||||
use project::{FakeFs, Project};
|
||||
use prompt_store::PromptBuilder;
|
||||
use serde_json::json;
|
||||
@@ -3247,6 +3271,196 @@ fn main() {{
|
||||
assert_eq!(request.temperature, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_thread_summary(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(cx, json!({})).await;
|
||||
|
||||
let (_, _thread_store, thread, _context_store, model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Initial state should be pending
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Pending));
|
||||
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
|
||||
});
|
||||
|
||||
// Manually setting the summary should not be allowed in this state
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summary("This should not work", cx);
|
||||
});
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Pending));
|
||||
});
|
||||
|
||||
// Send a message
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
|
||||
thread.send_to_model(model.clone(), None, cx);
|
||||
});
|
||||
|
||||
let fake_model = model.as_fake();
|
||||
simulate_successful_response(&fake_model, cx);
|
||||
|
||||
// Should start generating summary when there are >= 2 messages
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(*thread.summary(), ThreadSummary::Generating);
|
||||
});
|
||||
|
||||
// Should not be able to set the summary while generating
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summary("This should not work either", cx);
|
||||
});
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Generating));
|
||||
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Brief".into());
|
||||
fake_model.stream_last_completion_response(" Introduction".into());
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Summary should be set
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
|
||||
assert_eq!(thread.summary().or_default(), "Brief Introduction");
|
||||
});
|
||||
|
||||
// Now we should be able to set a summary
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summary("Brief Intro", cx);
|
||||
});
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.summary().or_default(), "Brief Intro");
|
||||
});
|
||||
|
||||
// Test setting an empty summary (should default to DEFAULT)
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summary("", cx);
|
||||
});
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
|
||||
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(cx, json!({})).await;
|
||||
|
||||
let (_, _thread_store, thread, _context_store, model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
test_summarize_error(&model, &thread, cx);
|
||||
|
||||
// Now we should be able to set a summary
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_summary("Brief Intro", cx);
|
||||
});
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
|
||||
assert_eq!(thread.summary().or_default(), "Brief Intro");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(cx, json!({})).await;
|
||||
|
||||
let (_, _thread_store, thread, _context_store, model) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
test_summarize_error(&model, &thread, cx);
|
||||
|
||||
// Sending another message should not trigger another summarize request
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"How are you?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
vec![],
|
||||
cx,
|
||||
);
|
||||
thread.send_to_model(model.clone(), None, cx);
|
||||
});
|
||||
|
||||
let fake_model = model.as_fake();
|
||||
simulate_successful_response(&fake_model, cx);
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
// State is still Error, not Generating
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Error));
|
||||
});
|
||||
|
||||
// But the summarize request can be invoked manually
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.summarize(cx);
|
||||
});
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Generating));
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("A successful summary".into());
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
|
||||
assert_eq!(thread.summary().or_default(), "A successful summary");
|
||||
});
|
||||
}
|
||||
|
||||
fn test_summarize_error(
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
thread: &Entity<Thread>,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
|
||||
thread.send_to_model(model.clone(), None, cx);
|
||||
});
|
||||
|
||||
let fake_model = model.as_fake();
|
||||
simulate_successful_response(&fake_model, cx);
|
||||
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Generating));
|
||||
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
|
||||
});
|
||||
|
||||
// Simulate summary request ending
|
||||
cx.run_until_parked();
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// State is set to Error and default message
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert!(matches!(thread.summary(), ThreadSummary::Error));
|
||||
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
|
||||
});
|
||||
}
|
||||
|
||||
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Assistant response".into());
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
@@ -3303,9 +3517,29 @@ fn main() {{
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
||||
let model = FakeLanguageModel::default();
|
||||
let provider = Arc::new(FakeLanguageModelProvider);
|
||||
let model = provider.test_model();
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(model);
|
||||
|
||||
cx.update(|_, cx| {
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: provider.clone(),
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
registry.set_thread_summary_model(
|
||||
Some(ConfiguredModel {
|
||||
provider,
|
||||
model: model.clone(),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
});
|
||||
|
||||
(workspace, thread_store, thread, context_store, model)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ use std::fmt::Display;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_context_editor::SavedContextMetadata;
|
||||
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
|
||||
use editor::{Editor, EditorEvent};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{
|
||||
App, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
|
||||
App, ClickEvent, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
|
||||
UniformListScrollHandle, WeakEntity, Window, uniform_list,
|
||||
};
|
||||
use time::{OffsetDateTime, UtcOffset};
|
||||
@@ -18,7 +17,6 @@ use ui::{
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::history_store::{HistoryEntry, HistoryStore};
|
||||
use crate::thread_store::SerializedThreadMetadata;
|
||||
use crate::{AgentPanel, RemoveSelectedThread};
|
||||
|
||||
pub struct ThreadHistory {
|
||||
@@ -26,11 +24,12 @@ pub struct ThreadHistory {
|
||||
history_store: Entity<HistoryStore>,
|
||||
scroll_handle: UniformListScrollHandle,
|
||||
selected_index: usize,
|
||||
hovered_index: Option<usize>,
|
||||
search_editor: Entity<Editor>,
|
||||
all_entries: Arc<Vec<HistoryEntry>>,
|
||||
// When the search is empty, we display date separators between history entries
|
||||
// This vector contains an enum of either a separator or an actual entry
|
||||
separated_items: Vec<HistoryListItem>,
|
||||
separated_items: Vec<ListItemType>,
|
||||
// Maps entry indexes to list item indexes
|
||||
separated_item_indexes: Vec<u32>,
|
||||
_separated_items_task: Option<Task<()>>,
|
||||
@@ -52,7 +51,7 @@ enum SearchState {
|
||||
},
|
||||
}
|
||||
|
||||
enum HistoryListItem {
|
||||
enum ListItemType {
|
||||
BucketSeparator(TimeBucket),
|
||||
Entry {
|
||||
index: usize,
|
||||
@@ -60,11 +59,11 @@ enum HistoryListItem {
|
||||
},
|
||||
}
|
||||
|
||||
impl HistoryListItem {
|
||||
impl ListItemType {
|
||||
fn entry_index(&self) -> Option<usize> {
|
||||
match self {
|
||||
HistoryListItem::BucketSeparator(_) => None,
|
||||
HistoryListItem::Entry { index, .. } => Some(*index),
|
||||
ListItemType::BucketSeparator(_) => None,
|
||||
ListItemType::Entry { index, .. } => Some(*index),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,6 +101,7 @@ impl ThreadHistory {
|
||||
history_store,
|
||||
scroll_handle,
|
||||
selected_index: 0,
|
||||
hovered_index: None,
|
||||
search_state: SearchState::Empty,
|
||||
all_entries: Default::default(),
|
||||
separated_items: Default::default(),
|
||||
@@ -117,40 +117,21 @@ impl ThreadHistory {
|
||||
}
|
||||
|
||||
fn update_all_entries(&mut self, cx: &mut Context<Self>) {
|
||||
self.all_entries = self
|
||||
let new_entries: Arc<Vec<HistoryEntry>> = self
|
||||
.history_store
|
||||
.update(cx, |store, cx| store.entries(cx))
|
||||
.into();
|
||||
|
||||
self.set_selected_entry_index(0, cx);
|
||||
self.update_separated_items(cx);
|
||||
|
||||
match &self.search_state {
|
||||
SearchState::Empty => {}
|
||||
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
|
||||
self.search(query.clone(), cx);
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn update_separated_items(&mut self, cx: &mut Context<Self>) {
|
||||
self._separated_items_task.take();
|
||||
let all_entries = self.all_entries.clone();
|
||||
|
||||
let mut items = std::mem::take(&mut self.separated_items);
|
||||
let mut indexes = std::mem::take(&mut self.separated_item_indexes);
|
||||
items.clear();
|
||||
indexes.clear();
|
||||
// We know there's going to be at least one bucket separator
|
||||
items.reserve(all_entries.len() + 1);
|
||||
indexes.reserve(all_entries.len() + 1);
|
||||
let mut items = Vec::with_capacity(new_entries.len() + 1);
|
||||
let mut indexes = Vec::with_capacity(new_entries.len() + 1);
|
||||
|
||||
let bg_task = cx.background_spawn(async move {
|
||||
let mut bucket = None;
|
||||
let today = Local::now().naive_local().date();
|
||||
|
||||
for (index, entry) in all_entries.iter().enumerate() {
|
||||
for (index, entry) in new_entries.iter().enumerate() {
|
||||
let entry_date = entry
|
||||
.updated_at()
|
||||
.with_timezone(&Local)
|
||||
@@ -160,23 +141,50 @@ impl ThreadHistory {
|
||||
|
||||
if Some(entry_bucket) != bucket {
|
||||
bucket = Some(entry_bucket);
|
||||
items.push(HistoryListItem::BucketSeparator(entry_bucket));
|
||||
items.push(ListItemType::BucketSeparator(entry_bucket));
|
||||
}
|
||||
|
||||
indexes.push(items.len() as u32);
|
||||
items.push(HistoryListItem::Entry {
|
||||
items.push(ListItemType::Entry {
|
||||
index,
|
||||
format: entry_bucket.into(),
|
||||
});
|
||||
}
|
||||
(items, indexes)
|
||||
(new_entries, items, indexes)
|
||||
});
|
||||
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
let (items, indexes) = bg_task.await;
|
||||
let (new_entries, items, indexes) = bg_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
let previously_selected_entry =
|
||||
this.all_entries.get(this.selected_index).map(|e| e.id());
|
||||
|
||||
this.all_entries = new_entries;
|
||||
this.separated_items = items;
|
||||
this.separated_item_indexes = indexes;
|
||||
|
||||
match &this.search_state {
|
||||
SearchState::Empty => {
|
||||
if this.selected_index >= this.all_entries.len() {
|
||||
this.set_selected_entry_index(
|
||||
this.all_entries.len().saturating_sub(1),
|
||||
cx,
|
||||
);
|
||||
} else if let Some(prev_id) = previously_selected_entry {
|
||||
if let Some(new_ix) = this
|
||||
.all_entries
|
||||
.iter()
|
||||
.position(|probe| probe.id() == prev_id)
|
||||
{
|
||||
this.set_selected_entry_index(new_ix, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
|
||||
this.search(query.clone(), cx);
|
||||
}
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
@@ -252,10 +260,7 @@ impl ThreadHistory {
|
||||
}
|
||||
});
|
||||
|
||||
self.search_state = SearchState::Searching {
|
||||
query: query.clone(),
|
||||
_task: task,
|
||||
};
|
||||
self.search_state = SearchState::Searching { query, _task: task };
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
@@ -467,7 +472,7 @@ impl ThreadHistory {
|
||||
.map(|(ix, m)| {
|
||||
self.render_list_item(
|
||||
Some(range_start + ix),
|
||||
&HistoryListItem::Entry {
|
||||
&ListItemType::Entry {
|
||||
index: m.candidate_id,
|
||||
format: EntryTimeFormat::DateAndTime,
|
||||
},
|
||||
@@ -485,25 +490,36 @@ impl ThreadHistory {
|
||||
fn render_list_item(
|
||||
&self,
|
||||
list_entry_ix: Option<usize>,
|
||||
item: &HistoryListItem,
|
||||
item: &ListItemType,
|
||||
highlight_positions: Vec<usize>,
|
||||
cx: &App,
|
||||
cx: &Context<Self>,
|
||||
) -> AnyElement {
|
||||
match item {
|
||||
HistoryListItem::Entry { index, format } => match self.all_entries.get(*index) {
|
||||
ListItemType::Entry { index, format } => match self.all_entries.get(*index) {
|
||||
Some(entry) => h_flex()
|
||||
.w_full()
|
||||
.pb_1()
|
||||
.child(self.render_history_entry(
|
||||
entry,
|
||||
list_entry_ix == Some(self.selected_index),
|
||||
highlight_positions,
|
||||
*format,
|
||||
))
|
||||
.child(
|
||||
HistoryEntryElement::new(entry.clone(), self.agent_panel.clone())
|
||||
.highlight_positions(highlight_positions)
|
||||
.timestamp_format(*format)
|
||||
.selected(list_entry_ix == Some(self.selected_index))
|
||||
.hovered(list_entry_ix == self.hovered_index)
|
||||
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
|
||||
if *is_hovered {
|
||||
this.hovered_index = list_entry_ix;
|
||||
} else if this.hovered_index == list_entry_ix {
|
||||
this.hovered_index = None;
|
||||
}
|
||||
|
||||
cx.notify();
|
||||
}))
|
||||
.into_any_element(),
|
||||
)
|
||||
.into_any(),
|
||||
None => Empty.into_any_element(),
|
||||
},
|
||||
HistoryListItem::BucketSeparator(bucket) => div()
|
||||
ListItemType::BucketSeparator(bucket) => div()
|
||||
.px(DynamicSpacing::Base06.rems(cx))
|
||||
.pt_2()
|
||||
.pb_1()
|
||||
@@ -515,33 +531,6 @@ impl ThreadHistory {
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_history_entry(
|
||||
&self,
|
||||
entry: &HistoryEntry,
|
||||
is_active: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
format: EntryTimeFormat,
|
||||
) -> AnyElement {
|
||||
match entry {
|
||||
HistoryEntry::Thread(thread) => PastThread::new(
|
||||
thread.clone(),
|
||||
self.agent_panel.clone(),
|
||||
is_active,
|
||||
highlight_positions,
|
||||
format,
|
||||
)
|
||||
.into_any_element(),
|
||||
HistoryEntry::Context(context) => PastContext::new(
|
||||
context.clone(),
|
||||
self.agent_panel.clone(),
|
||||
is_active,
|
||||
highlight_positions,
|
||||
format,
|
||||
)
|
||||
.into_any_element(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for ThreadHistory {
|
||||
@@ -623,155 +612,97 @@ impl Render for ThreadHistory {
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct PastThread {
|
||||
thread: SerializedThreadMetadata,
|
||||
pub struct HistoryEntryElement {
|
||||
entry: HistoryEntry,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
hovered: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
on_hover: Box<dyn Fn(&bool, &mut Window, &mut App) + 'static>,
|
||||
}
|
||||
|
||||
impl PastThread {
|
||||
pub fn new(
|
||||
thread: SerializedThreadMetadata,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
) -> Self {
|
||||
impl HistoryEntryElement {
|
||||
pub fn new(entry: HistoryEntry, agent_panel: WeakEntity<AgentPanel>) -> Self {
|
||||
Self {
|
||||
thread,
|
||||
entry,
|
||||
agent_panel,
|
||||
selected,
|
||||
highlight_positions,
|
||||
timestamp_format,
|
||||
selected: false,
|
||||
hovered: false,
|
||||
highlight_positions: vec![],
|
||||
timestamp_format: EntryTimeFormat::DateAndTime,
|
||||
on_hover: Box::new(|_, _, _| {}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn selected(mut self, selected: bool) -> Self {
|
||||
self.selected = selected;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn hovered(mut self, hovered: bool) -> Self {
|
||||
self.hovered = hovered;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn highlight_positions(mut self, positions: Vec<usize>) -> Self {
|
||||
self.highlight_positions = positions;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self {
|
||||
self.on_hover = Box::new(on_hover);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn timestamp_format(mut self, format: EntryTimeFormat) -> Self {
|
||||
self.timestamp_format = format;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for PastThread {
|
||||
impl RenderOnce for HistoryEntryElement {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let summary = self.thread.summary;
|
||||
let (id, summary, timestamp) = match &self.entry {
|
||||
HistoryEntry::Thread(thread) => (
|
||||
thread.id.to_string(),
|
||||
thread.summary.clone(),
|
||||
thread.updated_at.timestamp(),
|
||||
),
|
||||
HistoryEntry::Context(context) => (
|
||||
context.path.to_string_lossy().to_string(),
|
||||
context.title.clone().into(),
|
||||
context.mtime.timestamp(),
|
||||
),
|
||||
};
|
||||
|
||||
let thread_timestamp = self.timestamp_format.format_timestamp(
|
||||
&self.agent_panel,
|
||||
self.thread.updated_at.timestamp(),
|
||||
cx,
|
||||
);
|
||||
let thread_timestamp =
|
||||
self.timestamp_format
|
||||
.format_timestamp(&self.agent_panel, timestamp, cx);
|
||||
|
||||
ListItem::new(SharedString::from(self.thread.id.to_string()))
|
||||
ListItem::new(SharedString::from(id))
|
||||
.rounded()
|
||||
.toggle_state(self.selected)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.start_slot(
|
||||
div().max_w_4_5().child(
|
||||
HighlightedLabel::new(summary, self.highlight_positions)
|
||||
.size(LabelSize::Small)
|
||||
.truncate(),
|
||||
),
|
||||
)
|
||||
.end_slot(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.w_full()
|
||||
.gap_2()
|
||||
.justify_between()
|
||||
.child(
|
||||
HighlightedLabel::new(summary, self.highlight_positions)
|
||||
.size(LabelSize::Small)
|
||||
.truncate(),
|
||||
)
|
||||
.child(
|
||||
Label::new(thread_timestamp)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::XSmall),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("delete", IconName::TrashAlt)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
|
||||
})
|
||||
.on_click({
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let id = self.thread.id.clone();
|
||||
move |_event, _window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.delete_thread(&id, cx).detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}),
|
||||
),
|
||||
)
|
||||
.on_click({
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let id = self.thread.id.clone();
|
||||
move |_event, window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.open_thread_by_id(&id, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct PastContext {
|
||||
context: SavedContextMetadata,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
}
|
||||
|
||||
impl PastContext {
|
||||
pub fn new(
|
||||
context: SavedContextMetadata,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
) -> Self {
|
||||
Self {
|
||||
context,
|
||||
agent_panel,
|
||||
selected,
|
||||
highlight_positions,
|
||||
timestamp_format,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for PastContext {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let summary = self.context.title;
|
||||
let context_timestamp = self.timestamp_format.format_timestamp(
|
||||
&self.agent_panel,
|
||||
self.context.mtime.timestamp(),
|
||||
cx,
|
||||
);
|
||||
|
||||
ListItem::new(SharedString::from(
|
||||
self.context.path.to_string_lossy().to_string(),
|
||||
))
|
||||
.rounded()
|
||||
.toggle_state(self.selected)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.start_slot(
|
||||
div().max_w_4_5().child(
|
||||
HighlightedLabel::new(summary, self.highlight_positions)
|
||||
.size(LabelSize::Small)
|
||||
.truncate(),
|
||||
),
|
||||
)
|
||||
.end_slot(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Label::new(context_timestamp)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::XSmall),
|
||||
)
|
||||
.child(
|
||||
.on_hover(self.on_hover)
|
||||
.end_slot::<IconButton>(if self.hovered || self.selected {
|
||||
Some(
|
||||
IconButton::new("delete", IconName::TrashAlt)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
@@ -781,30 +712,70 @@ impl RenderOnce for PastContext {
|
||||
})
|
||||
.on_click({
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let path = self.context.path.clone();
|
||||
move |_event, _window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.delete_context(path.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
let f: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static> =
|
||||
match &self.entry {
|
||||
HistoryEntry::Thread(thread) => {
|
||||
let id = thread.id.clone();
|
||||
|
||||
Box::new(move |_event, _window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.delete_thread(&id, cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
HistoryEntry::Context(context) => {
|
||||
let path = context.path.clone();
|
||||
|
||||
Box::new(move |_event, _window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.delete_context(path.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
};
|
||||
f
|
||||
}),
|
||||
),
|
||||
)
|
||||
.on_click({
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let path = self.context.path.clone();
|
||||
move |_event, window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.open_saved_prompt_editor(path.clone(), window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
.on_click({
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
|
||||
let f: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static> = match &self.entry
|
||||
{
|
||||
HistoryEntry::Thread(thread) => {
|
||||
let id = thread.id.clone();
|
||||
Box::new(move |_event, window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.open_thread_by_id(&id, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
HistoryEntry::Context(context) => {
|
||||
let path = context.path.clone();
|
||||
Box::new(move |_event, window, cx| {
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.open_saved_prompt_editor(path.clone(), window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
};
|
||||
f
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ use gpui::{
|
||||
};
|
||||
use heed::Database;
|
||||
use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
@@ -386,6 +386,25 @@ impl ThreadStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_thread_from_serialized(
|
||||
&mut self,
|
||||
serialized: SerializedThread,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Entity<Thread> {
|
||||
cx.new(|cx| {
|
||||
Thread::deserialize(
|
||||
ThreadId::new(),
|
||||
serialized,
|
||||
self.project.clone(),
|
||||
self.tools.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
self.project_context.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_thread(
|
||||
&self,
|
||||
id: &ThreadId,
|
||||
@@ -411,7 +430,7 @@ impl ThreadStore {
|
||||
this.tools.clone(),
|
||||
this.prompt_builder.clone(),
|
||||
this.project_context.clone(),
|
||||
window,
|
||||
Some(window),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -486,8 +505,8 @@ impl ThreadStore {
|
||||
ToolSource::Native,
|
||||
&profile
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
);
|
||||
@@ -511,32 +530,32 @@ impl ThreadStore {
|
||||
});
|
||||
}
|
||||
// Enable all the tools from all context servers, but disable the ones that are explicitly disabled
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
for (context_server_id, preset) in profile.context_servers {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.disable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
id: context_server_id.into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| (!enabled).then(|| tool))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
for (context_server_id, preset) in &profile.context_servers {
|
||||
for (context_server_id, preset) in profile.context_servers {
|
||||
self.tools.update(cx, |tools, cx| {
|
||||
tools.enable(
|
||||
ToolSource::ContextServer {
|
||||
id: context_server_id.clone().into(),
|
||||
id: context_server_id.into(),
|
||||
},
|
||||
&preset
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
|
||||
.into_iter()
|
||||
.filter_map(|(tool, enabled)| enabled.then(|| tool))
|
||||
.collect::<Vec<_>>(),
|
||||
cx,
|
||||
)
|
||||
@@ -657,8 +676,6 @@ pub struct SerializedThread {
|
||||
pub model: Option<SerializedLanguageModel>,
|
||||
#[serde(default)]
|
||||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
@@ -777,7 +794,7 @@ pub struct SerializedToolUse {
|
||||
pub struct SerializedToolResult {
|
||||
pub tool_use_id: LanguageModelToolUseId,
|
||||
pub is_error: bool,
|
||||
pub content: Arc<str>,
|
||||
pub content: LanguageModelToolResultContent,
|
||||
pub output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
@@ -804,7 +821,6 @@ impl LegacySerializedThread {
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
profile: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
|
||||
use assistant_tool::{
|
||||
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||
};
|
||||
use project::Project;
|
||||
use ui::{IconName, Window};
|
||||
@@ -52,15 +54,19 @@ impl ToolUseState {
|
||||
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
|
||||
///
|
||||
/// Accepts a function to filter the tools that should be used to populate the state.
|
||||
///
|
||||
/// If `window` is `None` (e.g., when in headless mode or when running evals),
|
||||
/// tool cards won't be deserialized
|
||||
pub fn from_serialized_messages(
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
window: Option<&mut Window>, // None in headless mode
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let mut this = Self::new(tools);
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
let mut window = window;
|
||||
|
||||
for message in messages {
|
||||
match message.role {
|
||||
@@ -105,12 +111,17 @@ impl ToolUseState {
|
||||
},
|
||||
);
|
||||
|
||||
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
|
||||
if let Some(output) = tool_result.output.clone() {
|
||||
if let Some(card) =
|
||||
tool.deserialize_card(output, project.clone(), window, cx)
|
||||
{
|
||||
this.tool_result_cards.insert(tool_use_id, card);
|
||||
if let Some(window) = &mut window {
|
||||
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
|
||||
if let Some(output) = tool_result.output.clone() {
|
||||
if let Some(card) = tool.deserialize_card(
|
||||
output,
|
||||
project.clone(),
|
||||
window,
|
||||
cx,
|
||||
) {
|
||||
this.tool_result_cards.insert(tool_use_id, card);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -165,10 +176,16 @@ impl ToolUseState {
|
||||
|
||||
let status = (|| {
|
||||
if let Some(tool_result) = tool_result {
|
||||
let content = tool_result
|
||||
.content
|
||||
.to_str()
|
||||
.map(|str| str.to_owned().into())
|
||||
.unwrap_or_default();
|
||||
|
||||
return if tool_result.is_error {
|
||||
ToolUseStatus::Error(tool_result.content.clone().into())
|
||||
ToolUseStatus::Error(content)
|
||||
} else {
|
||||
ToolUseStatus::Finished(tool_result.content.clone().into())
|
||||
ToolUseStatus::Finished(content)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -399,21 +416,45 @@ impl ToolUseState {
|
||||
let tool_result = output.content;
|
||||
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
|
||||
|
||||
// Protect from clearly large output
|
||||
let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
|
||||
|
||||
// Protect from overly large output
|
||||
let tool_output_limit = configured_model
|
||||
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
let tool_result = if tool_result.len() <= tool_output_limit {
|
||||
tool_result
|
||||
} else {
|
||||
let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
|
||||
let content = match tool_result {
|
||||
ToolResultContent::Text(text) => {
|
||||
let text = if text.len() < tool_output_limit {
|
||||
text
|
||||
} else {
|
||||
let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
|
||||
format!(
|
||||
"Tool result too long. The first {} bytes:\n\n{}",
|
||||
truncated.len(),
|
||||
truncated
|
||||
)
|
||||
};
|
||||
LanguageModelToolResultContent::Text(text.into())
|
||||
}
|
||||
ToolResultContent::Image(language_model_image) => {
|
||||
if language_model_image.estimate_tokens() < tool_output_limit {
|
||||
LanguageModelToolResultContent::Image(language_model_image)
|
||||
} else {
|
||||
self.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name,
|
||||
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
|
||||
is_error: true,
|
||||
output: None,
|
||||
},
|
||||
);
|
||||
|
||||
format!(
|
||||
"Tool result too long. The first {} bytes:\n\n{}",
|
||||
truncated.len(),
|
||||
truncated
|
||||
)
|
||||
return old_use;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.tool_results.insert(
|
||||
@@ -421,12 +462,13 @@ impl ToolUseState {
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name,
|
||||
content: tool_result.into(),
|
||||
content,
|
||||
is_error: false,
|
||||
output: output.output,
|
||||
},
|
||||
);
|
||||
self.pending_tool_uses_by_id.remove(&tool_use_id)
|
||||
|
||||
old_use
|
||||
}
|
||||
Err(err) => {
|
||||
self.tool_results.insert(
|
||||
@@ -434,7 +476,7 @@ impl ToolUseState {
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name,
|
||||
content: err.to_string().into(),
|
||||
content: LanguageModelToolResultContent::Text(err.to_string().into()),
|
||||
is_error: true,
|
||||
output: None,
|
||||
},
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use collections::HashMap;
|
||||
use component::ComponentId;
|
||||
use gpui::{App, Entity, WeakEntity};
|
||||
use linkme::distributed_slice;
|
||||
use std::sync::OnceLock;
|
||||
use ui::{AnyElement, Component, ComponentScope, Window};
|
||||
use workspace::Workspace;
|
||||
|
||||
@@ -12,9 +12,15 @@ use crate::ActiveThread;
|
||||
pub type PreviewFn =
|
||||
fn(WeakEntity<Workspace>, Entity<ActiveThread>, &mut Window, &mut App) -> Option<AnyElement>;
|
||||
|
||||
/// Distributed slice for preview registration functions
|
||||
#[distributed_slice]
|
||||
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
|
||||
pub struct AgentPreviewFn(fn() -> (ComponentId, PreviewFn));
|
||||
|
||||
impl AgentPreviewFn {
|
||||
pub const fn new(f: fn() -> (ComponentId, PreviewFn)) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
|
||||
inventory::collect!(AgentPreviewFn);
|
||||
|
||||
/// Trait that must be implemented by components that provide agent previews.
|
||||
pub trait AgentPreview: Component + Sized {
|
||||
@@ -36,16 +42,14 @@ pub trait AgentPreview: Component + Sized {
|
||||
#[macro_export]
|
||||
macro_rules! register_agent_preview {
|
||||
($type:ty) => {
|
||||
#[linkme::distributed_slice($crate::ui::preview::__ALL_AGENT_PREVIEWS)]
|
||||
static __REGISTER_AGENT_PREVIEW: fn() -> (
|
||||
component::ComponentId,
|
||||
$crate::ui::preview::PreviewFn,
|
||||
) = || {
|
||||
(
|
||||
<$type as component::Component>::id(),
|
||||
<$type as $crate::ui::preview::AgentPreview>::agent_preview,
|
||||
)
|
||||
};
|
||||
inventory::submit! {
|
||||
$crate::ui::preview::AgentPreviewFn::new(|| {
|
||||
(
|
||||
<$type as component::Component>::id(),
|
||||
<$type as $crate::ui::preview::AgentPreview>::agent_preview,
|
||||
)
|
||||
})
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -56,8 +60,8 @@ static AGENT_PREVIEW_REGISTRY: OnceLock<HashMap<ComponentId, PreviewFn>> = OnceL
|
||||
fn get_or_init_registry() -> &'static HashMap<ComponentId, PreviewFn> {
|
||||
AGENT_PREVIEW_REGISTRY.get_or_init(|| {
|
||||
let mut map = HashMap::default();
|
||||
for register_fn in __ALL_AGENT_PREVIEWS.iter() {
|
||||
let (id, preview_fn) = register_fn();
|
||||
for register_fn in inventory::iter::<AgentPreviewFn>() {
|
||||
let (id, preview_fn) = (register_fn.0)();
|
||||
map.insert(id, preview_fn);
|
||||
}
|
||||
map
|
||||
|
||||
@@ -39,7 +39,7 @@ impl RenderOnce for UsageCallout {
|
||||
|
||||
let (title, message, button_text, url) = if is_limit_reached {
|
||||
match self.plan {
|
||||
Plan::Free => (
|
||||
Plan::ZedFree => (
|
||||
"Out of free prompts",
|
||||
"Upgrade to continue, wait for the next reset, or switch to API key."
|
||||
.to_string(),
|
||||
@@ -61,7 +61,7 @@ impl RenderOnce for UsageCallout {
|
||||
}
|
||||
} else {
|
||||
match self.plan {
|
||||
Plan::Free => (
|
||||
Plan::ZedFree => (
|
||||
"Reaching free plan limit soon",
|
||||
format!(
|
||||
"{remaining} remaining - Upgrade to increase limit, or switch providers",
|
||||
@@ -120,7 +120,7 @@ impl Component for UsageCallout {
|
||||
single_example(
|
||||
"Approaching limit (90%)",
|
||||
UsageCallout::new(
|
||||
Plan::Free,
|
||||
Plan::ZedFree,
|
||||
RequestUsage {
|
||||
limit: UsageLimit::Limited(50),
|
||||
amount: 45, // 90% of limit
|
||||
@@ -131,7 +131,7 @@ impl Component for UsageCallout {
|
||||
single_example(
|
||||
"Limit reached (100%)",
|
||||
UsageCallout::new(
|
||||
Plan::Free,
|
||||
Plan::ZedFree,
|
||||
RequestUsage {
|
||||
limit: UsageLimit::Limited(50),
|
||||
amount: 50, // 100% of limit
|
||||
|
||||
@@ -534,12 +534,26 @@ pub enum RequestContent {
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
is_error: bool,
|
||||
content: String,
|
||||
content: ToolResultContent,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolResultContent {
|
||||
Plain(String),
|
||||
Multipart(Vec<ToolResultPart>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ToolResultPart {
|
||||
Text { text: String },
|
||||
Image { source: ImageSource },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseContent {
|
||||
|
||||
@@ -163,20 +163,26 @@ impl AskPassSession {
|
||||
#[cfg(unix)]
|
||||
fn get_shell_safe_zed_path() -> anyhow::Result<String> {
|
||||
let zed_path = std::env::current_exe()
|
||||
.context("Failed to figure out current executable path for use in askpass")?
|
||||
.context("Failed to determine current executable path for use in askpass")?
|
||||
.to_string_lossy()
|
||||
// see https://github.com/rust-lang/rust/issues/69343
|
||||
.trim_end_matches(" (deleted)")
|
||||
.to_string();
|
||||
|
||||
// sanity check on unix systems that the path exists and is executable
|
||||
// todo(windows): implement this check for windows (or just use `is-executable` crate)
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
let metadata = std::fs::metadata(&zed_path)
|
||||
.context("Failed to check metadata of Zed executable path for use in askpass")?;
|
||||
let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
|
||||
anyhow::ensure!(
|
||||
is_executable,
|
||||
"Failed to verify Zed executable path for use in askpass"
|
||||
);
|
||||
// NOTE: this was previously enabled, however, it caused errors when it shouldn't have
|
||||
// (see https://github.com/zed-industries/zed/issues/29819)
|
||||
// The zed path failing to execute within the askpass script results in very vague ssh
|
||||
// authentication failed errors, so this was done to try and surface a better error
|
||||
//
|
||||
// use std::os::unix::fs::MetadataExt;
|
||||
// let metadata = std::fs::metadata(&zed_path)
|
||||
// .context("Failed to check metadata of Zed executable path for use in askpass")?;
|
||||
// let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
|
||||
// anyhow::ensure!(
|
||||
// is_executable,
|
||||
// "Failed to verify Zed executable path for use in askpass"
|
||||
// );
|
||||
|
||||
// As of writing, this can only be fail if the path contains a null byte, which shouldn't be possible
|
||||
// but shlex has annotated the error as #[non_exhaustive] so we can't make it a compile error if other
|
||||
// errors are introduced in the future :(
|
||||
|
||||
@@ -8,7 +8,8 @@ mod slash_command_picker;
|
||||
use std::sync::Arc;
|
||||
|
||||
use client::Client;
|
||||
use gpui::App;
|
||||
use gpui::{App, Context};
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::context::*;
|
||||
pub use crate::context_editor::*;
|
||||
@@ -16,6 +17,18 @@ pub use crate::context_history::*;
|
||||
pub use crate::context_store::*;
|
||||
pub use crate::slash_command::*;
|
||||
|
||||
pub fn init(client: Arc<Client>, _cx: &mut App) {
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
context_store::init(&client.into());
|
||||
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
|
||||
|
||||
cx.observe_new(
|
||||
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
|
||||
workspace
|
||||
.register_action(ContextEditor::quote_selection)
|
||||
.register_action(ContextEditor::insert_selection)
|
||||
.register_action(ContextEditor::copy_code)
|
||||
.register_action(ContextEditor::handle_insert_dragged_files);
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#[cfg(test)]
|
||||
mod context_tests;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_slash_command::{
|
||||
SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection,
|
||||
@@ -21,8 +21,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
|
||||
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||
Role, StopReason, report_assistant_event,
|
||||
LanguageModelToolUseId, MessageContent, PaymentRequiredError, Role, StopReason,
|
||||
report_assistant_event,
|
||||
};
|
||||
use open_ai::Model as OpenAiModel;
|
||||
use paths::contexts_dir;
|
||||
@@ -133,7 +133,7 @@ pub enum ContextOperation {
|
||||
version: clock::Global,
|
||||
},
|
||||
UpdateSummary {
|
||||
summary: ContextSummary,
|
||||
summary: ContextSummaryContent,
|
||||
version: clock::Global,
|
||||
},
|
||||
SlashCommandStarted {
|
||||
@@ -203,7 +203,7 @@ impl ContextOperation {
|
||||
version: language::proto::deserialize_version(&update.version),
|
||||
}),
|
||||
proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
|
||||
summary: ContextSummary {
|
||||
summary: ContextSummaryContent {
|
||||
text: update.summary,
|
||||
done: update.done,
|
||||
timestamp: language::proto::deserialize_timestamp(
|
||||
@@ -447,7 +447,6 @@ impl ContextOperation {
|
||||
pub enum ContextEvent {
|
||||
ShowAssistError(SharedString),
|
||||
ShowPaymentRequiredError,
|
||||
ShowMaxMonthlySpendReachedError,
|
||||
MessagesEdited,
|
||||
SummaryChanged,
|
||||
SummaryGenerated,
|
||||
@@ -467,11 +466,73 @@ pub enum ContextEvent {
|
||||
Operation(ContextOperation),
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct ContextSummary {
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ContextSummary {
|
||||
Pending,
|
||||
Content(ContextSummaryContent),
|
||||
Error,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ContextSummaryContent {
|
||||
pub text: String,
|
||||
pub done: bool,
|
||||
timestamp: clock::Lamport,
|
||||
pub timestamp: clock::Lamport,
|
||||
}
|
||||
|
||||
impl ContextSummary {
|
||||
pub const DEFAULT: &str = "New Text Thread";
|
||||
|
||||
pub fn or_default(&self) -> SharedString {
|
||||
self.unwrap_or(Self::DEFAULT)
|
||||
}
|
||||
|
||||
pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
|
||||
self.content()
|
||||
.map_or_else(|| message.into(), |content| content.text.clone().into())
|
||||
}
|
||||
|
||||
pub fn content(&self) -> Option<&ContextSummaryContent> {
|
||||
match self {
|
||||
ContextSummary::Content(content) => Some(content),
|
||||
ContextSummary::Pending | ContextSummary::Error => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn content_as_mut(&mut self) -> Option<&mut ContextSummaryContent> {
|
||||
match self {
|
||||
ContextSummary::Content(content) => Some(content),
|
||||
ContextSummary::Pending | ContextSummary::Error => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn content_or_set_empty(&mut self) -> &mut ContextSummaryContent {
|
||||
match self {
|
||||
ContextSummary::Content(content) => content,
|
||||
ContextSummary::Pending | ContextSummary::Error => {
|
||||
let content = ContextSummaryContent::default();
|
||||
*self = ContextSummary::Content(content);
|
||||
self.content_as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_pending(&self) -> bool {
|
||||
matches!(self, ContextSummary::Pending)
|
||||
}
|
||||
|
||||
fn timestamp(&self) -> Option<clock::Lamport> {
|
||||
match self {
|
||||
ContextSummary::Content(content) => Some(content.timestamp),
|
||||
ContextSummary::Pending | ContextSummary::Error => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for ContextSummary {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.timestamp().partial_cmp(&other.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
@@ -607,7 +668,7 @@ pub struct AssistantContext {
|
||||
message_anchors: Vec<MessageAnchor>,
|
||||
contents: Vec<Content>,
|
||||
messages_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: Option<ContextSummary>,
|
||||
summary: ContextSummary,
|
||||
summary_task: Task<Option<()>>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
@@ -694,7 +755,7 @@ impl AssistantContext {
|
||||
slash_command_output_sections: Vec::new(),
|
||||
thought_process_output_sections: Vec::new(),
|
||||
edits_since_last_parse: edits_since_last_slash_command_parse,
|
||||
summary: None,
|
||||
summary: ContextSummary::Pending,
|
||||
summary_task: Task::ready(None),
|
||||
completion_count: Default::default(),
|
||||
pending_completions: Default::default(),
|
||||
@@ -753,7 +814,7 @@ impl AssistantContext {
|
||||
.collect(),
|
||||
summary: self
|
||||
.summary
|
||||
.as_ref()
|
||||
.content()
|
||||
.map(|summary| summary.text.clone())
|
||||
.unwrap_or_default(),
|
||||
slash_command_output_sections: self
|
||||
@@ -939,12 +1000,10 @@ impl AssistantContext {
|
||||
summary: new_summary,
|
||||
..
|
||||
} => {
|
||||
if self
|
||||
.summary
|
||||
.as_ref()
|
||||
.map_or(true, |summary| new_summary.timestamp > summary.timestamp)
|
||||
{
|
||||
self.summary = Some(new_summary);
|
||||
if self.summary.timestamp().map_or(true, |current_timestamp| {
|
||||
new_summary.timestamp > current_timestamp
|
||||
}) {
|
||||
self.summary = ContextSummary::Content(new_summary);
|
||||
summary_generated = true;
|
||||
}
|
||||
}
|
||||
@@ -1102,8 +1161,8 @@ impl AssistantContext {
|
||||
self.path.as_ref()
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> Option<&ContextSummary> {
|
||||
self.summary.as_ref()
|
||||
pub fn summary(&self) -> &ContextSummary {
|
||||
&self.summary
|
||||
}
|
||||
|
||||
pub fn parsed_slash_commands(&self) -> &[ParsedSlashCommand] {
|
||||
@@ -2095,12 +2154,6 @@ impl AssistantContext {
|
||||
metadata.status = MessageStatus::Canceled;
|
||||
});
|
||||
Some(error.to_string())
|
||||
} else if error.is::<MaxMonthlySpendReachedError>() {
|
||||
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
|
||||
this.update_metadata(assistant_message_id, cx, |metadata| {
|
||||
metadata.status = MessageStatus::Canceled;
|
||||
});
|
||||
Some(error.to_string())
|
||||
} else {
|
||||
let error_message = error
|
||||
.chain()
|
||||
@@ -2576,7 +2629,7 @@ impl AssistantContext {
|
||||
return;
|
||||
};
|
||||
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
|
||||
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_pending()) {
|
||||
if !model.provider.is_authenticated(cx) {
|
||||
return;
|
||||
}
|
||||
@@ -2593,17 +2646,20 @@ impl AssistantContext {
|
||||
|
||||
// If there is no summary, it is set with `done: false` so that "Loading Summary…" can
|
||||
// be displayed.
|
||||
if self.summary.is_none() {
|
||||
self.summary = Some(ContextSummary {
|
||||
text: "".to_string(),
|
||||
done: false,
|
||||
timestamp: clock::Lamport::default(),
|
||||
});
|
||||
replace_old = true;
|
||||
match self.summary {
|
||||
ContextSummary::Pending | ContextSummary::Error => {
|
||||
self.summary = ContextSummary::Content(ContextSummaryContent {
|
||||
text: "".to_string(),
|
||||
done: false,
|
||||
timestamp: clock::Lamport::default(),
|
||||
});
|
||||
replace_old = true;
|
||||
}
|
||||
ContextSummary::Content(_) => {}
|
||||
}
|
||||
|
||||
self.summary_task = cx.spawn(async move |this, cx| {
|
||||
async move {
|
||||
let result = async {
|
||||
let stream = model.model.stream_completion_text(request, &cx);
|
||||
let mut messages = stream.await?;
|
||||
|
||||
@@ -2614,7 +2670,7 @@ impl AssistantContext {
|
||||
this.update(cx, |this, cx| {
|
||||
let version = this.version.clone();
|
||||
let timestamp = this.next_timestamp();
|
||||
let summary = this.summary.get_or_insert(ContextSummary::default());
|
||||
let summary = this.summary.content_or_set_empty();
|
||||
if !replaced && replace_old {
|
||||
summary.text.clear();
|
||||
replaced = true;
|
||||
@@ -2636,10 +2692,19 @@ impl AssistantContext {
|
||||
}
|
||||
}
|
||||
|
||||
this.read_with(cx, |this, _cx| {
|
||||
if let Some(summary) = this.summary.content() {
|
||||
if summary.text.is_empty() {
|
||||
bail!("Model generated an empty summary");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})??;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let version = this.version.clone();
|
||||
let timestamp = this.next_timestamp();
|
||||
if let Some(summary) = this.summary.as_mut() {
|
||||
if let Some(summary) = this.summary.content_as_mut() {
|
||||
summary.done = true;
|
||||
summary.timestamp = timestamp;
|
||||
let operation = ContextOperation::UpdateSummary {
|
||||
@@ -2654,8 +2719,18 @@ impl AssistantContext {
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
.await
|
||||
.await;
|
||||
|
||||
if let Err(err) = result {
|
||||
this.update(cx, |this, cx| {
|
||||
this.summary = ContextSummary::Error;
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
})
|
||||
.log_err();
|
||||
log::error!("Error generating context summary: {}", err);
|
||||
}
|
||||
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -2769,7 +2844,7 @@ impl AssistantContext {
|
||||
|
||||
let (old_path, summary) = this.read_with(cx, |this, _| {
|
||||
let path = this.path.clone();
|
||||
let summary = if let Some(summary) = this.summary.as_ref() {
|
||||
let summary = if let Some(summary) = this.summary.content() {
|
||||
if summary.done {
|
||||
Some(summary.text.clone())
|
||||
} else {
|
||||
@@ -2823,21 +2898,12 @@ impl AssistantContext {
|
||||
|
||||
pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context<Self>) {
|
||||
let timestamp = self.next_timestamp();
|
||||
let summary = self.summary.get_or_insert(ContextSummary::default());
|
||||
let summary = self.summary.content_or_set_empty();
|
||||
summary.timestamp = timestamp;
|
||||
summary.done = true;
|
||||
summary.text = custom_summary;
|
||||
cx.emit(ContextEvent::SummaryChanged);
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
self.summary
|
||||
.as_ref()
|
||||
.map(|summary| summary.text.clone().into())
|
||||
.unwrap_or(Self::DEFAULT_SUMMARY)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -3053,7 +3119,7 @@ impl SavedContext {
|
||||
|
||||
let timestamp = next_timestamp.tick();
|
||||
operations.push(ContextOperation::UpdateSummary {
|
||||
summary: ContextSummary {
|
||||
summary: ContextSummaryContent {
|
||||
text: self.summary,
|
||||
done: true,
|
||||
timestamp,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation,
|
||||
AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, ContextSummary,
|
||||
InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
|
||||
};
|
||||
use anyhow::Result;
|
||||
@@ -16,7 +16,10 @@ use futures::{
|
||||
};
|
||||
use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
|
||||
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
|
||||
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
|
||||
fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::Project;
|
||||
@@ -1177,6 +1180,187 @@ fn test_mark_cache_anchors(cx: &mut App) {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_summarization(cx: &mut TestAppContext) {
|
||||
let (context, fake_model) = setup_context_editor_with_fake_model(cx);
|
||||
|
||||
// Initial state should be pending
|
||||
context.read_with(cx, |context, _| {
|
||||
assert!(matches!(context.summary(), ContextSummary::Pending));
|
||||
assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
|
||||
});
|
||||
|
||||
let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
|
||||
context.update(cx, |context, cx| {
|
||||
context
|
||||
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Send a message
|
||||
context.update(cx, |context, cx| {
|
||||
context.assist(cx);
|
||||
});
|
||||
|
||||
simulate_successful_response(&fake_model, cx);
|
||||
|
||||
// Should start generating summary when there are >= 2 messages
|
||||
context.read_with(cx, |context, _| {
|
||||
assert!(!context.summary().content().unwrap().done);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Brief".into());
|
||||
fake_model.stream_last_completion_response(" Introduction".into());
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// Summary should be set
|
||||
context.read_with(cx, |context, _| {
|
||||
assert_eq!(context.summary().or_default(), "Brief Introduction");
|
||||
});
|
||||
|
||||
// We should be able to manually set a summary
|
||||
context.update(cx, |context, cx| {
|
||||
context.set_custom_summary("Brief Intro".into(), cx);
|
||||
});
|
||||
|
||||
context.read_with(cx, |context, _| {
|
||||
assert_eq!(context.summary().or_default(), "Brief Intro");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
|
||||
let (context, fake_model) = setup_context_editor_with_fake_model(cx);
|
||||
|
||||
test_summarize_error(&fake_model, &context, cx);
|
||||
|
||||
// Now we should be able to set a summary
|
||||
context.update(cx, |context, cx| {
|
||||
context.set_custom_summary("Brief Intro".into(), cx);
|
||||
});
|
||||
|
||||
context.read_with(cx, |context, _| {
|
||||
assert_eq!(context.summary().or_default(), "Brief Intro");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
|
||||
let (context, fake_model) = setup_context_editor_with_fake_model(cx);
|
||||
|
||||
test_summarize_error(&fake_model, &context, cx);
|
||||
|
||||
// Sending another message should not trigger another summarize request
|
||||
context.update(cx, |context, cx| {
|
||||
context.assist(cx);
|
||||
});
|
||||
|
||||
simulate_successful_response(&fake_model, cx);
|
||||
|
||||
context.read_with(cx, |context, _| {
|
||||
// State is still Error, not Generating
|
||||
assert!(matches!(context.summary(), ContextSummary::Error));
|
||||
});
|
||||
|
||||
// But the summarize request can be invoked manually
|
||||
context.update(cx, |context, cx| {
|
||||
context.summarize(true, cx);
|
||||
});
|
||||
|
||||
context.read_with(cx, |context, _| {
|
||||
assert!(!context.summary().content().unwrap().done);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("A successful summary".into());
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
context.read_with(cx, |context, _| {
|
||||
assert_eq!(context.summary().or_default(), "A successful summary");
|
||||
});
|
||||
}
|
||||
|
||||
fn test_summarize_error(
|
||||
model: &Arc<FakeLanguageModel>,
|
||||
context: &Entity<AssistantContext>,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
|
||||
context.update(cx, |context, cx| {
|
||||
context
|
||||
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Send a message
|
||||
context.update(cx, |context, cx| {
|
||||
context.assist(cx);
|
||||
});
|
||||
|
||||
simulate_successful_response(&model, cx);
|
||||
|
||||
context.read_with(cx, |context, _| {
|
||||
assert!(!context.summary().content().unwrap().done);
|
||||
});
|
||||
|
||||
// Simulate summary request ending
|
||||
cx.run_until_parked();
|
||||
model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
|
||||
// State is set to Error and default message
|
||||
context.read_with(cx, |context, _| {
|
||||
assert_eq!(*context.summary(), ContextSummary::Error);
|
||||
assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
|
||||
});
|
||||
}
|
||||
|
||||
fn setup_context_editor_with_fake_model(
|
||||
cx: &mut TestAppContext,
|
||||
) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
|
||||
|
||||
let fake_provider = Arc::new(FakeLanguageModelProvider);
|
||||
let fake_model = Arc::new(fake_provider.test_model());
|
||||
|
||||
cx.update(|cx| {
|
||||
init_test(cx);
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.set_default_model(
|
||||
Some(ConfiguredModel {
|
||||
provider: fake_provider.clone(),
|
||||
model: fake_model.clone(),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context = cx.new(|cx| {
|
||||
AssistantContext::local(
|
||||
registry,
|
||||
None,
|
||||
None,
|
||||
prompt_builder.clone(),
|
||||
Arc::new(SlashCommandWorkingSet::default()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
(context, fake_model)
|
||||
}
|
||||
|
||||
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
|
||||
cx.run_until_parked();
|
||||
fake_model.stream_last_completion_response("Assistant response".into());
|
||||
fake_model.end_last_completion_stream();
|
||||
cx.run_until_parked();
|
||||
}
|
||||
|
||||
fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
|
||||
context
|
||||
.read(cx)
|
||||
|
||||
@@ -114,7 +114,6 @@ type MessageHeader = MessageMetadata;
|
||||
#[derive(Clone)]
|
||||
enum AssistError {
|
||||
PaymentRequired,
|
||||
MaxMonthlySpendReached,
|
||||
Message(SharedString),
|
||||
}
|
||||
|
||||
@@ -732,9 +731,6 @@ impl ContextEditor {
|
||||
ContextEvent::ShowPaymentRequiredError => {
|
||||
self.last_error = Some(AssistError::PaymentRequired);
|
||||
}
|
||||
ContextEvent::ShowMaxMonthlySpendReachedError => {
|
||||
self.last_error = Some(AssistError::MaxMonthlySpendReached);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1594,7 +1590,7 @@ impl ContextEditor {
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (String, CopyMetadata, Vec<text::Selection<usize>>) {
|
||||
let (selection, creases) = self.editor.update(cx, |editor, cx| {
|
||||
let (mut selection, creases) = self.editor.update(cx, |editor, cx| {
|
||||
let mut selection = editor.selections.newest_adjusted(cx);
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
|
||||
@@ -1652,7 +1648,18 @@ impl ContextEditor {
|
||||
} else if message.offset_range.end >= selection.range().start {
|
||||
let range = cmp::max(message.offset_range.start, selection.range().start)
|
||||
..cmp::min(message.offset_range.end, selection.range().end);
|
||||
if !range.is_empty() {
|
||||
if range.is_empty() {
|
||||
let snapshot = context.buffer().read(cx).snapshot();
|
||||
let point = snapshot.offset_to_point(range.start);
|
||||
selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
|
||||
selection.end = snapshot.point_to_offset(cmp::min(
|
||||
Point::new(point.row + 1, 0),
|
||||
snapshot.max_point(),
|
||||
));
|
||||
for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
|
||||
text.push_str(chunk);
|
||||
}
|
||||
} else {
|
||||
for chunk in context.buffer().read(cx).text_for_range(range) {
|
||||
text.push_str(chunk);
|
||||
}
|
||||
@@ -1860,7 +1867,12 @@ impl ContextEditor {
|
||||
}
|
||||
|
||||
pub fn title(&self, cx: &App) -> SharedString {
|
||||
self.context.read(cx).summary_or_default()
|
||||
self.context.read(cx).summary().or_default()
|
||||
}
|
||||
|
||||
pub fn regenerate_summary(&mut self, cx: &mut Context<Self>) {
|
||||
self.context
|
||||
.update(cx, |context, cx| context.summarize(true, cx));
|
||||
}
|
||||
|
||||
fn render_notice(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
@@ -1894,11 +1906,24 @@ impl ContextEditor {
|
||||
.log_err();
|
||||
|
||||
if let Some(client) = client {
|
||||
cx.spawn(async move |this, cx| {
|
||||
client.authenticate_and_connect(true, cx).await?;
|
||||
this.update(cx, |_, cx| cx.notify())
|
||||
cx.spawn(async move |context_editor, cx| {
|
||||
match client.authenticate_and_connect(true, cx).await {
|
||||
util::ConnectionResult::Timeout => {
|
||||
log::error!("Authentication timeout")
|
||||
}
|
||||
util::ConnectionResult::ConnectionReset => {
|
||||
log::error!("Connection reset")
|
||||
}
|
||||
util::ConnectionResult::Result(r) => {
|
||||
if r.log_err().is_some() {
|
||||
context_editor
|
||||
.update(cx, |_, cx| cx.notify())
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach_and_log_err(cx)
|
||||
.detach()
|
||||
}
|
||||
})),
|
||||
)
|
||||
@@ -2089,9 +2114,6 @@ impl ContextEditor {
|
||||
.occlude()
|
||||
.child(match last_error {
|
||||
AssistError::PaymentRequired => self.render_payment_required_error(cx),
|
||||
AssistError::MaxMonthlySpendReached => {
|
||||
self.render_max_monthly_spend_reached_error(cx)
|
||||
}
|
||||
AssistError::Message(error_message) => {
|
||||
self.render_assist_error(error_message, cx)
|
||||
}
|
||||
@@ -2140,48 +2162,6 @@ impl ContextEditor {
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_max_monthly_spend_reached_error(&self, cx: &mut Context<Self>) -> AnyElement {
|
||||
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(ERROR_MESSAGE)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(
|
||||
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||
cx.listener(|this, _, _window, cx| {
|
||||
this.last_error = None;
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
cx.notify();
|
||||
}),
|
||||
),
|
||||
)
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, _window, cx| {
|
||||
this.last_error = None;
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_assist_error(
|
||||
&self,
|
||||
error_message: &SharedString,
|
||||
@@ -3064,7 +3044,7 @@ fn invoked_slash_command_fold_placeholder(
|
||||
.gap_2()
|
||||
.bg(cx.theme().colors().surface_background)
|
||||
.rounded_sm()
|
||||
.child(Label::new(format!("/{}", command.name.clone())))
|
||||
.child(Label::new(format!("/{}", command.name)))
|
||||
.map(|parent| match &command.status {
|
||||
InvokedSlashCommandStatus::Running(_) => {
|
||||
parent.child(Icon::new(IconName::ArrowCircle).with_animation(
|
||||
@@ -3233,9 +3213,77 @@ pub fn make_lsp_adapter_delegate(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::App;
|
||||
use language::Buffer;
|
||||
use fs::FakeFs;
|
||||
use gpui::{App, TestAppContext, VisualTestContext};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use prompt_store::PromptBuilder;
|
||||
use unindent::Unindent;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
|
||||
cx.update(init_test);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let context = cx.new(|cx| {
|
||||
AssistantContext::local(
|
||||
registry,
|
||||
None,
|
||||
None,
|
||||
prompt_builder.clone(),
|
||||
Arc::new(SlashCommandWorkingSet::default()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
let workspace = window.root(cx).unwrap();
|
||||
let cx = &mut VisualTestContext::from_window(*window, cx);
|
||||
|
||||
let context_editor = window
|
||||
.update(cx, |_, window, cx| {
|
||||
cx.new(|cx| {
|
||||
ContextEditor::for_context(
|
||||
context,
|
||||
fs,
|
||||
workspace.downgrade(),
|
||||
project,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
context_editor.update_in(cx, |context_editor, window, cx| {
|
||||
context_editor.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("abc\ndef\nghi", window, cx);
|
||||
editor.move_to_beginning(&Default::default(), window, cx);
|
||||
})
|
||||
});
|
||||
|
||||
context_editor.update_in(cx, |context_editor, window, cx| {
|
||||
context_editor.editor.update(cx, |editor, cx| {
|
||||
editor.copy(&Default::default(), window, cx);
|
||||
editor.paste(&Default::default(), window, cx);
|
||||
|
||||
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
|
||||
})
|
||||
});
|
||||
|
||||
context_editor.update_in(cx, |context_editor, window, cx| {
|
||||
context_editor.editor.update(cx, |editor, cx| {
|
||||
editor.cut(&Default::default(), window, cx);
|
||||
assert_eq!(editor.text(cx), "abc\ndef\nghi");
|
||||
|
||||
editor.paste(&Default::default(), window, cx);
|
||||
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_find_code_blocks(cx: &mut App) {
|
||||
@@ -3310,4 +3358,17 @@ mod tests {
|
||||
assert_eq!(range, expected, "unexpected result on row {:?}", row);
|
||||
}
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut App) {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
prompt_store::init(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
assistant_settings::init(cx);
|
||||
Project::init_settings(cx);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
workspace::init_settings(cx);
|
||||
editor::init_settings(cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -648,7 +648,10 @@ impl ContextStore {
|
||||
if context.replica_id() == ReplicaId::default() {
|
||||
Some(proto::ContextMetadata {
|
||||
context_id: context.id().to_proto(),
|
||||
summary: context.summary().map(|summary| summary.text.clone()),
|
||||
summary: context
|
||||
.summary()
|
||||
.content()
|
||||
.map(|summary| summary.text.clone()),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -278,8 +278,8 @@ impl CompletionProvider for SlashCommandCompletionProvider {
|
||||
buffer.anchor_after(Point::new(position.row, first_arg_start.start as u32));
|
||||
let arguments = call
|
||||
.arguments
|
||||
.iter()
|
||||
.filter_map(|argument| Some(line.get(argument.clone())?.to_string()))
|
||||
.into_iter()
|
||||
.filter_map(|argument| Some(line.get(argument)?.to_string()))
|
||||
.collect::<Vec<_>>();
|
||||
let argument_range = first_arg_start..buffer_position;
|
||||
(
|
||||
|
||||
@@ -41,6 +41,7 @@ pub enum NotifyWhenAgentWaiting {
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[serde(tag = "name", rename_all = "snake_case")]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub enum AssistantProviderContentV1 {
|
||||
#[serde(rename = "zed.dev")]
|
||||
ZedDotDev { default_model: Option<CloudModel> },
|
||||
@@ -93,6 +94,7 @@ pub struct AssistantSettings {
|
||||
pub single_file_review: bool,
|
||||
pub model_parameters: Vec<LanguageModelParameters>,
|
||||
pub preferred_completion_mode: CompletionMode,
|
||||
pub enable_feedback: bool,
|
||||
}
|
||||
|
||||
impl AssistantSettings {
|
||||
@@ -260,6 +262,7 @@ impl AssistantSettingsContent {
|
||||
single_file_review: None,
|
||||
model_parameters: Vec::new(),
|
||||
preferred_completion_mode: None,
|
||||
enable_feedback: None,
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
|
||||
},
|
||||
@@ -290,6 +293,7 @@ impl AssistantSettingsContent {
|
||||
single_file_review: None,
|
||||
model_parameters: Vec::new(),
|
||||
preferred_completion_mode: None,
|
||||
enable_feedback: None,
|
||||
},
|
||||
None => AssistantSettingsContentV2::default(),
|
||||
}
|
||||
@@ -543,6 +547,7 @@ impl AssistantSettingsContent {
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[serde(tag = "version")]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub enum VersionedAssistantSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(AssistantSettingsContentV1),
|
||||
@@ -571,11 +576,13 @@ impl Default for VersionedAssistantSettingsContent {
|
||||
single_file_review: None,
|
||||
model_parameters: Vec::new(),
|
||||
preferred_completion_mode: None,
|
||||
enable_feedback: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct AssistantSettingsContentV2 {
|
||||
/// Whether the Assistant is enabled.
|
||||
///
|
||||
@@ -644,6 +651,10 @@ pub struct AssistantSettingsContentV2 {
|
||||
///
|
||||
/// Default: normal
|
||||
preferred_completion_mode: Option<CompletionMode>,
|
||||
/// Whether to show thumb buttons for feedback in the agent panel.
|
||||
///
|
||||
/// Default: true
|
||||
enable_feedback: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
|
||||
@@ -681,7 +692,7 @@ impl JsonSchema for LanguageModelProviderSetting {
|
||||
schemars::schema::SchemaObject {
|
||||
enum_values: Some(vec![
|
||||
"anthropic".into(),
|
||||
"bedrock".into(),
|
||||
"amazon-bedrock".into(),
|
||||
"google".into(),
|
||||
"lmstudio".into(),
|
||||
"ollama".into(),
|
||||
@@ -734,6 +745,7 @@ pub struct ContextServerPresetContent {
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct AssistantSettingsContentV1 {
|
||||
/// Whether the Assistant is enabled.
|
||||
///
|
||||
@@ -763,6 +775,7 @@ pub struct AssistantSettingsContentV1 {
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct LegacyAssistantSettingsContent {
|
||||
/// Whether to show the assistant panel button in the status bar.
|
||||
///
|
||||
@@ -848,6 +861,7 @@ impl Settings for AssistantSettings {
|
||||
&mut settings.preferred_completion_mode,
|
||||
value.preferred_completion_mode,
|
||||
);
|
||||
merge(&mut settings.enable_feedback, value.enable_feedback);
|
||||
|
||||
settings
|
||||
.model_parameters
|
||||
@@ -984,6 +998,7 @@ mod tests {
|
||||
notify_when_agent_waiting: None,
|
||||
stream_edits: None,
|
||||
single_file_review: None,
|
||||
enable_feedback: None,
|
||||
model_parameters: Vec::new(),
|
||||
preferred_completion_mode: None,
|
||||
},
|
||||
|
||||
@@ -49,6 +49,37 @@ impl ActionLog {
|
||||
is_created: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut TrackedBuffer {
|
||||
let status = if is_created {
|
||||
if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
|
||||
match tracked.status {
|
||||
TrackedBufferStatus::Created {
|
||||
existing_file_content,
|
||||
} => TrackedBufferStatus::Created {
|
||||
existing_file_content,
|
||||
},
|
||||
TrackedBufferStatus::Modified | TrackedBufferStatus::Deleted => {
|
||||
TrackedBufferStatus::Created {
|
||||
existing_file_content: Some(tracked.diff_base),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map_or(false, |file| file.disk_state().exists())
|
||||
{
|
||||
TrackedBufferStatus::Created {
|
||||
existing_file_content: Some(buffer.read(cx).as_rope().clone()),
|
||||
}
|
||||
} else {
|
||||
TrackedBufferStatus::Created {
|
||||
existing_file_content: None,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TrackedBufferStatus::Modified
|
||||
};
|
||||
|
||||
let tracked_buffer = self
|
||||
.tracked_buffers
|
||||
.entry(buffer.clone())
|
||||
@@ -60,36 +91,21 @@ impl ActionLog {
|
||||
let text_snapshot = buffer.read(cx).text_snapshot();
|
||||
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
|
||||
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
|
||||
let base_text;
|
||||
let status;
|
||||
let diff_base;
|
||||
let unreviewed_changes;
|
||||
if is_created {
|
||||
let existing_file_content = if buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map_or(false, |file| file.disk_state().exists())
|
||||
{
|
||||
Some(text_snapshot.as_rope().clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
base_text = Rope::default();
|
||||
status = TrackedBufferStatus::Created {
|
||||
existing_file_content,
|
||||
};
|
||||
diff_base = Rope::default();
|
||||
unreviewed_changes = Patch::new(vec![Edit {
|
||||
old: 0..1,
|
||||
new: 0..text_snapshot.max_point().row + 1,
|
||||
}])
|
||||
} else {
|
||||
base_text = buffer.read(cx).as_rope().clone();
|
||||
status = TrackedBufferStatus::Modified;
|
||||
diff_base = buffer.read(cx).as_rope().clone();
|
||||
unreviewed_changes = Patch::default();
|
||||
}
|
||||
TrackedBuffer {
|
||||
buffer: buffer.clone(),
|
||||
base_text,
|
||||
diff_base,
|
||||
unreviewed_changes,
|
||||
snapshot: text_snapshot.clone(),
|
||||
status,
|
||||
@@ -184,7 +200,7 @@ impl ActionLog {
|
||||
.context("buffer not tracked")?;
|
||||
|
||||
let rebase = cx.background_spawn({
|
||||
let mut base_text = tracked_buffer.base_text.clone();
|
||||
let mut base_text = tracked_buffer.diff_base.clone();
|
||||
let old_snapshot = tracked_buffer.snapshot.clone();
|
||||
let new_snapshot = buffer_snapshot.clone();
|
||||
let unreviewed_changes = tracked_buffer.unreviewed_changes.clone();
|
||||
@@ -210,7 +226,7 @@ impl ActionLog {
|
||||
))
|
||||
})??;
|
||||
|
||||
let (new_base_text, new_base_text_rope) = rebase.await;
|
||||
let (new_base_text, new_diff_base) = rebase.await;
|
||||
let diff_snapshot = BufferDiff::update_diff(
|
||||
diff.clone(),
|
||||
buffer_snapshot.clone(),
|
||||
@@ -229,24 +245,23 @@ impl ActionLog {
|
||||
.background_spawn({
|
||||
let diff_snapshot = diff_snapshot.clone();
|
||||
let buffer_snapshot = buffer_snapshot.clone();
|
||||
let new_base_text_rope = new_base_text_rope.clone();
|
||||
let new_diff_base = new_diff_base.clone();
|
||||
async move {
|
||||
let mut unreviewed_changes = Patch::default();
|
||||
for hunk in diff_snapshot.hunks_intersecting_range(
|
||||
Anchor::MIN..Anchor::MAX,
|
||||
&buffer_snapshot,
|
||||
) {
|
||||
let old_range = new_base_text_rope
|
||||
let old_range = new_diff_base
|
||||
.offset_to_point(hunk.diff_base_byte_range.start)
|
||||
..new_base_text_rope
|
||||
.offset_to_point(hunk.diff_base_byte_range.end);
|
||||
..new_diff_base.offset_to_point(hunk.diff_base_byte_range.end);
|
||||
let new_range = hunk.range.start..hunk.range.end;
|
||||
unreviewed_changes.push(point_to_row_edit(
|
||||
Edit {
|
||||
old: old_range,
|
||||
new: new_range,
|
||||
},
|
||||
&new_base_text_rope,
|
||||
&new_diff_base,
|
||||
&buffer_snapshot.as_rope(),
|
||||
));
|
||||
}
|
||||
@@ -264,7 +279,7 @@ impl ActionLog {
|
||||
.tracked_buffers
|
||||
.get_mut(&buffer)
|
||||
.context("buffer not tracked")?;
|
||||
tracked_buffer.base_text = new_base_text_rope;
|
||||
tracked_buffer.diff_base = new_diff_base;
|
||||
tracked_buffer.snapshot = buffer_snapshot;
|
||||
tracked_buffer.unreviewed_changes = unreviewed_changes;
|
||||
cx.notify();
|
||||
@@ -283,7 +298,6 @@ impl ActionLog {
|
||||
/// Mark a buffer as edited, so we can refresh it in the context
|
||||
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
self.track_buffer_internal(buffer.clone(), true, cx);
|
||||
}
|
||||
|
||||
@@ -346,11 +360,11 @@ impl ActionLog {
|
||||
true
|
||||
} else {
|
||||
let old_range = tracked_buffer
|
||||
.base_text
|
||||
.diff_base
|
||||
.point_to_offset(Point::new(edit.old.start, 0))
|
||||
..tracked_buffer.base_text.point_to_offset(cmp::min(
|
||||
..tracked_buffer.diff_base.point_to_offset(cmp::min(
|
||||
Point::new(edit.old.end, 0),
|
||||
tracked_buffer.base_text.max_point(),
|
||||
tracked_buffer.diff_base.max_point(),
|
||||
));
|
||||
let new_range = tracked_buffer
|
||||
.snapshot
|
||||
@@ -359,7 +373,7 @@ impl ActionLog {
|
||||
Point::new(edit.new.end, 0),
|
||||
tracked_buffer.snapshot.max_point(),
|
||||
));
|
||||
tracked_buffer.base_text.replace(
|
||||
tracked_buffer.diff_base.replace(
|
||||
old_range,
|
||||
&tracked_buffer
|
||||
.snapshot
|
||||
@@ -417,7 +431,7 @@ impl ActionLog {
|
||||
}
|
||||
TrackedBufferStatus::Deleted => {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_text(tracked_buffer.base_text.to_string(), cx)
|
||||
buffer.set_text(tracked_buffer.diff_base.to_string(), cx)
|
||||
});
|
||||
let save = self
|
||||
.project
|
||||
@@ -464,14 +478,14 @@ impl ActionLog {
|
||||
|
||||
if revert {
|
||||
let old_range = tracked_buffer
|
||||
.base_text
|
||||
.diff_base
|
||||
.point_to_offset(Point::new(edit.old.start, 0))
|
||||
..tracked_buffer.base_text.point_to_offset(cmp::min(
|
||||
..tracked_buffer.diff_base.point_to_offset(cmp::min(
|
||||
Point::new(edit.old.end, 0),
|
||||
tracked_buffer.base_text.max_point(),
|
||||
tracked_buffer.diff_base.max_point(),
|
||||
));
|
||||
let old_text = tracked_buffer
|
||||
.base_text
|
||||
.diff_base
|
||||
.chunks_in_range(old_range)
|
||||
.collect::<String>();
|
||||
edits_to_revert.push((new_range, old_text));
|
||||
@@ -492,7 +506,7 @@ impl ActionLog {
|
||||
TrackedBufferStatus::Deleted => false,
|
||||
_ => {
|
||||
tracked_buffer.unreviewed_changes.clear();
|
||||
tracked_buffer.base_text = tracked_buffer.snapshot.as_rope().clone();
|
||||
tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone();
|
||||
tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
|
||||
true
|
||||
}
|
||||
@@ -655,7 +669,7 @@ enum TrackedBufferStatus {
|
||||
|
||||
struct TrackedBuffer {
|
||||
buffer: Entity<Buffer>,
|
||||
base_text: Rope,
|
||||
diff_base: Rope,
|
||||
unreviewed_changes: Patch<u32>,
|
||||
status: TrackedBufferStatus,
|
||||
version: clock::Global,
|
||||
@@ -1094,6 +1108,86 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_overwriting_previously_edited_files(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/dir"),
|
||||
json!({
|
||||
"file1": "Lorem ipsum dolor"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
|
||||
.unwrap();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| buffer.append(" sit amet consecteur", cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
});
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
unreviewed_hunks(&action_log, cx),
|
||||
vec![(
|
||||
buffer.clone(),
|
||||
vec![HunkStatus {
|
||||
range: Point::new(0, 0)..Point::new(0, 37),
|
||||
diff_status: DiffHunkStatusKind::Modified,
|
||||
old_text: "Lorem ipsum dolor".into(),
|
||||
}],
|
||||
)]
|
||||
);
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("rewritten", cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
});
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
unreviewed_hunks(&action_log, cx),
|
||||
vec![(
|
||||
buffer.clone(),
|
||||
vec![HunkStatus {
|
||||
range: Point::new(0, 0)..Point::new(0, 9),
|
||||
diff_status: DiffHunkStatusKind::Added,
|
||||
old_text: "".into(),
|
||||
}],
|
||||
)]
|
||||
);
|
||||
|
||||
action_log
|
||||
.update(cx, |log, cx| {
|
||||
log.reject_edits_in_ranges(buffer.clone(), vec![2..5], cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _cx| buffer.text()),
|
||||
"Lorem ipsum dolor"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_deleting_files(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
@@ -1601,7 +1695,7 @@ mod tests {
|
||||
cx.run_until_parked();
|
||||
action_log.update(cx, |log, cx| {
|
||||
let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap();
|
||||
let mut old_text = tracked_buffer.base_text.clone();
|
||||
let mut old_text = tracked_buffer.diff_base.clone();
|
||||
let new_text = buffer.read(cx).as_rope();
|
||||
for edit in tracked_buffer.unreviewed_changes.edits() {
|
||||
let old_start = old_text.point_to_offset(Point::new(edit.new.start, 0));
|
||||
|
||||
@@ -19,6 +19,7 @@ use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||
use icons::IconName;
|
||||
use language_model::LanguageModel;
|
||||
use language_model::LanguageModelImage;
|
||||
use language_model::LanguageModelRequest;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
@@ -65,21 +66,50 @@ impl ToolUseStatus {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolResultOutput {
|
||||
pub content: String,
|
||||
pub content: ToolResultContent,
|
||||
pub output: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ToolResultContent {
|
||||
Text(String),
|
||||
Image(LanguageModelImage),
|
||||
}
|
||||
|
||||
impl ToolResultContent {
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
ToolResultContent::Text(str) => str.len(),
|
||||
ToolResultContent::Image(image) => image.len(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
ToolResultContent::Text(str) => str.is_empty(),
|
||||
ToolResultContent::Image(image) => image.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
ToolResultContent::Text(str) => Some(str),
|
||||
ToolResultContent::Image(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ToolResultOutput {
|
||||
fn from(value: String) -> Self {
|
||||
ToolResultOutput {
|
||||
content: value,
|
||||
content: ToolResultContent::Text(value),
|
||||
output: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ToolResultOutput {
|
||||
type Target = String;
|
||||
type Target = ToolResultContent;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.content
|
||||
|
||||
@@ -35,7 +35,6 @@ indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
linkme.workspace = true
|
||||
log.workspace = true
|
||||
markdown.workspace = true
|
||||
open.workspace = true
|
||||
|
||||
@@ -42,7 +42,7 @@ use crate::list_directory_tool::ListDirectoryTool;
|
||||
use crate::now_tool::NowTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
|
||||
pub use edit_file_tool::EditFileToolInput;
|
||||
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
|
||||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use open_tool::OpenTool;
|
||||
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
use super::*;
|
||||
use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
|
||||
use crate::{
|
||||
ReadFileToolInput,
|
||||
edit_file_tool::{EditFileMode, EditFileToolInput},
|
||||
grep_tool::GrepToolInput,
|
||||
};
|
||||
use Role::*;
|
||||
use anyhow::anyhow;
|
||||
use assistant_tool::ToolRegistry;
|
||||
@@ -10,8 +14,8 @@ use futures::{FutureExt, future::LocalBoxFuture};
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::{formatdoc, indoc};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
|
||||
LanguageModelToolUseId,
|
||||
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel,
|
||||
};
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
@@ -21,6 +25,7 @@ use std::{
|
||||
cmp::Reverse,
|
||||
fmt::{self, Display},
|
||||
io::Write as _,
|
||||
str::FromStr,
|
||||
sync::mpsc,
|
||||
};
|
||||
use util::path;
|
||||
@@ -71,7 +76,7 @@ fn eval_extract_handle_command_output() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
)],
|
||||
),
|
||||
@@ -127,7 +132,7 @@ fn eval_delete_run_git_blame() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
)],
|
||||
),
|
||||
@@ -182,7 +187,7 @@ fn eval_translate_doc_comments() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
)],
|
||||
),
|
||||
@@ -297,7 +302,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
)],
|
||||
),
|
||||
@@ -372,7 +377,7 @@ fn eval_disable_cursor_blinking() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
)],
|
||||
),
|
||||
@@ -566,7 +571,7 @@ fn eval_from_pixels_constructor() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
)],
|
||||
),
|
||||
@@ -643,7 +648,7 @@ fn eval_zode() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: true,
|
||||
mode: EditFileMode::Create,
|
||||
},
|
||||
),
|
||||
],
|
||||
@@ -888,7 +893,7 @@ fn eval_add_overwrite_test() {
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
},
|
||||
),
|
||||
],
|
||||
@@ -951,7 +956,7 @@ fn tool_result(
|
||||
tool_use_id: LanguageModelToolUseId::from(id.into()),
|
||||
tool_name: name.into(),
|
||||
is_error: false,
|
||||
content: result.into(),
|
||||
content: LanguageModelToolResultContent::Text(result.into()),
|
||||
output: None,
|
||||
})
|
||||
}
|
||||
@@ -1212,7 +1217,7 @@ fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usiz
|
||||
passed_count as f64 / evaluated_count as f64
|
||||
};
|
||||
print!(
|
||||
"\r\x1b[KEvaluated {}/{} ({:.2}%)",
|
||||
"\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
|
||||
evaluated_count,
|
||||
iterations,
|
||||
passed_ratio * 100.0
|
||||
@@ -1251,13 +1256,21 @@ impl EditAgentTest {
|
||||
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let agent_model = SelectedModel::from_str(
|
||||
&std::env::var("ZED_AGENT_MODEL")
|
||||
.unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
|
||||
)
|
||||
.unwrap();
|
||||
let judge_model = SelectedModel::from_str(
|
||||
&std::env::var("ZED_JUDGE_MODEL")
|
||||
.unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
|
||||
)
|
||||
.unwrap();
|
||||
let (agent_model, judge_model) = cx
|
||||
.update(|cx| {
|
||||
cx.spawn(async move |cx| {
|
||||
let agent_model =
|
||||
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
|
||||
let judge_model =
|
||||
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
|
||||
let agent_model = Self::load_model(&agent_model, cx).await;
|
||||
let judge_model = Self::load_model(&judge_model, cx).await;
|
||||
(agent_model.unwrap(), judge_model.unwrap())
|
||||
})
|
||||
})
|
||||
@@ -1272,15 +1285,17 @@ impl EditAgentTest {
|
||||
}
|
||||
|
||||
async fn load_model(
|
||||
provider: &str,
|
||||
id: &str,
|
||||
selected_model: &SelectedModel,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Arc<dyn LanguageModel>> {
|
||||
let (provider, model) = cx.update(|cx| {
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.provider_id().0 == provider && model.id().0 == id)
|
||||
.find(|model| {
|
||||
model.provider_id() == selected_model.provider
|
||||
&& model.id() == selected_model.model
|
||||
})
|
||||
.unwrap();
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
(provider, model)
|
||||
|
||||
@@ -1249,7 +1249,7 @@ pub struct ActiveDiagnosticGroup {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
|
||||
pub(crate) enum ActiveDiagnostic {
|
||||
None,
|
||||
All,
|
||||
|
||||
@@ -5,7 +5,8 @@ use crate::{
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{
|
||||
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus,
|
||||
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput,
|
||||
ToolUseStatus,
|
||||
};
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
|
||||
@@ -75,12 +76,22 @@ pub struct EditFileToolInput {
|
||||
/// </example>
|
||||
pub path: PathBuf,
|
||||
|
||||
/// If true, this tool will recreate the file from scratch.
|
||||
/// If false, this tool will produce granular edits to an existing file.
|
||||
/// The mode of operation on the file. Possible values:
|
||||
/// - 'edit': Make granular edits to an existing file.
|
||||
/// - 'create': Create a new file if it doesn't exist.
|
||||
/// - 'overwrite': Replace the entire contents of an existing file.
|
||||
///
|
||||
/// When a file already exists or you just created it, always prefer editing
|
||||
/// When a file already exists or you just created it, prefer editing
|
||||
/// it as opposed to recreating it from scratch.
|
||||
pub create_or_overwrite: bool,
|
||||
pub mode: EditFileMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EditFileMode {
|
||||
Edit,
|
||||
Create,
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -194,7 +205,11 @@ impl Tool for EditFileTool {
|
||||
.as_ref()
|
||||
.map_or(false, |file| file.disk_state().exists())
|
||||
})?;
|
||||
if !input.create_or_overwrite && !exists {
|
||||
let create_or_overwrite = match input.mode {
|
||||
EditFileMode::Create | EditFileMode::Overwrite => true,
|
||||
_ => false,
|
||||
};
|
||||
if !create_or_overwrite && !exists {
|
||||
return Err(anyhow!("{} not found", input.path.display()));
|
||||
}
|
||||
|
||||
@@ -206,7 +221,7 @@ impl Tool for EditFileTool {
|
||||
})
|
||||
.await;
|
||||
|
||||
let (output, mut events) = if input.create_or_overwrite {
|
||||
let (output, mut events) = if create_or_overwrite {
|
||||
edit_agent.overwrite(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
@@ -292,7 +307,10 @@ impl Tool for EditFileTool {
|
||||
}
|
||||
} else {
|
||||
Ok(ToolResultOutput {
|
||||
content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff),
|
||||
content: ToolResultContent::Text(format!(
|
||||
"Edited {}:\n\n```diff\n{}\n```",
|
||||
input_path, diff
|
||||
)),
|
||||
output: serde_json::to_value(output).ok(),
|
||||
})
|
||||
}
|
||||
@@ -637,7 +655,7 @@ impl ToolCard for EditFileToolCard {
|
||||
.p_3()
|
||||
.gap_1()
|
||||
.border_t_1()
|
||||
.rounded_md()
|
||||
.rounded_b_md()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().editor_background);
|
||||
|
||||
@@ -872,7 +890,7 @@ mod tests {
|
||||
let input = serde_json::to_value(EditFileToolInput {
|
||||
display_description: "Some edit".into(),
|
||||
path: "root/nonexistent_file.txt".into(),
|
||||
create_or_overwrite: false,
|
||||
mode: EditFileMode::Edit,
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use assistant_tool::{
|
||||
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
|
||||
};
|
||||
use editor::Editor;
|
||||
use futures::channel::oneshot::{self, Receiver};
|
||||
use gpui::{
|
||||
@@ -38,6 +40,12 @@ pub struct FindPathToolInput {
|
||||
pub offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FindPathToolOutput {
|
||||
glob: String,
|
||||
paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
||||
pub struct FindPathTool;
|
||||
@@ -111,10 +119,18 @@ impl Tool for FindPathTool {
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
let output = FindPathToolOutput {
|
||||
glob,
|
||||
paths: matches.clone(),
|
||||
};
|
||||
|
||||
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
|
||||
write!(&mut message, "\n{}", mat.display()).unwrap();
|
||||
}
|
||||
Ok(message.into())
|
||||
Ok(ToolResultOutput {
|
||||
content: ToolResultContent::Text(message),
|
||||
output: Some(serde_json::to_value(output)?),
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
@@ -123,6 +139,18 @@ impl Tool for FindPathTool {
|
||||
card: Some(card.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_card(
|
||||
self: Arc<Self>,
|
||||
output: serde_json::Value,
|
||||
_project: Entity<Project>,
|
||||
_window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<assistant_tool::AnyToolCard> {
|
||||
let output = serde_json::from_value::<FindPathToolOutput>(output).ok()?;
|
||||
let card = cx.new(|_| FindPathToolCard::from_output(output));
|
||||
Some(card.into())
|
||||
}
|
||||
}
|
||||
|
||||
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
|
||||
@@ -180,6 +208,15 @@ impl FindPathToolCard {
|
||||
_receiver_task: Some(_receiver_task),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_output(output: FindPathToolOutput) -> Self {
|
||||
Self {
|
||||
glob: output.glob,
|
||||
paths: output.paths,
|
||||
expanded: false,
|
||||
_receiver_task: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCard for FindPathToolCard {
|
||||
|
||||
@@ -752,9 +752,9 @@ mod tests {
|
||||
match task.output.await {
|
||||
Ok(result) => {
|
||||
if cfg!(windows) {
|
||||
result.content.replace("root\\", "root/")
|
||||
result.content.as_str().unwrap().replace("root\\", "root/")
|
||||
} else {
|
||||
result.content
|
||||
result.content.as_str().unwrap().to_string()
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("Failed to run grep tool: {}", e),
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::outline;
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use assistant_tool::{ToolResultContent, outline};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use project::{ImageItem, image_store};
|
||||
|
||||
use assistant_tool::ToolResultOutput;
|
||||
use indoc::formatdoc;
|
||||
use itertools::Itertools;
|
||||
use language::{Anchor, Point};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
};
|
||||
use project::{AgentLocation, Project};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -86,7 +90,7 @@ impl Tool for ReadFileTool {
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
@@ -100,6 +104,42 @@ impl Tool for ReadFileTool {
|
||||
};
|
||||
|
||||
let file_path = input.path.clone();
|
||||
|
||||
if image_store::is_image_file(&project, &project_path, cx) {
|
||||
if !model.supports_images() {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Attempted to read an image, but Zed doesn't currently sending images to {}.",
|
||||
model.name().0
|
||||
)))
|
||||
.into();
|
||||
}
|
||||
|
||||
let task = cx.spawn(async move |cx| -> Result<ToolResultOutput> {
|
||||
let image_entity: Entity<ImageItem> = cx
|
||||
.update(|cx| {
|
||||
project.update(cx, |project, cx| {
|
||||
project.open_image(project_path.clone(), cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let image =
|
||||
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
|
||||
|
||||
let language_model_image = cx
|
||||
.update(|cx| LanguageModelImage::from_image(image, cx))?
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("Failed to process image"))?;
|
||||
|
||||
Ok(ToolResultOutput {
|
||||
content: ToolResultContent::Image(language_model_image),
|
||||
output: None,
|
||||
})
|
||||
});
|
||||
|
||||
return task.into();
|
||||
}
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = cx
|
||||
.update(|cx| {
|
||||
@@ -282,7 +322,10 @@ mod test {
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap().content, "This is a small file content");
|
||||
assert_eq!(
|
||||
result.unwrap().content.as_str(),
|
||||
Some("This is a small file content")
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -322,6 +365,7 @@ mod test {
|
||||
})
|
||||
.await;
|
||||
let content = result.unwrap();
|
||||
let content = content.as_str().unwrap();
|
||||
assert_eq!(
|
||||
content.lines().skip(4).take(6).collect::<Vec<_>>(),
|
||||
vec![
|
||||
@@ -365,6 +409,8 @@ mod test {
|
||||
.collect::<Vec<_>>();
|
||||
pretty_assertions::assert_eq!(
|
||||
content
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.lines()
|
||||
.skip(4)
|
||||
.take(expected_content.len())
|
||||
@@ -408,7 +454,10 @@ mod test {
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4");
|
||||
assert_eq!(
|
||||
result.unwrap().content.as_str(),
|
||||
Some("Line 2\nLine 3\nLine 4")
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -448,7 +497,7 @@ mod test {
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap().content, "Line 1\nLine 2");
|
||||
assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2"));
|
||||
|
||||
// end_line of 0 should result in at least 1 line
|
||||
let result = cx
|
||||
@@ -471,7 +520,7 @@ mod test {
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap().content, "Line 1");
|
||||
assert_eq!(result.unwrap().content.as_str(), Some("Line 1"));
|
||||
|
||||
// when start_line > end_line, should still return at least 1 line
|
||||
let result = cx
|
||||
@@ -494,7 +543,7 @@ mod test {
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
assert_eq!(result.unwrap().content, "Line 3");
|
||||
assert_eq!(result.unwrap().content.as_str(), Some("Line 3"));
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement,
|
||||
WeakEntity, Window,
|
||||
};
|
||||
use language::LineEnding;
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
|
||||
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
|
||||
use project::{Project, terminals::TerminalKind};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
use std::{
|
||||
env,
|
||||
path::{Path, PathBuf},
|
||||
@@ -17,6 +22,7 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use terminal_view::TerminalView;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Disclosure, Tooltip, prelude::*};
|
||||
use util::{
|
||||
get_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
|
||||
@@ -119,14 +125,24 @@ impl Tool for TerminalTool {
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let input_path = Path::new(&input.cd);
|
||||
let working_dir = match working_dir(&input, &project, input_path, cx) {
|
||||
let working_dir = match working_dir(&input, &project, cx) {
|
||||
Ok(dir) => dir,
|
||||
Err(err) => return Task::ready(Err(err)).into(),
|
||||
};
|
||||
let program = self.determine_shell.clone();
|
||||
let command = format!("({}) </dev/null", input.command);
|
||||
let args = vec!["-c".into(), command.clone()];
|
||||
let command = if cfg!(windows) {
|
||||
format!("$null | & {{{}}}", input.command.replace("\"", "'"))
|
||||
} else if let Some(cwd) = working_dir
|
||||
.as_ref()
|
||||
.and_then(|cwd| cwd.as_os_str().to_str())
|
||||
{
|
||||
// Make sure once we're *inside* the shell, we cd into `cwd`
|
||||
format!("(cd {cwd}; {}) </dev/null", input.command)
|
||||
} else {
|
||||
format!("({}) </dev/null", input.command)
|
||||
};
|
||||
let args = vec!["-c".into(), command];
|
||||
|
||||
let cwd = working_dir.clone();
|
||||
let env = match &working_dir {
|
||||
Some(dir) => project.update(cx, |project, cx| {
|
||||
@@ -211,8 +227,21 @@ impl Tool for TerminalTool {
|
||||
}
|
||||
});
|
||||
|
||||
let command_markdown = cx.new(|cx| {
|
||||
Markdown::new(
|
||||
format!("```bash\n{}\n```", input.command).into(),
|
||||
None,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let card = cx.new(|cx| {
|
||||
TerminalToolCard::new(input.command.clone(), working_dir.clone(), cx.entity_id())
|
||||
TerminalToolCard::new(
|
||||
command_markdown.clone(),
|
||||
working_dir.clone(),
|
||||
cx.entity_id(),
|
||||
)
|
||||
});
|
||||
|
||||
let output = cx.spawn({
|
||||
@@ -296,19 +325,13 @@ fn process_content(
|
||||
} else {
|
||||
content
|
||||
};
|
||||
let is_empty = content.trim().is_empty();
|
||||
|
||||
let content = format!(
|
||||
"```\n{}{}```",
|
||||
content,
|
||||
if content.ends_with('\n') { "" } else { "\n" }
|
||||
);
|
||||
|
||||
let content = content.trim();
|
||||
let is_empty = content.is_empty();
|
||||
let content = format!("```\n{content}\n```");
|
||||
let content = if should_truncate {
|
||||
format!(
|
||||
"Command output too long. The first {} bytes:\n\n{}",
|
||||
"Command output too long. The first {} bytes:\n\n{content}",
|
||||
content.len(),
|
||||
content,
|
||||
)
|
||||
} else {
|
||||
content
|
||||
@@ -348,47 +371,52 @@ fn process_content(
|
||||
fn working_dir(
|
||||
input: &TerminalToolInput,
|
||||
project: &Entity<Project>,
|
||||
input_path: &Path,
|
||||
cx: &mut App,
|
||||
) -> Result<Option<PathBuf>> {
|
||||
let project = project.read(cx);
|
||||
let cd = &input.cd;
|
||||
|
||||
if input.cd == "." {
|
||||
// Accept "." as meaning "the one worktree" if we only have one worktree.
|
||||
if cd == "." || cd == "" {
|
||||
// Accept "." or "" as meaning "the one worktree" if we only have one worktree.
|
||||
let mut worktrees = project.worktrees(cx);
|
||||
|
||||
match worktrees.next() {
|
||||
Some(worktree) => {
|
||||
if worktrees.next().is_some() {
|
||||
bail!(
|
||||
if worktrees.next().is_none() {
|
||||
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.",
|
||||
);
|
||||
))
|
||||
}
|
||||
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
} else if input_path.is_absolute() {
|
||||
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
|
||||
if !project
|
||||
.worktrees(cx)
|
||||
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
|
||||
{
|
||||
bail!("The absolute path must be within one of the project's worktrees");
|
||||
} else {
|
||||
let input_path = Path::new(cd);
|
||||
|
||||
if input_path.is_absolute() {
|
||||
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
|
||||
if project
|
||||
.worktrees(cx)
|
||||
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
|
||||
{
|
||||
return Ok(Some(input_path.into()));
|
||||
}
|
||||
} else {
|
||||
if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
|
||||
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(input_path.into()))
|
||||
} else {
|
||||
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
|
||||
bail!("`cd` directory {:?} not found in the project", input.cd);
|
||||
};
|
||||
|
||||
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
|
||||
Err(anyhow!(
|
||||
"`cd` directory {cd:?} was not in any of the project's worktrees."
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
struct TerminalToolCard {
|
||||
input_command: String,
|
||||
input_command: Entity<Markdown>,
|
||||
working_dir: Option<PathBuf>,
|
||||
entity_id: EntityId,
|
||||
exit_status: Option<ExitStatus>,
|
||||
@@ -404,7 +432,11 @@ struct TerminalToolCard {
|
||||
}
|
||||
|
||||
impl TerminalToolCard {
|
||||
pub fn new(input_command: String, working_dir: Option<PathBuf>, entity_id: EntityId) -> Self {
|
||||
pub fn new(
|
||||
input_command: Entity<Markdown>,
|
||||
working_dir: Option<PathBuf>,
|
||||
entity_id: EntityId,
|
||||
) -> Self {
|
||||
Self {
|
||||
input_command,
|
||||
working_dir,
|
||||
@@ -427,7 +459,7 @@ impl ToolCard for TerminalToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
window: &mut Window,
|
||||
_workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
@@ -571,11 +603,25 @@ impl ToolCard for TerminalToolCard {
|
||||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(
|
||||
v_flex().p_2().gap_0p5().bg(header_bg).child(header).child(
|
||||
Label::new(self.input_command.clone())
|
||||
.buffer_font(cx)
|
||||
.size(LabelSize::Small),
|
||||
),
|
||||
v_flex()
|
||||
.p_2()
|
||||
.gap_0p5()
|
||||
.bg(header_bg)
|
||||
.text_xs()
|
||||
.child(header)
|
||||
.child(
|
||||
MarkdownElement::new(
|
||||
self.input_command.clone(),
|
||||
markdown_style(window, cx),
|
||||
)
|
||||
.code_block_renderer(
|
||||
markdown::CodeBlockRenderer::Default {
|
||||
copy_button: false,
|
||||
copy_button_on_hover: true,
|
||||
border: false,
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
.when(self.preview_expanded && !should_hide_terminal, |this| {
|
||||
this.child(
|
||||
@@ -594,6 +640,27 @@ impl ToolCard for TerminalToolCard {
|
||||
}
|
||||
}
|
||||
|
||||
fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
|
||||
let theme_settings = ThemeSettings::get_global(cx);
|
||||
let buffer_font_size = TextSize::Default.rems(cx);
|
||||
let mut text_style = window.text_style();
|
||||
|
||||
text_style.refine(&TextStyleRefinement {
|
||||
font_family: Some(theme_settings.buffer_font.family.clone()),
|
||||
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
|
||||
font_features: Some(theme_settings.buffer_font.features.clone()),
|
||||
font_size: Some(buffer_font_size.into()),
|
||||
color: Some(cx.theme().colors().text),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
MarkdownStyle {
|
||||
base_text_style: text_style.clone(),
|
||||
selection_background_color: cx.theme().players().local().selection,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use editor::EditorSettings;
|
||||
@@ -665,8 +732,8 @@ mod tests {
|
||||
)
|
||||
});
|
||||
|
||||
let output = result.output.await.log_err().map(|output| output.content);
|
||||
assert_eq!(output, Some("Command executed successfully.".into()));
|
||||
let output = result.output.await.log_err().unwrap().content;
|
||||
assert_eq!(output.as_str().unwrap(), "Command executed successfully.");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
@@ -699,12 +766,13 @@ mod tests {
|
||||
cx,
|
||||
);
|
||||
cx.spawn(async move |_| {
|
||||
let output = headless_result
|
||||
.output
|
||||
.await
|
||||
.log_err()
|
||||
.map(|output| output.content);
|
||||
assert_eq!(output, expected);
|
||||
let output = headless_result.output.await.map(|output| output.content);
|
||||
assert_eq!(
|
||||
output
|
||||
.ok()
|
||||
.and_then(|content| content.as_str().map(ToString::to_string)),
|
||||
expected
|
||||
);
|
||||
})
|
||||
};
|
||||
|
||||
@@ -712,7 +780,7 @@ mod tests {
|
||||
check(
|
||||
TerminalToolInput {
|
||||
command: "pwd".into(),
|
||||
cd: "project".into(),
|
||||
cd: ".".into(),
|
||||
},
|
||||
Some(format!(
|
||||
"```\n{}\n```",
|
||||
@@ -727,12 +795,9 @@ mod tests {
|
||||
check(
|
||||
TerminalToolInput {
|
||||
command: "pwd".into(),
|
||||
cd: ".".into(),
|
||||
cd: "other-project".into(),
|
||||
},
|
||||
Some(format!(
|
||||
"```\n{}\n```",
|
||||
tree.path().join("project").display()
|
||||
)),
|
||||
None, // other-project is a dir, but *not* a worktree (yet)
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
||||
@@ -3,7 +3,9 @@ use std::{sync::Arc, time::Duration};
|
||||
use crate::schema::json_schema_for;
|
||||
use crate::ui::ToolCallCardHeader;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use assistant_tool::{
|
||||
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
|
||||
};
|
||||
use futures::{Future, FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
@@ -73,9 +75,13 @@ impl Tool for WebSearchTool {
|
||||
let search_task = search_task.clone();
|
||||
async move {
|
||||
let response = search_task.await.map_err(|err| anyhow!(err))?;
|
||||
serde_json::to_string(&response)
|
||||
.context("Failed to serialize search results")
|
||||
.map(Into::into)
|
||||
Ok(ToolResultOutput {
|
||||
content: ToolResultContent::Text(
|
||||
serde_json::to_string(&response)
|
||||
.context("Failed to serialize search results")?,
|
||||
),
|
||||
output: Some(serde_json::to_value(response)?),
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
@@ -84,6 +90,18 @@ impl Tool for WebSearchTool {
|
||||
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_card(
|
||||
self: Arc<Self>,
|
||||
output: serde_json::Value,
|
||||
_project: Entity<Project>,
|
||||
_window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<assistant_tool::AnyToolCard> {
|
||||
let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
|
||||
let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
|
||||
Some(card.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(RegisterComponent)]
|
||||
|
||||
@@ -38,6 +38,7 @@ pub enum Model {
|
||||
AmazonNovaLite,
|
||||
AmazonNovaMicro,
|
||||
AmazonNovaPro,
|
||||
AmazonNovaPremier,
|
||||
// AI21 models
|
||||
AI21J2GrandeInstruct,
|
||||
AI21J2JumboInstruct,
|
||||
@@ -72,6 +73,10 @@ pub enum Model {
|
||||
MistralMixtral8x7BInstructV0,
|
||||
MistralMistralLarge2402V1,
|
||||
MistralMistralSmall2402V1,
|
||||
MistralPixtralLarge2502V1,
|
||||
// Writer models
|
||||
PalmyraWriterX5,
|
||||
PalmyraWriterX4,
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
@@ -120,6 +125,7 @@ impl Model {
|
||||
Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
|
||||
Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
|
||||
Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
|
||||
Model::AmazonNovaPremier => "amazon.nova-premier-v1:0",
|
||||
Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
|
||||
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
|
||||
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
|
||||
@@ -149,6 +155,9 @@ impl Model {
|
||||
Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
|
||||
Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
|
||||
Model::MistralPixtralLarge2502V1 => "mistral.pixtral-large-2502-v1:0",
|
||||
Model::PalmyraWriterX4 => "writer.palmyra-x4-v1:0",
|
||||
Model::PalmyraWriterX5 => "writer.palmyra-x5-v1:0",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
@@ -166,6 +175,7 @@ impl Model {
|
||||
Self::AmazonNovaLite => "Amazon Nova Lite",
|
||||
Self::AmazonNovaMicro => "Amazon Nova Micro",
|
||||
Self::AmazonNovaPro => "Amazon Nova Pro",
|
||||
Self::AmazonNovaPremier => "Amazon Nova Premier",
|
||||
Self::DeepSeekR1 => "DeepSeek R1",
|
||||
Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
|
||||
Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
|
||||
@@ -195,6 +205,9 @@ impl Model {
|
||||
Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
|
||||
Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
|
||||
Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
|
||||
Self::MistralPixtralLarge2502V1 => "Pixtral Large 25.02 V1",
|
||||
Self::PalmyraWriterX5 => "Writer Palmyra X5",
|
||||
Self::PalmyraWriterX4 => "Writer Palmyra X4",
|
||||
Self::Custom {
|
||||
display_name, name, ..
|
||||
} => display_name.as_deref().unwrap_or(name),
|
||||
@@ -208,8 +221,11 @@ impl Model {
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3_5Haiku
|
||||
| Self::Claude3_7Sonnet => 200_000,
|
||||
Self::AmazonNovaPremier => 1_000_000,
|
||||
Self::PalmyraWriterX5 => 1_000_000,
|
||||
Self::PalmyraWriterX4 => 128_000,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
_ => 200_000,
|
||||
_ => 128_000,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,7 +233,7 @@ impl Model {
|
||||
match self {
|
||||
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
|
||||
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
|
||||
Self::Claude3_5SonnetV2 => 8_192,
|
||||
Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192,
|
||||
Self::Custom {
|
||||
max_output_tokens, ..
|
||||
} => max_output_tokens.unwrap_or(4_096),
|
||||
@@ -252,7 +268,10 @@ impl Model {
|
||||
| Self::Claude3_5Haiku => true,
|
||||
|
||||
// Amazon Nova models (all support tool use)
|
||||
Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
|
||||
Self::AmazonNovaPremier
|
||||
| Self::AmazonNovaPro
|
||||
| Self::AmazonNovaLite
|
||||
| Self::AmazonNovaMicro => true,
|
||||
|
||||
// AI21 Jamba 1.5 models support tool use
|
||||
Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
|
||||
@@ -305,8 +324,11 @@ impl Model {
|
||||
|
||||
// Models available only in US
|
||||
(Model::Claude3Opus, "us")
|
||||
| (Model::Claude3_5Haiku, "us")
|
||||
| (Model::Claude3_7Sonnet, "us")
|
||||
| (Model::Claude3_7SonnetThinking, "us") => {
|
||||
| (Model::Claude3_7SonnetThinking, "us")
|
||||
| (Model::AmazonNovaPremier, "us")
|
||||
| (Model::MistralPixtralLarge2502V1, "us") => {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
@@ -340,6 +362,12 @@ impl Model {
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// Writer models only available in the US
|
||||
(Model::PalmyraWriterX4, "us") | (Model::PalmyraWriterX5, "us") => {
|
||||
// They have some goofiness
|
||||
Ok(format!("{}.{}", region_group, model_id))
|
||||
}
|
||||
|
||||
// Any other combination is not supported
|
||||
_ => Ok(self.id().into()),
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ pub struct CallSettings {
|
||||
|
||||
/// Configuration of voice calls in Zed.
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct CallSettingsContent {
|
||||
/// Whether the microphone should be muted when joining a channel or a call.
|
||||
///
|
||||
|
||||
@@ -19,6 +19,7 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
|
||||
base64.workspace = true
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
@@ -29,6 +30,7 @@ gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
http_client.workspace = true
|
||||
http_client_tls.workspace = true
|
||||
httparse = "1.10"
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
parking_lot.workspace = true
|
||||
@@ -47,6 +49,7 @@ text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http = "0.8"
|
||||
tokio-native-tls = "0.3"
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod socks;
|
||||
mod proxy;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
@@ -24,13 +24,13 @@ use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
use rand::prelude::*;
|
||||
use release_channel::{AppVersion, ReleaseChannel};
|
||||
use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use socks::connect_socks_proxy_stream;
|
||||
use std::pin::Pin;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
@@ -49,7 +49,7 @@ use telemetry::Telemetry;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use url::Url;
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
use util::{ConnectionResult, ResultExt};
|
||||
|
||||
pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
@@ -151,9 +151,19 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
|
||||
let client = client.clone();
|
||||
move |_: &SignIn, cx| {
|
||||
if let Some(client) = client.upgrade() {
|
||||
cx.spawn(async move |cx| {
|
||||
client.authenticate_and_connect(true, &cx).log_err().await
|
||||
})
|
||||
cx.spawn(
|
||||
async move |cx| match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("Initial authentication timed out");
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::error!("Initial authentication connection reset");
|
||||
}
|
||||
ConnectionResult::Result(r) => {
|
||||
r.log_err();
|
||||
}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
@@ -658,7 +668,7 @@ impl Client {
|
||||
state._reconnect_task = None;
|
||||
}
|
||||
Status::ConnectionLost => {
|
||||
let this = self.clone();
|
||||
let client = self.clone();
|
||||
state._reconnect_task = Some(cx.spawn(async move |cx| {
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
@@ -666,10 +676,25 @@ impl Client {
|
||||
let mut rng = StdRng::from_entropy();
|
||||
|
||||
let mut delay = INITIAL_RECONNECTION_DELAY;
|
||||
while let Err(error) = this.authenticate_and_connect(true, &cx).await {
|
||||
log::error!("failed to connect {}", error);
|
||||
if matches!(*this.status().borrow(), Status::ConnectionError) {
|
||||
this.set_status(
|
||||
loop {
|
||||
match client.authenticate_and_connect(true, &cx).await {
|
||||
ConnectionResult::Timeout => {
|
||||
log::error!("client connect attempt timed out")
|
||||
}
|
||||
ConnectionResult::ConnectionReset => {
|
||||
log::error!("client connect attempt reset")
|
||||
}
|
||||
ConnectionResult::Result(r) => {
|
||||
if let Err(error) = r {
|
||||
log::error!("failed to connect: {error}");
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(*client.status().borrow(), Status::ConnectionError) {
|
||||
client.set_status(
|
||||
Status::ReconnectionError {
|
||||
next_reconnection: Instant::now() + delay,
|
||||
},
|
||||
@@ -827,7 +852,7 @@ impl Client {
|
||||
self: &Arc<Self>,
|
||||
try_provider: bool,
|
||||
cx: &AsyncApp,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> ConnectionResult<()> {
|
||||
let was_disconnected = match *self.status().borrow() {
|
||||
Status::SignedOut => true,
|
||||
Status::ConnectionError
|
||||
@@ -836,9 +861,14 @@ impl Client {
|
||||
| Status::Reauthenticating { .. }
|
||||
| Status::ReconnectionError { .. } => false,
|
||||
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
|
||||
return Ok(());
|
||||
return ConnectionResult::Result(Ok(()));
|
||||
}
|
||||
Status::UpgradeRequired => {
|
||||
return ConnectionResult::Result(
|
||||
Err(EstablishConnectionError::UpgradeRequired)
|
||||
.context("client auth and connect"),
|
||||
);
|
||||
}
|
||||
Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
|
||||
};
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Authenticating, cx);
|
||||
@@ -862,12 +892,12 @@ impl Client {
|
||||
Ok(creds) => credentials = Some(creds),
|
||||
Err(err) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
return Err(err);
|
||||
return ConnectionResult::Result(Err(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = status_rx.next().fuse() => {
|
||||
return Err(anyhow!("authentication canceled"));
|
||||
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -892,10 +922,10 @@ impl Client {
|
||||
}
|
||||
|
||||
futures::select_biased! {
|
||||
result = self.set_connection(conn, cx).fuse() => result,
|
||||
result = self.set_connection(conn, cx).fuse() => ConnectionResult::Result(result.context("client auth and connect")),
|
||||
_ = timeout => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(anyhow!("timed out waiting on hello message from server"))
|
||||
ConnectionResult::Timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -907,22 +937,22 @@ impl Client {
|
||||
self.authenticate_and_connect(false, cx).await
|
||||
} else {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(EstablishConnectionError::Unauthorized)?
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
|
||||
}
|
||||
}
|
||||
Err(EstablishConnectionError::UpgradeRequired) => {
|
||||
self.set_status(Status::UpgradeRequired, cx);
|
||||
Err(EstablishConnectionError::UpgradeRequired)?
|
||||
ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
|
||||
}
|
||||
Err(error) => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(error)?
|
||||
ConnectionResult::Result(Err(error).context("client auth and connect"))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = &mut timeout => {
|
||||
self.set_status(Status::ConnectionError, cx);
|
||||
Err(anyhow!("timed out trying to establish connection"))
|
||||
ConnectionResult::Timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -938,10 +968,7 @@ impl Client {
|
||||
|
||||
let peer_id = async {
|
||||
log::debug!("waiting for server hello");
|
||||
let message = incoming
|
||||
.next()
|
||||
.await
|
||||
.ok_or_else(|| anyhow!("no hello message received"))?;
|
||||
let message = incoming.next().await.context("no hello message received")?;
|
||||
log::debug!("got server hello");
|
||||
let hello_message_type_name = message.payload_type_name().to_string();
|
||||
let hello = message
|
||||
@@ -1129,7 +1156,7 @@ impl Client {
|
||||
let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
|
||||
let _guard = handle.enter();
|
||||
match proxy {
|
||||
Some(proxy) => connect_socks_proxy_stream(&proxy, rpc_host).await?,
|
||||
Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
|
||||
None => Box::new(TcpStream::connect(rpc_host).await?),
|
||||
}
|
||||
};
|
||||
@@ -1743,7 +1770,7 @@ mod tests {
|
||||
status.next().await,
|
||||
Some(Status::ConnectionError { .. })
|
||||
));
|
||||
auth_and_connect.await.unwrap_err();
|
||||
auth_and_connect.await.into_response().unwrap_err();
|
||||
|
||||
// Allow the connection to be established.
|
||||
let server = FakeServer::for_client(user_id, &client, cx).await;
|
||||
|
||||
66
crates/client/src/proxy.rs
Normal file
66
crates/client/src/proxy.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! client proxy
|
||||
|
||||
mod http_proxy;
|
||||
mod socks_proxy;
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use http_client::Url;
|
||||
use http_proxy::{HttpProxyType, connect_http_proxy_stream, parse_http_proxy};
|
||||
use socks_proxy::{SocksVersion, connect_socks_proxy_stream, parse_socks_proxy};
|
||||
|
||||
pub(crate) async fn connect_proxy_stream(
|
||||
proxy: &Url,
|
||||
rpc_host: (&str, u16),
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
let Some(((proxy_domain, proxy_port), proxy_type)) = parse_proxy_type(proxy) else {
|
||||
// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
|
||||
// SOCKS proxies are often used in contexts where security and privacy are critical,
|
||||
// so any fallback could expose users to significant risks.
|
||||
return Err(anyhow!("Parsing proxy url failed"));
|
||||
};
|
||||
|
||||
// Connect to proxy and wrap protocol later
|
||||
let stream = tokio::net::TcpStream::connect((proxy_domain.as_str(), proxy_port))
|
||||
.await
|
||||
.context("Failed to connect to proxy")?;
|
||||
|
||||
let proxy_stream = match proxy_type {
|
||||
ProxyType::SocksProxy(proxy) => connect_socks_proxy_stream(stream, proxy, rpc_host).await?,
|
||||
ProxyType::HttpProxy(proxy) => {
|
||||
connect_http_proxy_stream(stream, proxy, rpc_host, &proxy_domain).await?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(proxy_stream)
|
||||
}
|
||||
|
||||
enum ProxyType<'t> {
|
||||
SocksProxy(SocksVersion<'t>),
|
||||
HttpProxy(HttpProxyType<'t>),
|
||||
}
|
||||
|
||||
fn parse_proxy_type<'t>(proxy: &'t Url) -> Option<((String, u16), ProxyType<'t>)> {
|
||||
let scheme = proxy.scheme();
|
||||
let host = proxy.host()?.to_string();
|
||||
let port = proxy.port_or_known_default()?;
|
||||
let proxy_type = match scheme {
|
||||
scheme if scheme.starts_with("socks") => {
|
||||
Some(ProxyType::SocksProxy(parse_socks_proxy(scheme, proxy)))
|
||||
}
|
||||
scheme if scheme.starts_with("http") => {
|
||||
Some(ProxyType::HttpProxy(parse_http_proxy(scheme, proxy)))
|
||||
}
|
||||
_ => None,
|
||||
}?;
|
||||
|
||||
Some(((host, port), proxy_type))
|
||||
}
|
||||
|
||||
pub(crate) trait AsyncReadWrite:
|
||||
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
}
|
||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static> AsyncReadWrite
|
||||
for T
|
||||
{
|
||||
}
|
||||
171
crates/client/src/proxy/http_proxy.rs
Normal file
171
crates/client/src/proxy/http_proxy.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use anyhow::{Context, Result};
|
||||
use base64::Engine;
|
||||
use httparse::{EMPTY_HEADER, Response};
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
|
||||
net::TcpStream,
|
||||
};
|
||||
use tokio_native_tls::{TlsConnector, native_tls};
|
||||
use url::Url;
|
||||
|
||||
use super::AsyncReadWrite;
|
||||
|
||||
pub(super) enum HttpProxyType<'t> {
|
||||
HTTP(Option<HttpProxyAuthorization<'t>>),
|
||||
HTTPS(Option<HttpProxyAuthorization<'t>>),
|
||||
}
|
||||
|
||||
pub(super) struct HttpProxyAuthorization<'t> {
|
||||
username: &'t str,
|
||||
password: &'t str,
|
||||
}
|
||||
|
||||
pub(super) fn parse_http_proxy<'t>(scheme: &str, proxy: &'t Url) -> HttpProxyType<'t> {
|
||||
let auth = proxy.password().map(|password| HttpProxyAuthorization {
|
||||
username: proxy.username(),
|
||||
password,
|
||||
});
|
||||
if scheme.starts_with("https") {
|
||||
HttpProxyType::HTTPS(auth)
|
||||
} else {
|
||||
HttpProxyType::HTTP(auth)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn connect_http_proxy_stream(
|
||||
stream: TcpStream,
|
||||
http_proxy: HttpProxyType<'_>,
|
||||
rpc_host: (&str, u16),
|
||||
proxy_domain: &str,
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
match http_proxy {
|
||||
HttpProxyType::HTTP(auth) => http_connect(stream, rpc_host, auth).await,
|
||||
HttpProxyType::HTTPS(auth) => https_connect(stream, rpc_host, auth, proxy_domain).await,
|
||||
}
|
||||
.context("error connecting to http/https proxy")
|
||||
}
|
||||
|
||||
async fn http_connect<T>(
|
||||
stream: T,
|
||||
target: (&str, u16),
|
||||
auth: Option<HttpProxyAuthorization<'_>>,
|
||||
) -> Result<Box<dyn AsyncReadWrite>>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let mut stream = BufStream::new(stream);
|
||||
let request = make_request(target, auth);
|
||||
stream.write_all(request.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
check_response(&mut stream).await?;
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
|
||||
async fn https_connect<T>(
|
||||
stream: T,
|
||||
target: (&str, u16),
|
||||
auth: Option<HttpProxyAuthorization<'_>>,
|
||||
proxy_domain: &str,
|
||||
) -> Result<Box<dyn AsyncReadWrite>>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
|
||||
let stream = tls_connector.connect(proxy_domain, stream).await?;
|
||||
http_connect(stream, target, auth).await
|
||||
}
|
||||
|
||||
fn make_request(target: (&str, u16), auth: Option<HttpProxyAuthorization<'_>>) -> String {
|
||||
let (host, port) = target;
|
||||
let mut request = format!(
|
||||
"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\nProxy-Connection: Keep-Alive\r\n"
|
||||
);
|
||||
if let Some(HttpProxyAuthorization { username, password }) = auth {
|
||||
let auth =
|
||||
base64::prelude::BASE64_STANDARD.encode(format!("{username}:{password}").as_bytes());
|
||||
let auth = format!("Proxy-Authorization: Basic {auth}\r\n");
|
||||
request.push_str(&auth);
|
||||
}
|
||||
request.push_str("\r\n");
|
||||
request
|
||||
}
|
||||
|
||||
async fn check_response<T>(stream: &mut BufStream<T>) -> Result<()>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let response = recv_response(stream).await?;
|
||||
let mut dummy_headers = [EMPTY_HEADER; MAX_RESPONSE_HEADERS];
|
||||
let mut parser = Response::new(&mut dummy_headers);
|
||||
parser.parse(response.as_bytes())?;
|
||||
|
||||
match parser.code {
|
||||
Some(code) => {
|
||||
if code == 200 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Proxy connection failed with HTTP code: {code}"
|
||||
))
|
||||
}
|
||||
}
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Proxy connection failed with no HTTP code: {}",
|
||||
parser.reason.unwrap_or("Unknown reason")
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_RESPONSE_HEADER_LENGTH: usize = 4096;
|
||||
const MAX_RESPONSE_HEADERS: usize = 16;
|
||||
|
||||
async fn recv_response<T>(stream: &mut BufStream<T>) -> Result<String>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let mut response = String::new();
|
||||
loop {
|
||||
if stream.read_line(&mut response).await? == 0 {
|
||||
return Err(anyhow::anyhow!("End of stream"));
|
||||
}
|
||||
|
||||
if MAX_RESPONSE_HEADER_LENGTH < response.len() {
|
||||
return Err(anyhow::anyhow!("Maximum response header length exceeded"));
|
||||
}
|
||||
|
||||
if response.ends_with("\r\n\r\n") {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use url::Url;
|
||||
|
||||
use super::{HttpProxyAuthorization, HttpProxyType, parse_http_proxy};
|
||||
|
||||
#[test]
|
||||
fn test_parse_http_proxy() {
|
||||
let proxy = Url::parse("http://proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_http_proxy(scheme, &proxy);
|
||||
assert!(matches!(version, HttpProxyType::HTTP(None)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_http_proxy_with_auth() {
|
||||
let proxy = Url::parse("http://username:password@proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_http_proxy(scheme, &proxy);
|
||||
assert!(matches!(
|
||||
version,
|
||||
HttpProxyType::HTTP(Some(HttpProxyAuthorization {
|
||||
username: "username",
|
||||
password: "password"
|
||||
}))
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,19 @@
|
||||
//! socks proxy
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use http_client::Url;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
|
||||
use url::Url;
|
||||
|
||||
use super::AsyncReadWrite;
|
||||
|
||||
/// Identification to a Socks V4 Proxy
|
||||
struct Socks4Identification<'a> {
|
||||
pub(super) struct Socks4Identification<'a> {
|
||||
user_id: &'a str,
|
||||
}
|
||||
|
||||
/// Authorization to a Socks V5 Proxy
|
||||
struct Socks5Authorization<'a> {
|
||||
pub(super) struct Socks5Authorization<'a> {
|
||||
username: &'a str,
|
||||
password: &'a str,
|
||||
}
|
||||
@@ -18,45 +22,50 @@ struct Socks5Authorization<'a> {
|
||||
///
|
||||
/// V4 allows idenfication using a user_id
|
||||
/// V5 allows authorization using a username and password
|
||||
enum SocksVersion<'a> {
|
||||
pub(super) enum SocksVersion<'a> {
|
||||
V4(Option<Socks4Identification<'a>>),
|
||||
V5(Option<Socks5Authorization<'a>>),
|
||||
}
|
||||
|
||||
pub(crate) async fn connect_socks_proxy_stream(
|
||||
proxy: &Url,
|
||||
pub(super) fn parse_socks_proxy<'t>(scheme: &str, proxy: &'t Url) -> SocksVersion<'t> {
|
||||
if scheme.starts_with("socks4") {
|
||||
let identification = match proxy.username() {
|
||||
"" => None,
|
||||
username => Some(Socks4Identification { user_id: username }),
|
||||
};
|
||||
SocksVersion::V4(identification)
|
||||
} else {
|
||||
let authorization = proxy.password().map(|password| Socks5Authorization {
|
||||
username: proxy.username(),
|
||||
password,
|
||||
});
|
||||
SocksVersion::V5(authorization)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn connect_socks_proxy_stream(
|
||||
stream: TcpStream,
|
||||
socks_version: SocksVersion<'_>,
|
||||
rpc_host: (&str, u16),
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
let Some((socks_proxy, version)) = parse_socks_proxy(proxy) else {
|
||||
// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
|
||||
// SOCKS proxies are often used in contexts where security and privacy are critical,
|
||||
// so any fallback could expose users to significant risks.
|
||||
return Err(anyhow!("Parsing proxy url failed"));
|
||||
};
|
||||
|
||||
// Connect to proxy and wrap protocol later
|
||||
let stream = tokio::net::TcpStream::connect(socks_proxy)
|
||||
.await
|
||||
.context("Failed to connect to socks proxy")?;
|
||||
|
||||
let socks: Box<dyn AsyncReadWrite> = match version {
|
||||
match socks_version {
|
||||
SocksVersion::V4(None) => {
|
||||
let socks = Socks4Stream::connect_with_socket(stream, rpc_host)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Box::new(socks)
|
||||
Ok(Box::new(socks))
|
||||
}
|
||||
SocksVersion::V4(Some(Socks4Identification { user_id })) => {
|
||||
let socks = Socks4Stream::connect_with_userid_and_socket(stream, rpc_host, user_id)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Box::new(socks)
|
||||
Ok(Box::new(socks))
|
||||
}
|
||||
SocksVersion::V5(None) => {
|
||||
let socks = Socks5Stream::connect_with_socket(stream, rpc_host)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Box::new(socks)
|
||||
Ok(Box::new(socks))
|
||||
}
|
||||
SocksVersion::V5(Some(Socks5Authorization { username, password })) => {
|
||||
let socks = Socks5Stream::connect_with_password_and_socket(
|
||||
@@ -64,44 +73,9 @@ pub(crate) async fn connect_socks_proxy_stream(
|
||||
)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Box::new(socks)
|
||||
Ok(Box::new(socks))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(socks)
|
||||
}
|
||||
|
||||
fn parse_socks_proxy(proxy: &Url) -> Option<((String, u16), SocksVersion<'_>)> {
|
||||
let scheme = proxy.scheme();
|
||||
let socks_version = if scheme.starts_with("socks4") {
|
||||
let identification = match proxy.username() {
|
||||
"" => None,
|
||||
username => Some(Socks4Identification { user_id: username }),
|
||||
};
|
||||
SocksVersion::V4(identification)
|
||||
} else if scheme.starts_with("socks") {
|
||||
let authorization = proxy.password().map(|password| Socks5Authorization {
|
||||
username: proxy.username(),
|
||||
password,
|
||||
});
|
||||
SocksVersion::V5(authorization)
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let host = proxy.host()?.to_string();
|
||||
let port = proxy.port_or_known_default()?;
|
||||
|
||||
Some(((host, port), socks_version))
|
||||
}
|
||||
|
||||
pub(crate) trait AsyncReadWrite:
|
||||
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
}
|
||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static> AsyncReadWrite
|
||||
for T
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -113,20 +87,18 @@ mod tests {
|
||||
#[test]
|
||||
fn parse_socks4() {
|
||||
let proxy = Url::parse("socks4://proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
assert!(matches!(version, SocksVersion::V4(None)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_socks4_with_identification() {
|
||||
let proxy = Url::parse("socks4://userid@proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
assert!(matches!(
|
||||
version,
|
||||
SocksVersion::V4(Some(Socks4Identification { user_id: "userid" }))
|
||||
@@ -136,20 +108,18 @@ mod tests {
|
||||
#[test]
|
||||
fn parse_socks5() {
|
||||
let proxy = Url::parse("socks5://proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
assert!(matches!(version, SocksVersion::V5(None)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_socks5_with_authorization() {
|
||||
let proxy = Url::parse("socks5://username:password@proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
assert!(matches!(
|
||||
version,
|
||||
SocksVersion::V5(Some(Socks5Authorization {
|
||||
@@ -158,19 +128,4 @@ mod tests {
|
||||
}))
|
||||
))
|
||||
}
|
||||
|
||||
/// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
|
||||
/// SOCKS proxies are often used in contexts where security and privacy are critical,
|
||||
/// so any fallback could expose users to significant risks.
|
||||
#[tokio::test]
|
||||
async fn fails_on_bad_proxy() {
|
||||
// Should fail connecting because http is not a valid Socks proxy scheme
|
||||
let proxy = Url::parse("http://localhost:2313").unwrap();
|
||||
|
||||
let result = connect_socks_proxy_stream(&proxy, ("test", 1080)).await;
|
||||
match result {
|
||||
Err(e) => assert_eq!(e.to_string(), "Parsing proxy url failed"),
|
||||
Ok(_) => panic!("Connecting on bad proxy should fail"),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -137,18 +137,14 @@ pub fn os_version() -> String {
|
||||
log::error!("Failed to load /etc/os-release, /usr/lib/os-release");
|
||||
"".to_string()
|
||||
};
|
||||
let mut name = "unknown".to_string();
|
||||
let mut version = "unknown".to_string();
|
||||
let mut name = "unknown";
|
||||
let mut version = "unknown";
|
||||
|
||||
for line in content.lines() {
|
||||
if line.starts_with("ID=") {
|
||||
name = line.trim_start_matches("ID=").trim_matches('"').to_string();
|
||||
}
|
||||
if line.starts_with("VERSION_ID=") {
|
||||
version = line
|
||||
.trim_start_matches("VERSION_ID=")
|
||||
.trim_matches('"')
|
||||
.to_string();
|
||||
match line.split_once('=') {
|
||||
Some(("ID", val)) => name = val.trim_matches('"'),
|
||||
Some(("VERSION_ID", val)) => version = val.trim_matches('"'),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +218,7 @@ impl Telemetry {
|
||||
cx.background_spawn({
|
||||
let state = state.clone();
|
||||
let os_version = os_version();
|
||||
state.lock().os_version = Some(os_version.clone());
|
||||
state.lock().os_version = Some(os_version);
|
||||
async move {
|
||||
if let Some(tempfile) = File::create(Self::log_file_path()).log_err() {
|
||||
state.lock().log_file = Some(tempfile);
|
||||
@@ -369,7 +365,7 @@ impl Telemetry {
|
||||
telemetry::event!(
|
||||
"Editor Edited",
|
||||
duration = duration,
|
||||
environment = environment.to_string(),
|
||||
environment = environment,
|
||||
is_via_ssh = is_via_ssh
|
||||
);
|
||||
}
|
||||
@@ -431,9 +427,8 @@ impl Telemetry {
|
||||
|
||||
if state.flush_events_task.is_none() {
|
||||
let this = self.clone();
|
||||
let executor = self.executor.clone();
|
||||
state.flush_events_task = Some(self.executor.spawn(async move {
|
||||
executor.timer(FLUSH_INTERVAL).await;
|
||||
this.executor.timer(FLUSH_INTERVAL).await;
|
||||
this.flush_events().detach();
|
||||
}));
|
||||
}
|
||||
@@ -484,12 +479,12 @@ impl Telemetry {
|
||||
self: &Arc<Self>,
|
||||
// We take in the JSON bytes buffer so we can reuse the existing allocation.
|
||||
mut json_bytes: Vec<u8>,
|
||||
event_request: EventRequestBody,
|
||||
event_request: &EventRequestBody,
|
||||
) -> Result<Request<AsyncBody>> {
|
||||
json_bytes.clear();
|
||||
serde_json::to_writer(&mut json_bytes, &event_request)?;
|
||||
serde_json::to_writer(&mut json_bytes, event_request)?;
|
||||
|
||||
let checksum = calculate_json_checksum(&json_bytes).unwrap_or("".to_string());
|
||||
let checksum = calculate_json_checksum(&json_bytes).unwrap_or_default();
|
||||
|
||||
Ok(Request::builder()
|
||||
.method(Method::POST)
|
||||
@@ -506,7 +501,7 @@ impl Telemetry {
|
||||
pub fn flush_events(self: &Arc<Self>) -> Task<()> {
|
||||
let mut state = self.state.lock();
|
||||
state.first_event_date_time = None;
|
||||
let mut events = mem::take(&mut state.events_queue);
|
||||
let events = mem::take(&mut state.events_queue);
|
||||
state.flush_events_task.take();
|
||||
drop(state);
|
||||
if events.is_empty() {
|
||||
@@ -519,7 +514,7 @@ impl Telemetry {
|
||||
let mut json_bytes = Vec::new();
|
||||
|
||||
if let Some(file) = &mut this.state.lock().log_file {
|
||||
for event in &mut events {
|
||||
for event in &events {
|
||||
json_bytes.clear();
|
||||
serde_json::to_writer(&mut json_bytes, event)?;
|
||||
file.write_all(&json_bytes)?;
|
||||
@@ -546,7 +541,7 @@ impl Telemetry {
|
||||
}
|
||||
};
|
||||
|
||||
let request = this.build_request(json_bytes, request_body)?;
|
||||
let request = this.build_request(json_bytes, &request_body)?;
|
||||
let response = this.http_client.send(request).await?;
|
||||
if response.status() != 200 {
|
||||
log::error!("Failed to send events: HTTP {:?}", response.status());
|
||||
|
||||
@@ -107,6 +107,7 @@ impl FakeServer {
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
||||
server
|
||||
|
||||
@@ -1137,6 +1137,12 @@ async fn handle_customer_subscription_event(
|
||||
.await?;
|
||||
}
|
||||
|
||||
// When the user's subscription changes, push down any changes to their plan.
|
||||
rpc_server
|
||||
.update_plan_for_user(billing_customer.user_id)
|
||||
.await
|
||||
.trace_err();
|
||||
|
||||
// When the user's subscription changes, we want to refresh their LLM tokens
|
||||
// to either grant/revoke access.
|
||||
rpc_server
|
||||
@@ -1274,7 +1280,7 @@ async fn get_current_usage(
|
||||
subscription
|
||||
.kind
|
||||
.map(Into::into)
|
||||
.unwrap_or(zed_llm_client::Plan::Free)
|
||||
.unwrap_or(zed_llm_client::Plan::ZedFree)
|
||||
});
|
||||
|
||||
let model_requests_limit = match plan.model_requests_limit() {
|
||||
|
||||
@@ -543,7 +543,7 @@ pub struct MembershipUpdated {
|
||||
|
||||
/// The result of setting a member's role.
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
|
||||
pub enum SetMemberRoleResult {
|
||||
InviteUpdated(Channel),
|
||||
MembershipUpdated(MembershipUpdated),
|
||||
|
||||
@@ -99,7 +99,7 @@ impl From<SubscriptionKind> for zed_llm_client::Plan {
|
||||
match value {
|
||||
SubscriptionKind::ZedPro => Self::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Self::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => Self::Free,
|
||||
SubscriptionKind::ZedFree => Self::ZedFree,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,18 +25,12 @@ pub struct LlmTokenClaims {
|
||||
pub is_staff: bool,
|
||||
pub has_llm_closed_beta_feature_flag: bool,
|
||||
pub bypass_account_age_check: bool,
|
||||
#[serde(default)]
|
||||
pub use_llm_request_queue: bool,
|
||||
pub plan: Plan,
|
||||
#[serde(default)]
|
||||
pub has_extended_trial: bool,
|
||||
#[serde(default)]
|
||||
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
|
||||
#[serde(default)]
|
||||
pub subscription_period: (NaiveDateTime, NaiveDateTime),
|
||||
pub enable_model_request_overages: bool,
|
||||
#[serde(default)]
|
||||
pub model_request_overages_spend_limit_in_cents: u32,
|
||||
#[serde(default)]
|
||||
pub can_use_web_search_tool: bool,
|
||||
}
|
||||
|
||||
@@ -57,6 +51,23 @@ impl LlmTokenClaims {
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("no LLM API secret"))?;
|
||||
|
||||
let plan = if is_staff {
|
||||
Plan::ZedPro
|
||||
} else {
|
||||
subscription
|
||||
.as_ref()
|
||||
.and_then(|subscription| subscription.kind)
|
||||
.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
};
|
||||
let subscription_period =
|
||||
billing_subscription::Model::current_period(subscription, is_staff)
|
||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
||||
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
|
||||
|
||||
let now = Utc::now();
|
||||
let claims = Self {
|
||||
iat: now.timestamp() as u64,
|
||||
@@ -76,26 +87,11 @@ impl LlmTokenClaims {
|
||||
.any(|flag| flag == "bypass-account-age-check"),
|
||||
can_use_web_search_tool: true,
|
||||
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
|
||||
plan: if is_staff {
|
||||
Plan::ZedPro
|
||||
} else {
|
||||
subscription
|
||||
.as_ref()
|
||||
.and_then(|subscription| subscription.kind)
|
||||
.map_or(Plan::Free, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::Free,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
},
|
||||
plan,
|
||||
has_extended_trial: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
|
||||
subscription_period: billing_subscription::Model::current_period(
|
||||
subscription,
|
||||
is_staff,
|
||||
)
|
||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc())),
|
||||
subscription_period,
|
||||
enable_model_request_overages: billing_preferences
|
||||
.as_ref()
|
||||
.map_or(false, |preferences| {
|
||||
|
||||
@@ -36,6 +36,7 @@ use util::{ResultExt as _, maybe};
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
|
||||
|
||||
#[expect(clippy::result_large_err)]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
if let Err(error) = env::load_dotenv() {
|
||||
|
||||
@@ -2,6 +2,7 @@ mod connection_pool;
|
||||
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
@@ -67,7 +68,7 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use time::OffsetDateTime;
|
||||
use tokio::sync::{MutexGuard, Semaphore, watch};
|
||||
use tokio::sync::{Semaphore, watch};
|
||||
use tower::ServiceBuilder;
|
||||
use tracing::{
|
||||
Instrument,
|
||||
@@ -166,29 +167,6 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
return Ok(proto::Plan::ZedPro);
|
||||
}
|
||||
|
||||
let user_id = self.user_id();
|
||||
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
||||
|
||||
let plan = if let Some(subscription_kind) = subscription_kind {
|
||||
match subscription_kind {
|
||||
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => proto::Plan::Free,
|
||||
}
|
||||
} else {
|
||||
proto::Plan::Free
|
||||
};
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
fn user_id(&self) -> UserId {
|
||||
match &self.principal {
|
||||
Principal::User(user) => user.id,
|
||||
@@ -953,6 +931,32 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
|
||||
let user = self
|
||||
.app_state
|
||||
.db
|
||||
.get_user_by_id(user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let update_user_plan = make_update_user_plan_message(
|
||||
&self.app_state.db,
|
||||
self.app_state.llm_db.clone(),
|
||||
user_id,
|
||||
user.admin,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
self.peer
|
||||
.send(connection_id, update_user_plan.clone())
|
||||
.trace_err();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
|
||||
let pool = self.connection_pool.lock();
|
||||
for connection_id in pool.user_connection_ids(user_id) {
|
||||
@@ -2688,21 +2692,43 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
|
||||
version.0.minor() < 139
|
||||
}
|
||||
|
||||
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
|
||||
if is_staff {
|
||||
return Ok(proto::Plan::ZedPro);
|
||||
}
|
||||
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
||||
|
||||
let plan = if let Some(subscription_kind) = subscription_kind {
|
||||
match subscription_kind {
|
||||
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => proto::Plan::Free,
|
||||
}
|
||||
} else {
|
||||
proto::Plan::Free
|
||||
};
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
async fn make_update_user_plan_message(
|
||||
db: &Arc<Database>,
|
||||
llm_db: Option<Arc<LlmDatabase>>,
|
||||
user_id: UserId,
|
||||
is_staff: bool,
|
||||
) -> Result<proto::UpdateUserPlan> {
|
||||
let feature_flags = db.get_user_flags(user_id).await?;
|
||||
let plan = session.current_plan(&db).await?;
|
||||
let plan = current_plan(db, user_id, is_staff).await?;
|
||||
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
|
||||
let billing_preferences = db.get_billing_preferences(user_id).await?;
|
||||
|
||||
let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
|
||||
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
|
||||
let subscription_period = crate::db::billing_subscription::Model::current_period(
|
||||
subscription,
|
||||
session.is_staff(),
|
||||
);
|
||||
let subscription_period =
|
||||
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
|
||||
|
||||
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
|
||||
llm_db
|
||||
@@ -2717,92 +2743,92 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
session
|
||||
.peer
|
||||
.send(
|
||||
session.connection_id,
|
||||
proto::UpdateUserPlan {
|
||||
plan: plan.into(),
|
||||
trial_started_at: billing_customer
|
||||
.and_then(|billing_customer| billing_customer.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
|
||||
is_usage_based_billing_enabled: if session.is_staff() {
|
||||
Some(true)
|
||||
} else {
|
||||
billing_preferences
|
||||
.map(|preferences| preferences.model_request_overages_enabled)
|
||||
},
|
||||
subscription_period: subscription_period.map(|(started_at, ended_at)| {
|
||||
proto::SubscriptionPeriod {
|
||||
started_at: started_at.timestamp() as u64,
|
||||
ended_at: ended_at.timestamp() as u64,
|
||||
}
|
||||
}),
|
||||
usage: usage.map(|usage| {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => zed_llm_client::Plan::Free,
|
||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
Ok(proto::UpdateUserPlan {
|
||||
plan: plan.into(),
|
||||
trial_started_at: billing_customer
|
||||
.and_then(|billing_customer| billing_customer.trial_started_at)
|
||||
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
|
||||
is_usage_based_billing_enabled: if is_staff {
|
||||
Some(true)
|
||||
} else {
|
||||
billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
|
||||
},
|
||||
subscription_period: subscription_period.map(|(started_at, ended_at)| {
|
||||
proto::SubscriptionPeriod {
|
||||
started_at: started_at.timestamp() as u64,
|
||||
ended_at: ended_at.timestamp() as u64,
|
||||
}
|
||||
}),
|
||||
usage: usage.map(|usage| {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
let model_requests_limit = match plan.model_requests_limit() {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial
|
||||
&& feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
||||
{
|
||||
1_000
|
||||
} else {
|
||||
limit
|
||||
};
|
||||
|
||||
let model_requests_limit = match plan.model_requests_limit() {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial
|
||||
&& feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
||||
{
|
||||
1_000
|
||||
} else {
|
||||
limit
|
||||
};
|
||||
zed_llm_client::UsageLimit::Limited(limit)
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
|
||||
};
|
||||
|
||||
zed_llm_client::UsageLimit::Limited(limit)
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: usage.model_requests as u32,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: usage.model_requests as u32,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(
|
||||
proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
},
|
||||
)
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(
|
||||
proto::usage_limit::Unlimited {},
|
||||
)
|
||||
}
|
||||
}),
|
||||
}),
|
||||
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(
|
||||
proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
},
|
||||
)
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(
|
||||
proto::usage_limit::Unlimited {},
|
||||
)
|
||||
}
|
||||
}),
|
||||
}),
|
||||
}
|
||||
}),
|
||||
}),
|
||||
},
|
||||
)
|
||||
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
}),
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
let db = session.db().await;
|
||||
|
||||
let update_user_plan = make_update_user_plan_message(
|
||||
&db.0,
|
||||
session.app_state.llm_db.clone(),
|
||||
user_id,
|
||||
session.is_staff(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
session
|
||||
.peer
|
||||
.send(session.connection_id, update_user_plan)
|
||||
.trace_err();
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -248,6 +248,8 @@ impl StripeBilling {
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.payment_method_collection =
|
||||
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
|
||||
@@ -36,8 +36,8 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
|
||||
room.read_with(cx, |room, _| {
|
||||
let mut remote = room
|
||||
.remote_participants()
|
||||
.iter()
|
||||
.map(|(_, participant)| participant.user.github_login.clone())
|
||||
.values()
|
||||
.map(|participant| participant.user.github_login.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let mut pending = room
|
||||
.pending_participants()
|
||||
|
||||
@@ -1740,6 +1740,7 @@ async fn test_mutual_editor_inlay_hint_cache_update(
|
||||
fake_language_server
|
||||
.request::<lsp::request::InlayHintRefreshRequest>(())
|
||||
.await
|
||||
.into_response()
|
||||
.expect("inlay refresh request failed");
|
||||
|
||||
executor.run_until_parked();
|
||||
@@ -1930,6 +1931,7 @@ async fn test_inlay_hint_refresh_is_forwarded(
|
||||
fake_language_server
|
||||
.request::<lsp::request::InlayHintRefreshRequest>(())
|
||||
.await
|
||||
.into_response()
|
||||
.expect("inlay refresh request failed");
|
||||
executor.run_until_parked();
|
||||
editor_a.update(cx_a, |editor, _| {
|
||||
|
||||
@@ -1253,6 +1253,7 @@ async fn test_calls_on_multiple_connections(
|
||||
client_b1
|
||||
.authenticate_and_connect(false, &cx_b1.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
||||
// User B hangs up, and user A calls them again.
|
||||
@@ -1633,6 +1634,7 @@ async fn test_project_reconnect(
|
||||
client_a
|
||||
.authenticate_and_connect(false, &cx_a.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
executor.run_until_parked();
|
||||
|
||||
@@ -1761,6 +1763,7 @@ async fn test_project_reconnect(
|
||||
client_b
|
||||
.authenticate_and_connect(false, &cx_b.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
executor.run_until_parked();
|
||||
|
||||
@@ -4317,6 +4320,7 @@ async fn test_collaborating_with_lsp_progress_updates_and_diagnostics_ordering(
|
||||
token: lsp::NumberOrString::String("the-disk-based-token".to_string()),
|
||||
})
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
fake_language_server.notify::<lsp::notification::Progress>(&lsp::ProgressParams {
|
||||
token: lsp::NumberOrString::String("the-disk-based-token".to_string()),
|
||||
@@ -5699,6 +5703,7 @@ async fn test_contacts(
|
||||
client_c
|
||||
.authenticate_and_connect(false, &cx_c.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
||||
executor.run_until_parked();
|
||||
@@ -6229,6 +6234,7 @@ async fn test_contact_requests(
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,6 +313,7 @@ impl TestServer {
|
||||
client
|
||||
.authenticate_and_connect(false, &cx.to_async())
|
||||
.await
|
||||
.into_response()
|
||||
.unwrap();
|
||||
|
||||
let client = TestClient {
|
||||
|
||||
@@ -42,6 +42,7 @@ futures.workspace = true
|
||||
fuzzy.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
menu.workspace = true
|
||||
notifications.workspace = true
|
||||
picker.workspace = true
|
||||
|
||||
@@ -1059,7 +1059,7 @@ impl Render for ChatPanel {
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"@{}",
|
||||
user_being_replied_to.github_login.clone()
|
||||
user_being_replied_to.github_login
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.weight(FontWeight::BOLD),
|
||||
|
||||
@@ -378,16 +378,27 @@ impl CollabPanel {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
mut cx: AsyncWindowContext,
|
||||
) -> anyhow::Result<Entity<Self>> {
|
||||
let serialized_panel = cx
|
||||
.background_spawn(async move { KEY_VALUE_STORE.read_kvp(COLLABORATION_PANEL_KEY) })
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Failed to read collaboration panel from key value store"))
|
||||
.log_err()
|
||||
let serialized_panel = match workspace
|
||||
.read_with(&cx, |workspace, _| {
|
||||
CollabPanel::serialization_key(workspace)
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|panel| serde_json::from_str::<SerializedCollabPanel>(&panel))
|
||||
.transpose()
|
||||
.log_err()
|
||||
.flatten();
|
||||
{
|
||||
Some(serialization_key) => cx
|
||||
.background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!("Failed to read collaboration panel from key value store")
|
||||
})
|
||||
.log_err()
|
||||
.flatten()
|
||||
.map(|panel| serde_json::from_str::<SerializedCollabPanel>(&panel))
|
||||
.transpose()
|
||||
.log_err()
|
||||
.flatten(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
workspace.update_in(&mut cx, |workspace, window, cx| {
|
||||
let panel = CollabPanel::new(workspace, window, cx);
|
||||
@@ -407,14 +418,30 @@ impl CollabPanel {
|
||||
})
|
||||
}
|
||||
|
||||
fn serialization_key(workspace: &Workspace) -> Option<String> {
|
||||
workspace
|
||||
.database_id()
|
||||
.map(|id| i64::from(id).to_string())
|
||||
.or(workspace.session_id())
|
||||
.map(|id| format!("{}-{:?}", COLLABORATION_PANEL_KEY, id))
|
||||
}
|
||||
|
||||
fn serialize(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(serialization_key) = self
|
||||
.workspace
|
||||
.update(cx, |workspace, _| CollabPanel::serialization_key(workspace))
|
||||
.ok()
|
||||
.flatten()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let width = self.width;
|
||||
let collapsed_channels = self.collapsed_channels.clone();
|
||||
self.pending_serialization = cx.background_spawn(
|
||||
async move {
|
||||
KEY_VALUE_STORE
|
||||
.write_kvp(
|
||||
COLLABORATION_PANEL_KEY.into(),
|
||||
serialization_key,
|
||||
serde_json::to_string(&SerializedCollabPanel {
|
||||
width,
|
||||
collapsed_channels: Some(
|
||||
@@ -2227,6 +2254,7 @@ impl CollabPanel {
|
||||
client
|
||||
.authenticate_and_connect(true, &cx)
|
||||
.await
|
||||
.into_response()
|
||||
.notify_async_err(cx);
|
||||
})
|
||||
.detach()
|
||||
@@ -2998,10 +3026,12 @@ impl Panel for CollabPanel {
|
||||
.unwrap_or_else(|| CollaborationPanelSettings::get_global(cx).default_width)
|
||||
}
|
||||
|
||||
fn set_size(&mut self, size: Option<Pixels>, _: &mut Window, cx: &mut Context<Self>) {
|
||||
fn set_size(&mut self, size: Option<Pixels>, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.width = size;
|
||||
self.serialize(cx);
|
||||
cx.notify();
|
||||
cx.defer_in(window, |this, _, cx| {
|
||||
this.serialize(cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn icon(&self, _window: &Window, cx: &App) -> Option<ui::IconName> {
|
||||
|
||||
@@ -646,10 +646,20 @@ impl Render for NotificationPanel {
|
||||
let client = client.clone();
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
client
|
||||
match client
|
||||
.authenticate_and_connect(true, &cx)
|
||||
.log_err()
|
||||
.await;
|
||||
.await
|
||||
{
|
||||
util::ConnectionResult::Timeout => {
|
||||
log::error!("Connection timeout");
|
||||
}
|
||||
util::ConnectionResult::ConnectionReset => {
|
||||
log::error!("Connection reset");
|
||||
}
|
||||
util::ConnectionResult::Result(r) => {
|
||||
r.log_err();
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach()
|
||||
}
|
||||
|
||||
@@ -7,11 +7,11 @@ use crate::notifications::collab_notification::CollabNotification;
|
||||
pub struct CollabNotificationStory;
|
||||
|
||||
impl Render for CollabNotificationStory {
|
||||
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let window_container = |width, height| div().w(px(width)).h(px(height));
|
||||
|
||||
Story::container()
|
||||
.child(Story::title_for::<CollabNotification>())
|
||||
Story::container(cx)
|
||||
.child(Story::title_for::<CollabNotification>(cx))
|
||||
.child(
|
||||
StorySection::new().child(StoryItem::new(
|
||||
"Incoming Call Notification",
|
||||
|
||||
@@ -28,6 +28,7 @@ pub struct ChatPanelSettings {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct ChatPanelSettingsContent {
|
||||
/// When to show the panel button in the status bar.
|
||||
///
|
||||
@@ -51,6 +52,7 @@ pub struct NotificationPanelSettings {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct PanelSettingsContent {
|
||||
/// Whether to show the panel button in the status bar.
|
||||
///
|
||||
@@ -67,6 +69,7 @@ pub struct PanelSettingsContent {
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[schemars(deny_unknown_fields)]
|
||||
pub struct MessageEditorSettings {
|
||||
/// Whether to automatically replace emoji shortcodes with emoji characters.
|
||||
/// For example: typing `:wave:` gets replaced with `👋`.
|
||||
|
||||
@@ -14,7 +14,7 @@ path = "src/component.rs"
|
||||
[dependencies]
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
linkme.workspace = true
|
||||
inventory.workspace = true
|
||||
parking_lot.workspace = true
|
||||
strum.workspace = true
|
||||
theme.workspace = true
|
||||
|
||||
@@ -9,13 +9,12 @@
|
||||
|
||||
mod component_layout;
|
||||
|
||||
pub use component_layout::*;
|
||||
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub use component_layout::*;
|
||||
|
||||
use collections::HashMap;
|
||||
use gpui::{AnyElement, App, SharedString, Window};
|
||||
use linkme::distributed_slice;
|
||||
use parking_lot::RwLock;
|
||||
use strum::{Display, EnumString};
|
||||
|
||||
@@ -24,12 +23,27 @@ pub fn components() -> ComponentRegistry {
|
||||
}
|
||||
|
||||
pub fn init() {
|
||||
let component_fns: Vec<_> = __ALL_COMPONENTS.iter().cloned().collect();
|
||||
for f in component_fns {
|
||||
f();
|
||||
for f in inventory::iter::<ComponentFn>() {
|
||||
(f.0)();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComponentFn(fn());
|
||||
|
||||
impl ComponentFn {
|
||||
pub const fn new(f: fn()) -> Self {
|
||||
Self(f)
|
||||
}
|
||||
}
|
||||
|
||||
inventory::collect!(ComponentFn);
|
||||
|
||||
/// Private internals for macros.
|
||||
#[doc(hidden)]
|
||||
pub mod __private {
|
||||
pub use inventory;
|
||||
}
|
||||
|
||||
pub fn register_component<T: Component>() {
|
||||
let id = T::id();
|
||||
let metadata = ComponentMetadata {
|
||||
@@ -46,9 +60,6 @@ pub fn register_component<T: Component>() {
|
||||
data.components.insert(id, metadata);
|
||||
}
|
||||
|
||||
#[distributed_slice]
|
||||
pub static __ALL_COMPONENTS: [fn()] = [..];
|
||||
|
||||
pub static COMPONENT_DATA: LazyLock<RwLock<ComponentRegistry>> =
|
||||
LazyLock::new(|| RwLock::new(ComponentRegistry::default()));
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
[package]
|
||||
name = "component_preview"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/component_preview.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
[dependencies]
|
||||
agent.workspace = true
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
db.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
languages.workspace = true
|
||||
log.workspace = true
|
||||
notifications.workspace = true
|
||||
project.workspace = true
|
||||
prompt_store.workspace = true
|
||||
serde.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
@@ -14,7 +14,6 @@ doctest = false
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
test-support = [
|
||||
"collections/test-support",
|
||||
"gpui/test-support",
|
||||
@@ -43,16 +42,15 @@ node_runtime.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
project.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
strum.workspace = true
|
||||
task.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
itertools.workspace = true
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
async-std = { version = "1.12.0", features = ["unstable"] }
|
||||
|
||||
@@ -5,7 +5,7 @@ mod sign_in;
|
||||
|
||||
use crate::sign_in::initiate_sign_in_within_workspace;
|
||||
use ::fs::Fs;
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use futures::{Future, FutureExt, TryFutureExt, channel::oneshot, future::Shared};
|
||||
@@ -531,11 +531,15 @@ impl Copilot {
|
||||
.request::<request::CheckStatus>(request::CheckStatusParams {
|
||||
local_checks_only: false,
|
||||
})
|
||||
.await?;
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: check status")?;
|
||||
|
||||
server
|
||||
.request::<request::SetEditorInfo>(editor_info)
|
||||
.await?;
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: set editor info")?;
|
||||
|
||||
anyhow::Ok((server, status))
|
||||
};
|
||||
@@ -581,7 +585,9 @@ impl Copilot {
|
||||
.request::<request::SignInInitiate>(
|
||||
request::SignInInitiateParams {},
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot sign-in")?;
|
||||
match sign_in {
|
||||
request::SignInInitiateResult::AlreadySignedIn { user } => {
|
||||
Ok(request::SignInStatus::Ok { user: Some(user) })
|
||||
@@ -609,7 +615,9 @@ impl Copilot {
|
||||
user_code: flow.user_code,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: sign in confirm")?;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
@@ -656,7 +664,9 @@ impl Copilot {
|
||||
cx.background_spawn(async move {
|
||||
server
|
||||
.request::<request::SignOut>(request::SignOutParams {})
|
||||
.await?;
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: sign in confirm")?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
@@ -873,7 +883,10 @@ impl Copilot {
|
||||
uuid: completion.uuid.clone(),
|
||||
});
|
||||
cx.background_spawn(async move {
|
||||
request.await?;
|
||||
request
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: notify accepted")?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
@@ -897,7 +910,10 @@ impl Copilot {
|
||||
.collect(),
|
||||
});
|
||||
cx.background_spawn(async move {
|
||||
request.await?;
|
||||
request
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: notify rejected")?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
@@ -957,7 +973,9 @@ impl Copilot {
|
||||
version: version.try_into().unwrap(),
|
||||
},
|
||||
})
|
||||
.await?;
|
||||
.await
|
||||
.into_response()
|
||||
.context("copilot: get completions")?;
|
||||
let completions = result
|
||||
.completions
|
||||
.into_iter()
|
||||
|
||||
@@ -9,13 +9,20 @@ use fs::Fs;
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use gpui::{App, AsyncApp, Global, prelude::*};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use itertools::Itertools;
|
||||
use paths::home_dir;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::watch_config_dir;
|
||||
use strum::EnumIter;
|
||||
|
||||
pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
|
||||
pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
pub const COPILOT_CHAT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
|
||||
|
||||
// Copilot's base model; defined by Microsoft in premium requests table
|
||||
// This will be moved to the front of the Copilot model list, and will be used for
|
||||
// 'fast' requests (e.g. title generation)
|
||||
// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests
|
||||
const DEFAULT_MODEL_ID: &str = "gpt-4.1";
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
@@ -25,132 +32,130 @@ pub enum Role {
|
||||
System,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
#[default]
|
||||
#[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
|
||||
Gpt4o,
|
||||
#[serde(alias = "gpt-4", rename = "gpt-4")]
|
||||
Gpt4,
|
||||
#[serde(alias = "gpt-4.1", rename = "gpt-4.1")]
|
||||
Gpt4_1,
|
||||
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
|
||||
Gpt3_5Turbo,
|
||||
#[serde(alias = "o1", rename = "o1")]
|
||||
O1,
|
||||
#[serde(alias = "o1-mini", rename = "o3-mini")]
|
||||
O3Mini,
|
||||
#[serde(alias = "o3", rename = "o3")]
|
||||
O3,
|
||||
#[serde(alias = "o4-mini", rename = "o4-mini")]
|
||||
O4Mini,
|
||||
#[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
|
||||
Claude3_5Sonnet,
|
||||
#[serde(alias = "claude-3-7-sonnet", rename = "claude-3.7-sonnet")]
|
||||
Claude3_7Sonnet,
|
||||
#[serde(
|
||||
alias = "claude-3.7-sonnet-thought",
|
||||
rename = "claude-3.7-sonnet-thought"
|
||||
)]
|
||||
Claude3_7SonnetThinking,
|
||||
#[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
|
||||
Gemini20Flash,
|
||||
#[serde(alias = "gemini-2.5-pro", rename = "gemini-2.5-pro")]
|
||||
Gemini25Pro,
|
||||
#[derive(Deserialize)]
|
||||
struct ModelSchema {
|
||||
#[serde(deserialize_with = "deserialize_models_skip_errors")]
|
||||
data: Vec<Model>,
|
||||
}
|
||||
|
||||
fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
|
||||
let models = raw_values
|
||||
.into_iter()
|
||||
.filter_map(|value| match serde_json::from_value::<Model>(value) {
|
||||
Ok(model) => Some(model),
|
||||
Err(err) => {
|
||||
log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct Model {
|
||||
capabilities: ModelCapabilities,
|
||||
id: String,
|
||||
name: String,
|
||||
policy: Option<ModelPolicy>,
|
||||
vendor: ModelVendor,
|
||||
model_picker_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelCapabilities {
|
||||
family: String,
|
||||
#[serde(default)]
|
||||
limits: ModelLimits,
|
||||
supports: ModelSupportedFeatures,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelLimits {
|
||||
#[serde(default)]
|
||||
max_context_window_tokens: usize,
|
||||
#[serde(default)]
|
||||
max_output_tokens: usize,
|
||||
#[serde(default)]
|
||||
max_prompt_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelPolicy {
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
struct ModelSupportedFeatures {
|
||||
#[serde(default)]
|
||||
streaming: bool,
|
||||
#[serde(default)]
|
||||
tool_calls: bool,
|
||||
#[serde(default)]
|
||||
parallel_tool_calls: bool,
|
||||
#[serde(default)]
|
||||
vision: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub enum ModelVendor {
|
||||
// Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
|
||||
#[serde(alias = "Azure OpenAI")]
|
||||
OpenAI,
|
||||
Google,
|
||||
Anthropic,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ChatMessagePart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
Image { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn default_fast() -> Self {
|
||||
Self::Claude3_7Sonnet
|
||||
}
|
||||
|
||||
pub fn uses_streaming(&self) -> bool {
|
||||
match self {
|
||||
Self::Gpt4o
|
||||
| Self::Gpt4
|
||||
| Self::Gpt4_1
|
||||
| Self::Gpt3_5Turbo
|
||||
| Self::O3
|
||||
| Self::O4Mini
|
||||
| Self::Claude3_5Sonnet
|
||||
| Self::Claude3_7Sonnet
|
||||
| Self::Claude3_7SonnetThinking => true,
|
||||
Self::O3Mini | Self::O1 | Self::Gemini20Flash | Self::Gemini25Pro => false,
|
||||
}
|
||||
self.capabilities.supports.streaming
|
||||
}
|
||||
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
match id {
|
||||
"gpt-4o" => Ok(Self::Gpt4o),
|
||||
"gpt-4" => Ok(Self::Gpt4),
|
||||
"gpt-4.1" => Ok(Self::Gpt4_1),
|
||||
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
|
||||
"o1" => Ok(Self::O1),
|
||||
"o3-mini" => Ok(Self::O3Mini),
|
||||
"o3" => Ok(Self::O3),
|
||||
"o4-mini" => Ok(Self::O4Mini),
|
||||
"claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
|
||||
"claude-3-7-sonnet" => Ok(Self::Claude3_7Sonnet),
|
||||
"claude-3.7-sonnet-thought" => Ok(Self::Claude3_7SonnetThinking),
|
||||
"gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
|
||||
"gemini-2.5-pro" => Ok(Self::Gemini25Pro),
|
||||
_ => Err(anyhow!("Invalid model id: {}", id)),
|
||||
}
|
||||
pub fn id(&self) -> &str {
|
||||
self.id.as_str()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
|
||||
Self::Gpt4 => "gpt-4",
|
||||
Self::Gpt4_1 => "gpt-4.1",
|
||||
Self::Gpt4o => "gpt-4o",
|
||||
Self::O3Mini => "o3-mini",
|
||||
Self::O1 => "o1",
|
||||
Self::O3 => "o3",
|
||||
Self::O4Mini => "o4-mini",
|
||||
Self::Claude3_5Sonnet => "claude-3-5-sonnet",
|
||||
Self::Claude3_7Sonnet => "claude-3-7-sonnet",
|
||||
Self::Claude3_7SonnetThinking => "claude-3.7-sonnet-thought",
|
||||
Self::Gemini20Flash => "gemini-2.0-flash-001",
|
||||
Self::Gemini25Pro => "gemini-2.5-pro",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Gpt3_5Turbo => "GPT-3.5",
|
||||
Self::Gpt4 => "GPT-4",
|
||||
Self::Gpt4_1 => "GPT-4.1",
|
||||
Self::Gpt4o => "GPT-4o",
|
||||
Self::O3Mini => "o3-mini",
|
||||
Self::O1 => "o1",
|
||||
Self::O3 => "o3",
|
||||
Self::O4Mini => "o4-mini",
|
||||
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
|
||||
Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
|
||||
Self::Claude3_7SonnetThinking => "Claude 3.7 Sonnet Thinking",
|
||||
Self::Gemini20Flash => "Gemini 2.0 Flash",
|
||||
Self::Gemini25Pro => "Gemini 2.5 Pro",
|
||||
}
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Self::Gpt4o => 64_000,
|
||||
Self::Gpt4 => 32_768,
|
||||
Self::Gpt4_1 => 128_000,
|
||||
Self::Gpt3_5Turbo => 12_288,
|
||||
Self::O3Mini => 64_000,
|
||||
Self::O1 => 20_000,
|
||||
Self::O3 => 128_000,
|
||||
Self::O4Mini => 128_000,
|
||||
Self::Claude3_5Sonnet => 200_000,
|
||||
Self::Claude3_7Sonnet => 90_000,
|
||||
Self::Claude3_7SonnetThinking => 90_000,
|
||||
Self::Gemini20Flash => 128_000,
|
||||
Self::Gemini25Pro => 128_000,
|
||||
}
|
||||
self.capabilities.limits.max_prompt_tokens
|
||||
}
|
||||
|
||||
pub fn supports_tools(&self) -> bool {
|
||||
self.capabilities.supports.tool_calls
|
||||
}
|
||||
|
||||
pub fn vendor(&self) -> ModelVendor {
|
||||
self.vendor
|
||||
}
|
||||
|
||||
pub fn supports_vision(&self) -> bool {
|
||||
self.capabilities.supports.vision
|
||||
}
|
||||
|
||||
pub fn supports_parallel_tool_calls(&self) -> bool {
|
||||
self.capabilities.supports.parallel_tool_calls
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,7 +165,7 @@ pub struct Request {
|
||||
pub n: usize,
|
||||
pub stream: bool,
|
||||
pub temperature: f32,
|
||||
pub model: Model,
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<Tool>,
|
||||
@@ -189,26 +194,55 @@ pub enum ToolChoice {
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum ChatMessage {
|
||||
Assistant {
|
||||
content: Option<String>,
|
||||
content: ChatMessageContent,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
content: ChatMessageContent,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
},
|
||||
Tool {
|
||||
content: String,
|
||||
content: ChatMessageContent,
|
||||
tool_call_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessageContent {
|
||||
Plain(String),
|
||||
Multipart(Vec<ChatMessagePart>),
|
||||
}
|
||||
|
||||
impl ChatMessageContent {
|
||||
pub fn empty() -> Self {
|
||||
ChatMessageContent::Multipart(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<ChatMessagePart>> for ChatMessageContent {
|
||||
fn from(mut parts: Vec<ChatMessagePart>) -> Self {
|
||||
if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
|
||||
ChatMessageContent::Plain(std::mem::take(text))
|
||||
} else {
|
||||
ChatMessageContent::Multipart(parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ChatMessageContent {
|
||||
fn from(text: String) -> Self {
|
||||
ChatMessageContent::Plain(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
@@ -232,7 +266,6 @@ pub struct FunctionContent {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub struct ResponseEvent {
|
||||
pub choices: Vec<ResponseChoice>,
|
||||
pub created: u64,
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
@@ -306,6 +339,7 @@ impl Global for GlobalCopilotChat {}
|
||||
pub struct CopilotChat {
|
||||
oauth_token: Option<String>,
|
||||
api_token: Option<ApiToken>,
|
||||
models: Option<Vec<Model>>,
|
||||
client: Arc<dyn HttpClient>,
|
||||
}
|
||||
|
||||
@@ -342,31 +376,56 @@ impl CopilotChat {
|
||||
let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
|
||||
let dir_path = copilot_chat_config_dir();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let mut parent_watch_rx = watch_config_dir(
|
||||
cx.background_executor(),
|
||||
fs.clone(),
|
||||
dir_path.clone(),
|
||||
config_paths,
|
||||
);
|
||||
while let Some(contents) = parent_watch_rx.next().await {
|
||||
let oauth_token = extract_oauth_token(contents);
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.oauth_token = oauth_token;
|
||||
cx.notify();
|
||||
});
|
||||
cx.spawn({
|
||||
let client = client.clone();
|
||||
async move |cx| {
|
||||
let mut parent_watch_rx = watch_config_dir(
|
||||
cx.background_executor(),
|
||||
fs.clone(),
|
||||
dir_path.clone(),
|
||||
config_paths,
|
||||
);
|
||||
while let Some(contents) = parent_watch_rx.next().await {
|
||||
let oauth_token = extract_oauth_token(contents);
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.oauth_token = oauth_token.clone();
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
if let Some(ref oauth_token) = oauth_token {
|
||||
let api_token = request_api_token(oauth_token, client.clone()).await?;
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_token = Some(api_token.clone());
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})?;
|
||||
let models = get_models(api_token.api_key, client.clone()).await?;
|
||||
cx.update(|cx| {
|
||||
if let Some(this) = Self::global(cx).as_ref() {
|
||||
this.update(cx, |this, cx| {
|
||||
this.models = Some(models);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
})?;
|
||||
}
|
||||
})?;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
Self {
|
||||
oauth_token: None,
|
||||
api_token: None,
|
||||
models: None,
|
||||
client,
|
||||
}
|
||||
}
|
||||
@@ -375,6 +434,10 @@ impl CopilotChat {
|
||||
self.oauth_token.is_some()
|
||||
}
|
||||
|
||||
pub fn models(&self) -> Option<&[Model]> {
|
||||
self.models.as_deref()
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
request: Request,
|
||||
mut cx: AsyncApp,
|
||||
@@ -409,6 +472,61 @@ impl CopilotChat {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
|
||||
let all_models = request_models(api_token, client).await?;
|
||||
|
||||
let mut models: Vec<Model> = all_models
|
||||
.into_iter()
|
||||
.filter(|model| {
|
||||
// Ensure user has access to the model; Policy is present only for models that must be
|
||||
// enabled in the GitHub dashboard
|
||||
model.model_picker_enabled
|
||||
&& model
|
||||
.policy
|
||||
.as_ref()
|
||||
.is_none_or(|policy| policy.state == "enabled")
|
||||
})
|
||||
// The first model from the API response, in any given family, appear to be the non-tagged
|
||||
// models, which are likely the best choice (e.g. gpt-4o rather than gpt-4o-2024-11-20)
|
||||
.dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
|
||||
.collect();
|
||||
|
||||
if let Some(default_model_position) =
|
||||
models.iter().position(|model| model.id == DEFAULT_MODEL_ID)
|
||||
{
|
||||
let default_model = models.remove(default_model_position);
|
||||
models.insert(0, default_model);
|
||||
}
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
async fn request_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(COPILOT_CHAT_MODELS_URL)
|
||||
.header("Authorization", format!("Bearer {}", api_token))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Copilot-Integration-Id", "vscode-chat");
|
||||
|
||||
let request = request_builder.body(AsyncBody::empty())?;
|
||||
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
|
||||
let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
|
||||
|
||||
Ok(models)
|
||||
} else {
|
||||
Err(anyhow!("Failed to request models: {}", response.status()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
@@ -472,7 +590,8 @@ async fn stream_completion(
|
||||
)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Copilot-Integration-Id", "vscode-chat");
|
||||
.header("Copilot-Integration-Id", "vscode-chat")
|
||||
.header("Copilot-Vision-Request", "true");
|
||||
|
||||
let is_streaming = request.stream;
|
||||
|
||||
@@ -527,3 +646,82 @@ async fn stream_completion(
|
||||
Ok(futures::stream::once(async move { Ok(response) }).boxed())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_resilient_model_schema_deserialize() {
|
||||
let json = r#"{
|
||||
"data": [
|
||||
{
|
||||
"capabilities": {
|
||||
"family": "gpt-4",
|
||||
"limits": {
|
||||
"max_context_window_tokens": 32768,
|
||||
"max_output_tokens": 4096,
|
||||
"max_prompt_tokens": 32768
|
||||
},
|
||||
"object": "model_capabilities",
|
||||
"supports": { "streaming": true, "tool_calls": true },
|
||||
"tokenizer": "cl100k_base",
|
||||
"type": "chat"
|
||||
},
|
||||
"id": "gpt-4",
|
||||
"model_picker_enabled": false,
|
||||
"name": "GPT 4",
|
||||
"object": "model",
|
||||
"preview": false,
|
||||
"vendor": "Azure OpenAI",
|
||||
"version": "gpt-4-0613"
|
||||
},
|
||||
{
|
||||
"some-unknown-field": 123
|
||||
},
|
||||
{
|
||||
"capabilities": {
|
||||
"family": "claude-3.7-sonnet",
|
||||
"limits": {
|
||||
"max_context_window_tokens": 200000,
|
||||
"max_output_tokens": 16384,
|
||||
"max_prompt_tokens": 90000,
|
||||
"vision": {
|
||||
"max_prompt_image_size": 3145728,
|
||||
"max_prompt_images": 1,
|
||||
"supported_media_types": ["image/jpeg", "image/png", "image/webp"]
|
||||
}
|
||||
},
|
||||
"object": "model_capabilities",
|
||||
"supports": {
|
||||
"parallel_tool_calls": true,
|
||||
"streaming": true,
|
||||
"tool_calls": true,
|
||||
"vision": true
|
||||
},
|
||||
"tokenizer": "o200k_base",
|
||||
"type": "chat"
|
||||
},
|
||||
"id": "claude-3.7-sonnet",
|
||||
"model_picker_enabled": true,
|
||||
"name": "Claude 3.7 Sonnet",
|
||||
"object": "model",
|
||||
"policy": {
|
||||
"state": "enabled",
|
||||
"terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
|
||||
},
|
||||
"preview": false,
|
||||
"vendor": "Anthropic",
|
||||
"version": "claude-3.7-sonnet"
|
||||
}
|
||||
],
|
||||
"object": "list"
|
||||
}"#;
|
||||
|
||||
let schema: ModelSchema = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(schema.data.len(), 2);
|
||||
assert_eq!(schema.data[0].id, "gpt-4");
|
||||
assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_tar::Archive;
|
||||
use async_trait::async_trait;
|
||||
use collections::HashMap;
|
||||
use dap_types::{StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest};
|
||||
pub use dap_types::{StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest};
|
||||
use futures::io::BufReader;
|
||||
use gpui::{AsyncApp, SharedString};
|
||||
pub use http_client::{HttpClient, github::latest_github_release};
|
||||
use language::LanguageToolchainStore;
|
||||
use language::{LanguageName, LanguageToolchainStore};
|
||||
use node_runtime::NodeRuntime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::WorktreeId;
|
||||
@@ -418,6 +418,11 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<DebugAdapterBinary>;
|
||||
|
||||
/// Returns the language name of an adapter if it only supports one language
|
||||
fn adapter_language_name(&self) -> Option<LanguageName> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
|
||||
@@ -7,21 +7,14 @@ use dap_types::{
|
||||
messages::{Message, Response},
|
||||
requests::Request,
|
||||
};
|
||||
use futures::{FutureExt as _, channel::oneshot, select};
|
||||
use gpui::{AppContext, AsyncApp, BackgroundExecutor};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{AppContext, AsyncApp};
|
||||
use smol::channel::{Receiver, Sender};
|
||||
use std::{
|
||||
hash::Hash,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(2);
|
||||
|
||||
#[cfg(not(any(test, feature = "test-support")))]
|
||||
const DAP_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[repr(transparent)]
|
||||
pub struct SessionId(pub u32);
|
||||
@@ -41,7 +34,6 @@ pub struct DebugAdapterClient {
|
||||
id: SessionId,
|
||||
sequence_count: AtomicU64,
|
||||
binary: DebugAdapterBinary,
|
||||
executor: BackgroundExecutor,
|
||||
transport_delegate: TransportDelegate,
|
||||
}
|
||||
|
||||
@@ -61,7 +53,6 @@ impl DebugAdapterClient {
|
||||
binary,
|
||||
transport_delegate,
|
||||
sequence_count: AtomicU64::new(1),
|
||||
executor: cx.background_executor().clone(),
|
||||
};
|
||||
log::info!("Successfully connected to debug adapter");
|
||||
|
||||
@@ -173,40 +164,33 @@ impl DebugAdapterClient {
|
||||
|
||||
self.send_message(Message::Request(request)).await?;
|
||||
|
||||
let mut timeout = self.executor.timer(DAP_REQUEST_TIMEOUT).fuse();
|
||||
let command = R::COMMAND.to_string();
|
||||
|
||||
select! {
|
||||
response = callback_rx.fuse() => {
|
||||
log::debug!(
|
||||
"Client {} received response for: `{}` sequence_id: {}",
|
||||
self.id.0,
|
||||
command,
|
||||
sequence_id
|
||||
);
|
||||
|
||||
let response = response??;
|
||||
match response.success {
|
||||
true => {
|
||||
if let Some(json) = response.body {
|
||||
Ok(serde_json::from_value(json)?)
|
||||
// Note: dap types configure themselves to return `None` when an empty object is received,
|
||||
// which then fails here...
|
||||
} else if let Ok(result) = serde_json::from_value(serde_json::Value::Object(Default::default())) {
|
||||
Ok(result)
|
||||
} else {
|
||||
Ok(serde_json::from_value(Default::default())?)
|
||||
}
|
||||
}
|
||||
false => Err(anyhow!("Request failed: {}", response.message.unwrap_or_default())),
|
||||
let response = callback_rx.await??;
|
||||
log::debug!(
|
||||
"Client {} received response for: `{}` sequence_id: {}",
|
||||
self.id.0,
|
||||
command,
|
||||
sequence_id
|
||||
);
|
||||
match response.success {
|
||||
true => {
|
||||
if let Some(json) = response.body {
|
||||
Ok(serde_json::from_value(json)?)
|
||||
// Note: dap types configure themselves to return `None` when an empty object is received,
|
||||
// which then fails here...
|
||||
} else if let Ok(result) =
|
||||
serde_json::from_value(serde_json::Value::Object(Default::default()))
|
||||
{
|
||||
Ok(result)
|
||||
} else {
|
||||
Ok(serde_json::from_value(Default::default())?)
|
||||
}
|
||||
}
|
||||
|
||||
_ = timeout => {
|
||||
self.transport_delegate.cancel_pending_request(&sequence_id).await;
|
||||
log::error!("Cancelled DAP request for {command:?} id {sequence_id} which took over {DAP_REQUEST_TIMEOUT:?}");
|
||||
anyhow::bail!("DAP request timeout");
|
||||
}
|
||||
false => Err(anyhow!(
|
||||
"Request failed: {}",
|
||||
response.message.unwrap_or_default()
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ pub struct InlineValueLocation {
|
||||
/// during debugging sessions. Implementors must also handle variable scoping
|
||||
/// themselves by traversing the syntax tree upwards to determine whether a
|
||||
/// variable is local or global.
|
||||
pub trait InlineValueProvider {
|
||||
pub trait InlineValueProvider: 'static + Send + Sync {
|
||||
/// Provides a list of inline value locations based on the given node and source code.
|
||||
///
|
||||
/// # Parameters
|
||||
|
||||
@@ -2,6 +2,7 @@ use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use collections::FxHashMap;
|
||||
use gpui::{App, Global, SharedString};
|
||||
use language::LanguageName;
|
||||
use parking_lot::RwLock;
|
||||
use task::{DebugRequest, DebugScenario, SpawnInTerminal, TaskTemplate};
|
||||
|
||||
@@ -59,6 +60,11 @@ impl DapRegistry {
|
||||
);
|
||||
}
|
||||
|
||||
pub fn adapter_language(&self, adapter_name: &str) -> Option<LanguageName> {
|
||||
self.adapter(adapter_name)
|
||||
.and_then(|adapter| adapter.adapter_language_name())
|
||||
}
|
||||
|
||||
pub fn add_locator(&self, locator: Arc<dyn DapLocator>) {
|
||||
let _previous_value = self.0.write().locators.insert(locator.name(), locator);
|
||||
debug_assert!(
|
||||
|
||||
@@ -224,11 +224,6 @@ impl TransportDelegate {
|
||||
pending_requests.insert(sequence_id, request);
|
||||
}
|
||||
|
||||
pub(crate) async fn cancel_pending_request(&self, sequence_id: &u64) {
|
||||
let mut pending_requests = self.pending_requests.lock().await;
|
||||
pending_requests.remove(sequence_id);
|
||||
}
|
||||
|
||||
pub(crate) async fn send_message(&self, message: Message) -> Result<()> {
|
||||
if let Some(server_tx) = self.server_tx.lock().await.as_ref() {
|
||||
server_tx
|
||||
|
||||
@@ -42,7 +42,9 @@ impl CodeLldbDebugAdapter {
|
||||
if !launch.args.is_empty() {
|
||||
map.insert("args".into(), launch.args.clone().into());
|
||||
}
|
||||
|
||||
if !launch.env.is_empty() {
|
||||
map.insert("env".into(), launch.env_json());
|
||||
}
|
||||
if let Some(stop_on_entry) = config.stop_on_entry {
|
||||
map.insert("stopOnEntry".into(), stop_on_entry.into());
|
||||
}
|
||||
|
||||
@@ -35,6 +35,10 @@ impl GdbDebugAdapter {
|
||||
map.insert("args".into(), launch.args.clone().into());
|
||||
}
|
||||
|
||||
if !launch.env.is_empty() {
|
||||
map.insert("env".into(), launch.env_json());
|
||||
}
|
||||
|
||||
if let Some(stop_on_entry) = config.stop_on_entry {
|
||||
map.insert(
|
||||
"stopAtBeginningOfMainSubprogram".into(),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use dap::{StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
|
||||
use gpui::AsyncApp;
|
||||
use gpui::{AsyncApp, SharedString};
|
||||
use language::LanguageName;
|
||||
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
|
||||
|
||||
use crate::*;
|
||||
@@ -19,7 +20,8 @@ impl GoDebugAdapter {
|
||||
dap::DebugRequest::Launch(launch_config) => json!({
|
||||
"program": launch_config.program,
|
||||
"cwd": launch_config.cwd,
|
||||
"args": launch_config.args
|
||||
"args": launch_config.args,
|
||||
"env": launch_config.env_json()
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -42,6 +44,10 @@ impl DebugAdapter for GoDebugAdapter {
|
||||
DebugAdapterName(Self::ADAPTER_NAME.into())
|
||||
}
|
||||
|
||||
fn adapter_language_name(&self) -> Option<LanguageName> {
|
||||
Some(SharedString::new_static("Go").into())
|
||||
}
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user