Compare commits

..

1 Commits

Author SHA1 Message Date
Richard Feldman
9b31a6f4d2 wip 2025-05-19 15:55:30 -04:00
419 changed files with 9878 additions and 16016 deletions

View File

@@ -29,8 +29,8 @@ body:
id: environment
attributes:
label: Zed Version and System Specs
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
placeholder: |
Output of "zed: copy system specs into clipboard"
Output of "zed: Copy System Specs Into Clipboard"
validations:
required: true

View File

@@ -29,8 +29,8 @@ body:
id: environment
attributes:
label: Zed Version and System Specs
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
placeholder: |
Output of "zed: copy system specs into clipboard"
Output of "zed: Copy System Specs Into Clipboard"
validations:
required: true

View File

@@ -28,8 +28,8 @@ body:
id: environment
attributes:
label: Zed Version and System Specs
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
placeholder: |
Output of "zed: copy system specs into clipboard"
Output of "zed: Copy System Specs Into Clipboard"
validations:
required: true

View File

@@ -1,35 +0,0 @@
name: Bug Report (Debugger)
description: Zed Debugger-Related Bugs
type: "Bug"
labels: ["debugger"]
title: "Debugger: <a short description of the Debugger bug>"
body:
- type: textarea
attributes:
label: Summary
description: Describe the bug with a one line summary, and provide detailed reproduction steps
value: |
<!-- Please insert a one line summary of the issue below -->
SUMMARY_SENTENCE_HERE
### Description
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
Steps to trigger the problem:
1.
2.
3.
Actual Behavior:
Expected Behavior:
validations:
required: true
- type: textarea
id: environment
attributes:
label: Zed Version and System Specs
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
placeholder: |
Output of "zed: copy system specs into clipboard"
validations:
required: true

View File

@@ -49,8 +49,8 @@ body:
attributes:
label: Zed Version and System Specs
description: |
Open Zed, from the command palette select "zed: copy system specs into clipboard"
Open Zed, from the command palette select "zed: Copy System Specs Into Clipboard"
placeholder: |
Output of "zed: copy system specs into clipboard"
Output of "zed: Copy System Specs Into Clipboard"
validations:
required: true

View File

@@ -26,9 +26,9 @@ body:
id: environment
attributes:
label: Zed Version and System Specs
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
description: 'Open Zed, and in the command palette select "zed: Copy System Specs Into Clipboard"'
placeholder: |
Output of "zed: copy system specs into clipboard"
Output of "zed: Copy System Specs Into Clipboard"
validations:
required: true
- type: textarea

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
**/cargo-target
**/target
**/venv
**/.direnv
*.wasm
*.xcodeproj
.DS_Store

View File

@@ -2,14 +2,16 @@
{
"label": "Debug Zed (CodeLLDB)",
"adapter": "CodeLLDB",
"program": "target/debug/zed",
"request": "launch"
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT"
},
{
"label": "Debug Zed (GDB)",
"adapter": "GDB",
"program": "target/debug/zed",
"program": "$ZED_WORKTREE_ROOT/target/debug/zed",
"request": "launch",
"cwd": "$ZED_WORKTREE_ROOT",
"initialize_args": {
"stopAtBeginningOfMainSubprogram": true
}

341
Cargo.lock generated
View File

@@ -81,12 +81,12 @@ dependencies = [
"http_client",
"indexed_docs",
"indoc",
"inventory",
"itertools 0.14.0",
"jsonschema",
"language",
"language_model",
"language_model_selector",
"linkme",
"log",
"lsp",
"markdown",
@@ -513,6 +513,7 @@ dependencies = [
"settings",
"smallvec",
"smol",
"strum 0.27.1",
"telemetry_events",
"text",
"theme",
@@ -674,8 +675,8 @@ dependencies = [
"language",
"language_model",
"language_models",
"linkme",
"log",
"markdown",
"open",
"paths",
"portable-pty",
@@ -1269,9 +1270,9 @@ dependencies = [
[[package]]
name = "aws-lc-rs"
version = "1.13.1"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7"
checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878"
dependencies = [
"aws-lc-sys",
"zeroize",
@@ -1279,9 +1280,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.29.0"
version = "0.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079"
checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1"
dependencies = [
"bindgen 0.69.5",
"cc",
@@ -2028,7 +2029,7 @@ dependencies = [
[[package]]
name = "blade-graphics"
version = "0.6.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=b16f5c7bd873c7126f48c82c39e7ae64602ae74f#b16f5c7bd873c7126f48c82c39e7ae64602ae74f"
dependencies = [
"ash",
"ash-window",
@@ -2047,7 +2048,6 @@ dependencies = [
"naga",
"objc2",
"objc2-app-kit",
"objc2-core-foundation",
"objc2-foundation",
"objc2-metal",
"objc2-quartz-core",
@@ -2061,7 +2061,7 @@ dependencies = [
[[package]]
name = "blade-macros"
version = "0.3.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=b16f5c7bd873c7126f48c82c39e7ae64602ae74f#b16f5c7bd873c7126f48c82c39e7ae64602ae74f"
dependencies = [
"proc-macro2",
"quote",
@@ -2071,7 +2071,7 @@ dependencies = [
[[package]]
name = "blade-util"
version = "0.2.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=b16f5c7bd873c7126f48c82c39e7ae64602ae74f#b16f5c7bd873c7126f48c82c39e7ae64602ae74f"
dependencies = [
"blade-graphics",
"bytemuck",
@@ -2118,9 +2118,9 @@ dependencies = [
[[package]]
name = "block2"
version = "0.6.1"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "340d2f0bdb2a43c1d3cd40513185b2bd7def0aa1052f956455114bc98f82dcf2"
checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f"
dependencies = [
"objc2",
]
@@ -2792,7 +2792,6 @@ dependencies = [
"anyhow",
"async-recursion 0.3.2",
"async-tungstenite",
"base64 0.22.1",
"chrono",
"clock",
"cocoa 0.26.0",
@@ -2804,7 +2803,6 @@ dependencies = [
"gpui_tokio",
"http_client",
"http_client_tls",
"httparse",
"log",
"parking_lot",
"paths",
@@ -2825,7 +2823,6 @@ dependencies = [
"time",
"tiny_http",
"tokio",
"tokio-native-tls",
"tokio-socks",
"url",
"util",
@@ -3070,7 +3067,6 @@ dependencies = [
"gpui",
"http_client",
"language",
"log",
"menu",
"notifications",
"picker",
@@ -3179,13 +3175,39 @@ version = "0.1.0"
dependencies = [
"collections",
"gpui",
"inventory",
"linkme",
"parking_lot",
"strum 0.27.1",
"theme",
"workspace-hack",
]
[[package]]
name = "component_preview"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"client",
"collections",
"component",
"db",
"futures 0.3.31",
"gpui",
"languages",
"log",
"notifications",
"project",
"prompt_store",
"serde",
"ui",
"ui_input",
"util",
"workspace",
"workspace-hack",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
@@ -3311,7 +3333,6 @@ dependencies = [
"http_client",
"indoc",
"inline_completion",
"itertools 0.14.0",
"language",
"log",
"lsp",
@@ -3321,9 +3342,11 @@ dependencies = [
"paths",
"project",
"rpc",
"schemars",
"serde",
"serde_json",
"settings",
"strum 0.27.1",
"task",
"theme",
"ui",
@@ -3509,9 +3532,9 @@ dependencies = [
[[package]]
name = "cosmic-text"
version = "0.14.0"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e1ecbb5db9a4c2ee642df67bcfa8f044dd867dbbaa21bfab139cbc204ffbf67"
checksum = "e418dd4f5128c3e93eab12246391c54a20c496811131f85754dc8152ee207892"
dependencies = [
"bitflags 2.9.0",
"fontdb 0.16.2",
@@ -4136,18 +4159,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "debug_adapter_extension"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"dap",
"extension",
"gpui",
"workspace-hack",
]
[[package]]
name = "debugger_tools"
version = "0.1.0"
@@ -4181,7 +4192,6 @@ dependencies = [
"editor",
"env_logger 0.11.8",
"feature_flags",
"file_icons",
"futures 0.3.31",
"fuzzy",
"gpui",
@@ -4189,7 +4199,6 @@ dependencies = [
"log",
"menu",
"parking_lot",
"paths",
"picker",
"pretty_assertions",
"project",
@@ -4197,7 +4206,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
"shlex",
"sysinfo",
"task",
"tasks_ui",
@@ -4329,6 +4337,7 @@ dependencies = [
"gpui",
"indoc",
"language",
"linkme",
"log",
"lsp",
"markdown",
@@ -4447,16 +4456,6 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b"
[[package]]
name = "dispatch2"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec"
dependencies = [
"bitflags 2.9.0",
"objc2",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
@@ -5052,7 +5051,6 @@ dependencies = [
"async-tar",
"async-trait",
"collections",
"dap",
"fs",
"futures 0.3.31",
"gpui",
@@ -5065,10 +5063,8 @@ dependencies = [
"semantic_version",
"serde",
"serde_json",
"task",
"toml 0.8.20",
"util",
"wasi-preview1-component-adapter-provider",
"wasm-encoder 0.221.3",
"wasmparser 0.221.3",
"wit-component 0.221.3",
@@ -5110,7 +5106,6 @@ dependencies = [
"client",
"collections",
"ctor",
"dap",
"env_logger 0.11.8",
"extension",
"fs",
@@ -5163,7 +5158,6 @@ dependencies = [
"fuzzy",
"gpui",
"language",
"log",
"num-format",
"picker",
"project",
@@ -6028,6 +6022,7 @@ dependencies = [
"language",
"language_model",
"linkify",
"linkme",
"log",
"markdown",
"menu",
@@ -6351,7 +6346,6 @@ checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9"
dependencies = [
"cfg-if",
"crunchy",
"num-traits",
]
[[package]]
@@ -7243,6 +7237,7 @@ dependencies = [
"lsp",
"paths",
"project",
"proto",
"regex",
"serde_json",
"settings",
@@ -7250,6 +7245,7 @@ dependencies = [
"telemetry",
"theme",
"ui",
"util",
"workspace",
"workspace-hack",
"zed_actions",
@@ -7804,12 +7800,9 @@ version = "0.1.0"
dependencies = [
"collections",
"feature_flags",
"futures 0.3.31",
"fuzzy",
"gpui",
"language_model",
"log",
"ordered-float 2.10.1",
"picker",
"proto",
"ui",
@@ -8164,6 +8157,26 @@ dependencies = [
"memchr",
]
[[package]]
name = "linkme"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22d227772b5999ddc0690e733f734f95ca05387e329c4084fe65678c51198ffe"
dependencies = [
"linkme-impl",
]
[[package]]
name = "linkme-impl"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71a98813fa0073a317ed6a8055dcd4722a49d9b862af828ee68449adb799b6be"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
@@ -8503,7 +8516,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"assets",
"base64 0.22.1",
"env_logger 0.11.8",
"gpui",
"language",
@@ -8892,27 +8904,23 @@ checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "naga"
version = "25.0.1"
version = "23.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b977c445f26e49757f9aca3631c3b8b836942cb278d69a92e7b80d3b24da632"
checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f"
dependencies = [
"arrayvec",
"bit-set 0.8.0",
"bitflags 2.9.0",
"cfg_aliases 0.2.1",
"codespan-reporting 0.12.0",
"half",
"hashbrown 0.15.2",
"cfg_aliases 0.1.1",
"codespan-reporting 0.11.1",
"hexf-parse",
"indexmap",
"log",
"num-traits",
"once_cell",
"rustc-hash 1.1.0",
"spirv",
"strum 0.26.3",
"thiserror 2.0.12",
"unicode-ident",
"termcolor",
"thiserror 1.0.69",
"unicode-xid",
]
[[package]]
@@ -9086,6 +9094,7 @@ dependencies = [
"component",
"db",
"gpui",
"linkme",
"rpc",
"settings",
"sum_tree",
@@ -9366,36 +9375,95 @@ dependencies = [
]
[[package]]
name = "objc2"
version = "0.6.1"
name = "objc-sys"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551"
checksum = "cdb91bdd390c7ce1a8607f35f3ca7151b65afc0ff5ff3b34fa350f7d7c7e4310"
[[package]]
name = "objc2"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46a785d4eeff09c14c487497c162e92766fbb3e4059a71840cecc03d9a50b804"
dependencies = [
"objc-sys",
"objc2-encode",
]
[[package]]
name = "objc2-app-kit"
version = "0.3.1"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6f29f568bec459b0ddff777cec4fe3fd8666d82d5a40ebd0ff7e66134f89bcc"
checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff"
dependencies = [
"bitflags 2.9.0",
"block2",
"libc",
"objc2",
"objc2-core-foundation",
"objc2-core-data",
"objc2-core-image",
"objc2-foundation",
"objc2-quartz-core",
]
[[package]]
name = "objc2-core-foundation"
version = "0.3.1"
name = "objc2-cloud-kit"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166"
checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009"
dependencies = [
"bitflags 2.9.0",
"dispatch2",
"block2",
"objc2",
"objc2-core-location",
"objc2-foundation",
]
[[package]]
name = "objc2-contacts"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889"
dependencies = [
"block2",
"objc2",
"objc2-foundation",
]
[[package]]
name = "objc2-core-data"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef"
dependencies = [
"bitflags 2.9.0",
"block2",
"objc2",
"objc2-foundation",
]
[[package]]
name = "objc2-core-image"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80"
dependencies = [
"block2",
"objc2",
"objc2-foundation",
"objc2-metal",
]
[[package]]
name = "objc2-core-location"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781"
dependencies = [
"block2",
"objc2",
"objc2-contacts",
"objc2-foundation",
]
[[package]]
@@ -9406,53 +9474,106 @@ checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33"
[[package]]
name = "objc2-foundation"
version = "0.3.1"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c"
checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8"
dependencies = [
"bitflags 2.9.0",
"block2",
"libc",
"objc2",
"objc2-core-foundation",
]
[[package]]
name = "objc2-link-presentation"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398"
dependencies = [
"block2",
"objc2",
"objc2-app-kit",
"objc2-foundation",
]
[[package]]
name = "objc2-metal"
version = "0.3.1"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f246c183239540aab1782457b35ab2040d4259175bd1d0c58e46ada7b47a874"
checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6"
dependencies = [
"bitflags 2.9.0",
"block2",
"dispatch2",
"objc2",
"objc2-core-foundation",
"objc2-foundation",
]
[[package]]
name = "objc2-quartz-core"
version = "0.3.1"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ffb6a0cd5f182dc964334388560b12a57f7b74b3e2dec5e2722aa2dfb2ccd5"
checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a"
dependencies = [
"bitflags 2.9.0",
"block2",
"objc2",
"objc2-core-foundation",
"objc2-foundation",
"objc2-metal",
]
[[package]]
name = "objc2-ui-kit"
version = "0.3.1"
name = "objc2-symbols"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25b1312ad7bc8a0e92adae17aa10f90aae1fb618832f9b993b022b591027daed"
checksum = "0a684efe3dec1b305badae1a28f6555f6ddd3bb2c2267896782858d5a78404dc"
dependencies = [
"objc2",
"objc2-foundation",
]
[[package]]
name = "objc2-ui-kit"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f"
dependencies = [
"bitflags 2.9.0",
"block2",
"objc2",
"objc2-core-foundation",
"objc2-cloud-kit",
"objc2-core-data",
"objc2-core-image",
"objc2-core-location",
"objc2-foundation",
"objc2-link-presentation",
"objc2-quartz-core",
"objc2-symbols",
"objc2-uniform-type-identifiers",
"objc2-user-notifications",
]
[[package]]
name = "objc2-uniform-type-identifiers"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe"
dependencies = [
"block2",
"objc2",
"objc2-foundation",
]
[[package]]
name = "objc2-user-notifications"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3"
dependencies = [
"bitflags 2.9.0",
"block2",
"objc2",
"objc2-core-location",
"objc2-foundation",
]
[[package]]
@@ -10034,7 +10155,7 @@ name = "perplexity"
version = "0.1.0"
dependencies = [
"serde",
"zed_extension_api 0.6.0",
"zed_extension_api 0.5.0",
]
[[package]]
@@ -11831,8 +11952,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"askpass",
"assistant_tool",
"assistant_tools",
"async-watch",
"backtrace",
"cargo_toml",
@@ -11855,7 +11974,6 @@ dependencies = [
"http_client",
"language",
"language_extension",
"language_model",
"languages",
"libc",
"log",
@@ -15666,6 +15784,7 @@ dependencies = [
"gpui",
"icons",
"itertools 0.14.0",
"linkme",
"menu",
"serde",
"settings",
@@ -15686,6 +15805,7 @@ dependencies = [
"component",
"editor",
"gpui",
"linkme",
"settings",
"theme",
"ui",
@@ -15697,6 +15817,7 @@ name = "ui_macros"
version = "0.1.0"
dependencies = [
"convert_case 0.8.0",
"linkme",
"proc-macro2",
"quote",
"syn 1.0.109",
@@ -16215,12 +16336,6 @@ dependencies = [
"wit-bindgen-rt 0.39.0",
]
[[package]]
name = "wasi-preview1-component-adapter-provider"
version = "29.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcd9f21bbde82ba59e415a8725e6ad0d0d7e9e460b1a3ccbca5bdee952c1a324"
[[package]]
name = "wasite"
version = "0.1.0"
@@ -16923,6 +17038,7 @@ dependencies = [
"gpui",
"install_cli",
"language",
"linkme",
"picker",
"project",
"schemars",
@@ -18016,6 +18132,7 @@ dependencies = [
"aho-corasick",
"anstream",
"arrayvec",
"async-compression",
"async-std",
"async-tungstenite",
"aws-config",
@@ -18031,7 +18148,6 @@ dependencies = [
"base64ct",
"bigdecimal",
"bit-set 0.8.0",
"bit-vec 0.8.0",
"bitflags 2.9.0",
"bstr",
"bytemuck",
@@ -18043,7 +18159,6 @@ dependencies = [
"clang-sys",
"clap",
"clap_builder",
"codespan-reporting 0.12.0",
"concurrent-queue",
"core-foundation 0.9.4",
"core-foundation-sys",
@@ -18072,7 +18187,6 @@ dependencies = [
"getrandom 0.2.15",
"getrandom 0.3.2",
"gimli",
"half",
"handlebars 4.5.0",
"hashbrown 0.14.5",
"hashbrown 0.15.2",
@@ -18106,9 +18220,6 @@ dependencies = [
"num-iter",
"num-rational",
"num-traits",
"objc2",
"objc2-foundation",
"objc2-metal",
"object",
"once_cell",
"percent-encoding",
@@ -18127,7 +18238,6 @@ dependencies = [
"regex-syntax 0.8.5",
"ring",
"rust_decimal",
"rustc-hash 1.1.0",
"rustix 0.38.44",
"rustix 1.0.5",
"rustls 0.23.26",
@@ -18523,7 +18633,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.188.0"
version = "0.187.0"
dependencies = [
"activity_indicator",
"agent",
@@ -18533,7 +18643,6 @@ dependencies = [
"assets",
"assistant_context_editor",
"assistant_settings",
"assistant_tool",
"assistant_tools",
"async-watch",
"audio",
@@ -18550,7 +18659,7 @@ dependencies = [
"collab_ui",
"collections",
"command_palette",
"component",
"component_preview",
"copilot",
"dap",
"dap_adapters",
@@ -18576,7 +18685,6 @@ dependencies = [
"gpui_tokio",
"http_client",
"image_viewer",
"indoc",
"inline_completion_button",
"install_cli",
"journal",
@@ -18589,7 +18697,6 @@ dependencies = [
"languages",
"libc",
"log",
"markdown",
"markdown_preview",
"menu",
"migrator",
@@ -18641,7 +18748,6 @@ dependencies = [
"tree-sitter-md",
"tree-sitter-rust",
"ui",
"ui_input",
"ui_prompt",
"url",
"urlencoding",
@@ -18693,7 +18799,7 @@ dependencies = [
[[package]]
name = "zed_extension_api"
version = "0.6.0"
version = "0.5.0"
dependencies = [
"serde",
"serde_json",
@@ -18716,9 +18822,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.1"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16d993fc42f9ec43ab76fa46c6eb579a66e116bb08cd2bc9a67f3afcaa05d39d"
checksum = "a23b2fd00776b0c55072f389654910ceb501eb0083d7f78905ab0e5cc86949ec"
dependencies = [
"anyhow",
"serde",
@@ -18753,7 +18859,7 @@ dependencies = [
name = "zed_test_extension"
version = "0.1.0"
dependencies = [
"zed_extension_api 0.6.0",
"zed_extension_api 0.5.0",
]
[[package]]
@@ -18926,7 +19032,6 @@ dependencies = [
"paths",
"postage",
"project",
"proto",
"regex",
"release_channel",
"reqwest_client",

View File

@@ -31,13 +31,13 @@ members = [
"crates/command_palette",
"crates/command_palette_hooks",
"crates/component",
"crates/component_preview",
"crates/context_server",
"crates/copilot",
"crates/credentials_provider",
"crates/dap",
"crates/dap_adapters",
"crates/db",
"crates/debug_adapter_extension",
"crates/debugger_tools",
"crates/debugger_ui",
"crates/deepseek",
@@ -238,13 +238,13 @@ collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
component = { path = "crates/component" }
component_preview = { path = "crates/component_preview" }
context_server = { path = "crates/context_server" }
copilot = { path = "crates/copilot" }
credentials_provider = { path = "crates/credentials_provider" }
dap = { path = "crates/dap" }
dap_adapters = { path = "crates/dap_adapters" }
db = { path = "crates/db" }
debug_adapter_extension = { path = "crates/debug_adapter_extension" }
debugger_tools = { path = "crates/debugger_tools" }
debugger_ui = { path = "crates/debugger_ui" }
deepseek = { path = "crates/deepseek" }
@@ -410,9 +410,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-graphics = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
blade-util = { git = "https://github.com/kvark/blade", rev = "b16f5c7bd873c7126f48c82c39e7ae64602ae74f" }
blake3 = "1.5.3"
bytes = "1.0"
cargo_metadata = "0.19"
@@ -465,12 +465,13 @@ jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,r
libc = "0.2"
libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0"
linkme = "0.3.31"
log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] }
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
markup5ever_rcdom = "0.3.0"
metal = "0.29"
mlua = { version = "0.10", features = ["lua54", "vendored", "async", "send"] }
naga = { version = "25.0", features = ["wgsl-in"] }
naga = { version = "23.1.0", features = ["wgsl-in"] }
nanoid = "0.4"
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
nix = "0.29"
@@ -594,7 +595,6 @@ url = "2.2"
urlencoding = "2.1.2"
uuid = { version = "1.1.2", features = ["v4", "v5", "v7", "serde"] }
walkdir = "2.3"
wasi-preview1-component-adapter-provider = "29"
wasm-encoder = "0.221"
wasmparser = "0.221"
wasmtime = { version = "29", default-features = false, features = [
@@ -608,7 +608,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.8.1"
zed_llm_client = "0.8.0"
zstd = "0.11"
[workspace.dependencies.async-stripe]
@@ -788,9 +788,6 @@ let_underscore_future = "allow"
# running afoul of the borrow checker.
too_many_arguments = "allow"
# We often have large enum variants yet we rarely actually bother with splitting them up.
large_enum_variant = "allow"
[workspace.metadata.cargo-machete]
ignored = [
"bindgen",
@@ -798,6 +795,7 @@ ignored = [
"prost_build",
"serde",
"component",
"linkme",
"documented",
"workspace-hack",
]

View File

@@ -1,6 +1,6 @@
# syntax = docker/dockerfile:1.2
FROM rust:1.87-bookworm as builder
FROM rust:1.86-bookworm as builder
WORKDIR app
COPY . .

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-loader-circle-icon lucide-loader-circle"><path d="M21 12a9 9 0 1 1-6.219-8.56"/></svg>

Before

Width:  |  Height:  |  Size: 289 B

View File

@@ -217,6 +217,7 @@
"context": "ContextEditor > Editor",
"bindings": {
"ctrl-enter": "assistant::Assist",
"ctrl-shift-enter": "assistant::Edit",
"ctrl-s": "workspace::Save",
"save": "workspace::Save",
"ctrl->": "assistant::QuoteSelection",
@@ -244,7 +245,7 @@
"ctrl-shift-a": "agent::ToggleContextPicker",
"ctrl-shift-o": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"shift-escape": "agent::ExpandMessageEditor",
"ctrl-alt-e": "agent::RemoveAllContext",
"ctrl-shift-e": "project_panel::ToggleFocus"
}
@@ -766,7 +767,7 @@
"alt-ctrl-r": "project_panel::RevealInFileManager",
"ctrl-shift-enter": "project_panel::OpenWithSystem",
"shift-find": "project_panel::NewSearchInDirectory",
"ctrl-alt-shift-f": "project_panel::NewSearchInDirectory",
"ctrl-shift-f": "project_panel::NewSearchInDirectory",
"shift-down": "menu::SelectNext",
"shift-up": "menu::SelectPrevious",
"escape": "menu::Cancel"
@@ -934,7 +935,6 @@
"ctrl-e": ["terminal::SendKeystroke", "ctrl-e"],
"ctrl-o": ["terminal::SendKeystroke", "ctrl-o"],
"ctrl-w": ["terminal::SendKeystroke", "ctrl-w"],
"ctrl-backspace": ["terminal::SendKeystroke", "ctrl-w"],
"ctrl-shift-a": "editor::SelectAll",
"find": "buffer_search::Deploy",
"ctrl-shift-f": "buffer_search::Deploy",
@@ -952,10 +952,7 @@
"shift-down": "terminal::ScrollLineDown",
"shift-home": "terminal::ScrollToTop",
"shift-end": "terminal::ScrollToBottom",
"ctrl-shift-space": "terminal::ToggleViMode",
"ctrl-shift-r": "terminal::RerunTask",
"ctrl-alt-r": "terminal::RerunTask",
"alt-t": "terminal::RerunTask"
"ctrl-shift-space": "terminal::ToggleViMode"
}
},
{
@@ -978,12 +975,5 @@
"bindings": {
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
}
},
{
"context": "DebugConsole > Editor",
"use_key_equivalents": true,
"bindings": {
"enter": "menu::Confirm"
}
}
]

View File

@@ -263,6 +263,7 @@
"use_key_equivalents": true,
"bindings": {
"cmd-enter": "assistant::Assist",
"cmd-shift-enter": "assistant::Edit",
"cmd-s": "workspace::Save",
"cmd->": "assistant::QuoteSelection",
"cmd-<": "assistant::InsertIntoEditor",
@@ -290,7 +291,7 @@
"cmd-shift-a": "agent::ToggleContextPicker",
"cmd-shift-o": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"shift-escape": "agent::ExpandMessageEditor",
"cmd-alt-e": "agent::RemoveAllContext",
"cmd-shift-e": "project_panel::ToggleFocus"
}
@@ -825,7 +826,7 @@
"alt-cmd-r": "project_panel::RevealInFileManager",
"ctrl-shift-enter": "project_panel::OpenWithSystem",
"cmd-alt-backspace": ["project_panel::Delete", { "skip_prompt": false }],
"cmd-alt-shift-f": "project_panel::NewSearchInDirectory",
"cmd-shift-f": "project_panel::NewSearchInDirectory",
"shift-down": "menu::SelectNext",
"shift-up": "menu::SelectPrevious",
"escape": "menu::Cancel"
@@ -1021,7 +1022,6 @@
"escape": ["terminal::SendKeystroke", "escape"],
"enter": ["terminal::SendKeystroke", "enter"],
"ctrl-c": ["terminal::SendKeystroke", "ctrl-c"],
"ctrl-backspace": ["terminal::SendKeystroke", "ctrl-w"],
"shift-pageup": "terminal::ScrollPageUp",
"cmd-up": "terminal::ScrollPageUp",
"shift-pagedown": "terminal::ScrollPageDown",
@@ -1038,8 +1038,7 @@
"ctrl-alt-up": "pane::SplitUp",
"ctrl-alt-down": "pane::SplitDown",
"ctrl-alt-left": "pane::SplitLeft",
"ctrl-alt-right": "pane::SplitRight",
"cmd-alt-r": "terminal::RerunTask"
"ctrl-alt-right": "pane::SplitRight"
}
},
{
@@ -1084,12 +1083,5 @@
"bindings": {
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
}
},
{
"context": "DebugConsole > Editor",
"use_key_equivalents": true,
"bindings": {
"enter": "menu::Confirm"
}
}
]

View File

@@ -49,9 +49,10 @@ And here's the section to rewrite based on that prompt again for reference:
</rewrite_this>
{{#if diagnostic_errors}}
{{#each diagnostic_errors}}
Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to.
{{#each diagnostic_errors}}
<diagnostic_error>
<line_number>{{line_number}}</line_number>
<error_message>{{error_message}}</error_message>

View File

@@ -0,0 +1,206 @@
<task_description>
The user of a code editor wants to make a change to their codebase.
You must describe the change using the following XML structure:
- <patch> - A group of related code changes.
Child tags:
- <title> (required) - A high-level description of the changes. This should be as short
as possible, possibly using common abbreviations.
- <edit> (1 or more) - An edit to make at a particular range within a file.
Includes the following child tags:
- <path> (required) - The path to the file that will be changed.
- <description> (optional) - An arbitrarily-long comment that describes the purpose
of this edit.
- <old_text> (optional) - An excerpt from the file's current contents that uniquely
identifies a range within the file where the edit should occur. Required for all operations
except `create`.
- <new_text> (required) - The new text to insert into the file.
- <operation> (required) - The type of change that should occur at the given range
of the file. Must be one of the following values:
- `update`: Replaces the entire range with the new text.
- `insert_before`: Inserts the new text before the range.
- `insert_after`: Inserts new text after the range.
- `create`: Creates or overwrites a file with the given path and the new text.
- `delete`: Deletes the specified range from the file.
<guidelines>
- Never provide multiple edits whose ranges intersect each other. Instead, merge them into one edit.
- Prefer multiple edits to smaller, disjoint ranges, rather than one edit to a larger range.
- There's no need to escape angle brackets within XML tags.
- Always ensure imports are added if you're referencing symbols that are not in scope.
</guidelines>
Here are some concrete examples.
<example>
<message role="user">
```rs src/shapes.rs
pub mod rectangle;
pub mod circle;
```
```rs src/shapes/rectangle.rs
pub struct Rectangle {
width: f64,
height: f64,
}
impl Rectangle {
pub fn new(width: f64, height: f64) -> Self {
Rectangle { width, height }
}
}
```
```rs src/shapes/circle.rs
pub struct Circle {
radius: f64,
}
impl Circle {
pub fn new(radius: f64) -> Self {
Circle { radius }
}
}
```
Update all shapes to store their origin as an (x, y) tuple and implement Display.
</message>
<message role="assistant">
We'll need to update both the rectangle and circle modules.
<patch>
<title>Add origins and display impls to shapes</title>
<edit>
<path>src/shapes/rectangle.rs</path>
<description>Add the origin field to Rectangle struct</description>
<operation>insert_after</operation>
<old_text>
pub struct Rectangle {
</old_text>
<new_text>
origin: (f64, f64),
</new_text>
</edit>
<edit>
<path>src/shapes/rectangle.rs</path>
<operation>update</operation>
<old_text>
fn new(width: f64, height: f64) -> Self {
Rectangle { width, height }
}
</old_text>
<new_text>
fn new(origin: (f64, f64), width: f64, height: f64) -> Self {
Rectangle { origin, width, height }
}
</new_text>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<description>Add the origin field to Circle struct</description>
<operation>insert_after</operation>
<old_text>
pub struct Circle {
radius: f64,
</old_text>
<new_text>
origin: (f64, f64),
</new_text>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<operation>update</operation>
<old_text>
fn new(radius: f64) -> Self {
Circle { radius }
}
</old_text>
<new_text>
fn new(origin: (f64, f64), radius: f64) -> Self {
Circle { origin, radius }
}
</new_text>
</edit>
</step>
<edit>
<path>src/shapes/rectangle.rs</path>
<operation>insert_before</operation>
<old_text>
struct Rectangle {
</old_text>
<new_text>
use std::fmt;
</new_text>
</edit>
<edit>
<path>src/shapes/rectangle.rs</path>
<description>
Add a manual Display implementation for Rectangle.
Currently, this is the same as a derived Display implementation.
</description>
<operation>insert_after</operation>
<old_text>
Rectangle { width, height }
}
}
</old_text>
<new_text>
impl fmt::Display for Rectangle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.format_struct(f, "Rectangle")
.field("origin", &self.origin)
.field("width", &self.width)
.field("height", &self.height)
.finish()
}
}
</new_text>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<operation>insert_before</operation>
<old_text>
struct Circle {
</old_text>
<new_text>
use std::fmt;
</new_text>
</edit>
<edit>
<path>src/shapes/circle.rs</path>
<operation>insert_after</operation>
<old_text>
Circle { radius }
}
}
</old_text>
<new_text>
impl fmt::Display for Rectangle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.format_struct(f, "Rectangle")
.field("origin", &self.origin)
.field("width", &self.width)
.field("height", &self.height)
.finish()
}
}
</new_text>
</edit>
</patch>
</message>
</example>
</task_description>

View File

@@ -113,8 +113,8 @@
// Whether to show the informational hover box when moving the mouse
// over symbols in the editor.
"hover_popover_enabled": true,
// Time to wait in milliseconds before showing the informational hover box.
"hover_popover_delay": 300,
// Time to wait before showing the informational hover box
"hover_popover_delay": 350,
// Whether to confirm before quitting Zed.
"confirm_quit": false,
// Whether to restore last closed project when fresh Zed instance is opened.
@@ -218,23 +218,6 @@
// 1. Do nothing: `none`
// 2. Find references for the same symbol: `find_all_references` (default)
"go_to_definition_fallback": "find_all_references",
// Which level to use to filter out diagnostics displayed in the editor.
//
// Affects the editor rendering only, and does not interrupt
// the functionality of diagnostics fetching and project diagnostics editor.
// Which files containing diagnostic errors/warnings to mark in the tabs.
// Diagnostics are only shown when file icons are also active.
// This setting only works when can take the following three values:
//
// Which diagnostic indicators to show in the scrollbar, their level should be more or equal to the specified severity level.
// Possible values:
// - "off" — no diagnostics are allowed
// - "error"
// - "warning" (default)
// - "info"
// - "hint"
// - null — allow all diagnostics
"diagnostics_max_severity": "warning",
// Whether to show wrap guides (vertical rulers) in the editor.
// Setting this to true will show a guide at the 'preferred_line_length' value
// if 'soft_wrap' is set to 'preferred_line_length', and will show any
@@ -328,16 +311,10 @@
"title_bar": {
// Whether to show the branch icon beside branch switcher in the titlebar.
"show_branch_icon": false,
// Whether to show the branch name button in the titlebar.
"show_branch_name": true,
// Whether to show the project host and name in the titlebar.
"show_project_items": true,
// Whether to show onboarding banners in the titlebar.
"show_onboarding_banner": true,
// Whether to show user picture in the titlebar.
"show_user_picture": true,
// Whether to show the sign in button in the titlebar.
"show_sign_in": true
"show_user_picture": true
},
// Scrollbar related settings
"scrollbar": {
@@ -416,7 +393,11 @@
// 1. `null` to inherit the editor `current_line_highlight` setting (default)
// 2. "line" or "all" to highlight the current line in the minimap.
// 3. "gutter" or "none" to not highlight the current line in the minimap.
"current_line_highlight": null
"current_line_highlight": null,
// The width of the minimap in pixels.
"width": 100,
// The font size of the minimap in pixels.
"font_size": 2
},
// Enable middle-click paste on Linux.
"middle_click_paste": true,
@@ -476,8 +457,6 @@
"search_wrap": true,
// Search options to enable by default when opening new project and buffer searches.
"search": {
// Whether to show the project search button in the status bar.
"button": true,
"whole_word": false,
"case_sensitive": false,
"include_ignored": false,
@@ -758,8 +737,6 @@
"stream_edits": false,
// When enabled, agent edits will be displayed in single-file editors for review
"single_file_review": true,
// When enabled, show voting thumbs for feedback on agent edits.
"enable_feedback": true,
"default_profile": "write",
"profiles": {
"write": {
@@ -1012,8 +989,6 @@
"auto_update": true,
// Diagnostics configuration.
"diagnostics": {
// Whether to show the project diagnostics button in the status bar.
"button": true,
// Whether to show warnings or not by default.
"include_warnings": true,
// Settings for inline diagnostics
@@ -1031,7 +1006,7 @@
// longer than this value will still push diagnostics further to the right.
"min_column": 0,
// The minimum severity of the diagnostics to show inline.
// Inherits editor's diagnostics' max severity settings when `null`.
// Shows all diagnostics when not specified.
"max_severity": null
},
"cargo": {
@@ -1309,22 +1284,21 @@
"JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json"],
"Shell Script": [".env.*"]
},
// Settings for which version of Node.js and NPM to use when installing
// language servers and Copilot.
//
// Note: changing this setting currently requires restarting Zed.
"node": {
// By default, Zed will look for `node` and `npm` on your `$PATH`, and use the
// existing executables if their version is recent enough. Set this to `true`
// to prevent this, and force Zed to always download and install its own
// version of Node.
"ignore_system_version": false,
// You can also specify alternative paths to Node and NPM. If you specify
// `path`, but not `npm_path`, Zed will assume that `npm` is located at
// `${path}/../npm`.
"path": null,
"npm_path": null
},
// By default use a recent system version of node, or install our own.
// You can override this to use a version of node that is not in $PATH with:
// {
// "node": {
// "path": "/path/to/node"
// "npm_path": "/path/to/npm" (defaults to node_path/../npm)
// }
// }
// or to ensure Zed always downloads and installs an isolated version of node:
// {
// "node": {
// "ignore_system_version": true,
// }
// NOTE: changing this setting currently requires restarting Zed.
"node": {},
// The extensions that Zed should automatically install on startup.
//
// If you don't want any of these extensions, add this field to your settings

View File

@@ -47,12 +47,12 @@ heed.workspace = true
html_to_markdown.workspace = true
http_client.workspace = true
indexed_docs.workspace = true
inventory.workspace = true
itertools.workspace = true
jsonschema.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
linkme.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true

View File

@@ -3,10 +3,9 @@ use crate::context::{AgentContextHandle, RULES_ICON};
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::insert_message_creases;
use crate::thread::{
LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError,
ThreadEvent, ThreadFeedback, ThreadSummary,
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, TextThreadStore, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
@@ -33,9 +32,7 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{
HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange,
};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::{ProjectEntryId, ProjectItem as _};
use rope::Point;
use settings::{Settings as _, SettingsStore, update_settings_file};
@@ -185,14 +182,12 @@ pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = TextSize::Small.rems(cx);
let mut text_style = window.text_style();
let line_height = buffer_font_size * 1.75;
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
line_height: Some(line_height.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
@@ -332,7 +327,6 @@ fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
}
}
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 10;
fn render_markdown_code_block(
@@ -385,25 +379,18 @@ fn render_markdown_code_block(
)
} else {
let content = if let Some(parent) = path_range.path.parent() {
let file_name = file_name.to_string_lossy().to_string();
let path = parent.to_string_lossy().to_string();
let path_and_file = format!("{}/{}", path, file_name);
h_flex()
.id(("code-block-header-label", ix))
.ml_1()
.gap_1()
.child(Label::new(file_name).size(LabelSize::Small))
.child(Label::new(path).color(Color::Muted).size(LabelSize::Small))
.tooltip(move |window, cx| {
Tooltip::with_meta(
"Jump to File",
None,
path_and_file.clone(),
window,
cx,
)
})
.child(
Label::new(file_name.to_string_lossy().to_string())
.size(LabelSize::Small),
)
.child(
Label::new(parent.to_string_lossy().to_string())
.color(Color::Muted)
.size(LabelSize::Small),
)
.into_any_element()
} else {
Label::new(path_range.path.to_string_lossy().to_string())
@@ -413,7 +400,7 @@ fn render_markdown_code_block(
};
h_flex()
.id(("code-block-header-button", ix))
.id(("code-block-header-label", ix))
.w_full()
.max_w_full()
.px_1()
@@ -421,6 +408,7 @@ fn render_markdown_code_block(
.cursor_pointer()
.rounded_sm()
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
.tooltip(Tooltip::text("Jump to File"))
.child(
h_flex()
.gap_0p5()
@@ -440,8 +428,49 @@ fn render_markdown_code_block(
let path_range = path_range.clone();
move |_, window, cx| {
workspace
.update(cx, |workspace, cx| {
open_path(&path_range, window, workspace, cx)
.update(cx, {
|workspace, cx| {
let Some(project_path) = workspace
.project()
.read(cx)
.find_project_path(&path_range.path, cx)
else {
return;
};
let Some(target) = path_range.range.as_ref().map(|range| {
Point::new(
// Line number is 1-based
range.start.line.saturating_sub(1),
range.start.col.unwrap_or(0),
)
}) else {
return;
};
let open_task = workspace.open_path(
project_path,
None,
true,
window,
cx,
);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) =
item.downcast::<Editor>()
{
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
target, window, cx,
);
})
.ok();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
})
.ok();
}
@@ -456,13 +485,12 @@ fn render_markdown_code_block(
.copied_code_block_ids
.contains(&(message_id, ix));
let can_expand = metadata.line_count >= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK;
let is_expanded = if can_expand {
active_thread.read(cx).is_codeblock_expanded(message_id, ix)
} else {
false
};
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, ix))
.copied()
.unwrap_or(true);
let codeblock_header_bg = cx
.theme()
@@ -470,87 +498,10 @@ fn render_markdown_code_block(
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.01));
let control_buttons = h_flex()
.visible_on_hover(CODEBLOCK_CONTAINER_GROUP)
.absolute()
.top_0()
.right_0()
.h_full()
.bg(codeblock_header_bg)
.rounded_tr_md()
.px_1()
.gap_1()
.child(
IconButton::new(
("copy-markdown-code", ix),
if codeblock_was_copied {
IconName::Check
} else {
IconName::Copy
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
let code_block_range = metadata.content_range.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
let code = parsed_markdown.source()[code_block_range.clone()].to_string();
cx.write_to_clipboard(ClipboardItem::new_string(code));
cx.spawn(async move |this, cx| {
cx.background_executor().timer(Duration::from_secs(2)).await;
cx.update(|cx| {
this.update(cx, |this, cx| {
this.copied_code_block_ids.remove(&(message_id, ix));
cx.notify();
})
})
.ok();
})
.detach();
});
}
}),
)
.when(can_expand, |header| {
header.child(
IconButton::new(
("expand-collapse-code", ix),
if is_expanded {
IconName::ChevronUp
} else {
IconName::ChevronDown
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text(if is_expanded {
"Collapse Code"
} else {
"Expand Code"
}))
.on_click({
let active_thread = active_thread.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.toggle_codeblock_expanded(message_id, ix);
cx.notify();
});
}
}),
)
});
let codeblock_header = h_flex()
.relative()
.p_1()
.py_1()
.pl_1p5()
.pr_1()
.gap_1()
.justify_between()
.border_b_1()
@@ -558,10 +509,89 @@ fn render_markdown_code_block(
.bg(codeblock_header_bg)
.rounded_t_md()
.children(label)
.child(control_buttons);
.child(
h_flex()
.visible_on_hover("codeblock_container")
.gap_1()
.child(
IconButton::new(
("copy-markdown-code", ix),
if codeblock_was_copied {
IconName::Check
} else {
IconName::Copy
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
let code_block_range = metadata.content_range.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
let code =
parsed_markdown.source()[code_block_range.clone()].to_string();
cx.write_to_clipboard(ClipboardItem::new_string(code));
cx.spawn(async move |this, cx| {
cx.background_executor().timer(Duration::from_secs(2)).await;
cx.update(|cx| {
this.update(cx, |this, cx| {
this.copied_code_block_ids.remove(&(message_id, ix));
cx.notify();
})
})
.ok();
})
.detach();
});
}
}),
)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|header| {
header.child(
IconButton::new(
("expand-collapse-code", ix),
if is_expanded {
IconName::ChevronUp
} else {
IconName::ChevronDown
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text(if is_expanded {
"Collapse Code"
} else {
"Expand Code"
}))
.on_click({
let active_thread = active_thread.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
let is_expanded = this
.expanded_code_blocks
.entry((message_id, ix))
.or_insert(true);
*is_expanded = !*is_expanded;
cx.notify();
});
}
}),
)
},
),
);
v_flex()
.group(CODEBLOCK_CONTAINER_GROUP)
.group("codeblock_container")
.my_2()
.overflow_hidden()
.rounded_lg()
@@ -569,46 +599,16 @@ fn render_markdown_code_block(
.border_color(cx.theme().colors().border.opacity(0.6))
.bg(cx.theme().colors().editor_background)
.child(codeblock_header)
.when(can_expand && !is_expanded, |this| this.max_h_80())
}
fn open_path(
path_range: &PathWithRange,
window: &mut Window,
workspace: &mut Workspace,
cx: &mut Context<'_, Workspace>,
) {
let Some(project_path) = workspace
.project()
.read(cx)
.find_project_path(&path_range.path, cx)
else {
return; // TODO instead of just bailing out, open that path in a buffer.
};
let Some(target) = path_range.range.as_ref().map(|range| {
Point::new(
// Line number is 1-based
range.start.line.saturating_sub(1),
range.start.col.unwrap_or(0),
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|this| {
if is_expanded {
this.h_full()
} else {
this.max_h_80()
}
},
)
}) else {
return;
};
let open_task = workspace.open_path(project_path, None, true, window, cx);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) = item.downcast::<Editor>() {
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(target, window, cx);
})
.ok();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
fn render_code_language(
@@ -827,12 +827,12 @@ impl ActiveThread {
self.messages.is_empty()
}
pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary {
pub fn summary(&self, cx: &App) -> Option<SharedString> {
self.thread.read(cx).summary()
}
pub fn regenerate_summary(&self, cx: &mut App) {
self.thread.update(cx, |thread, cx| thread.summarize(cx))
pub fn summary_or_default(&self, cx: &App) -> SharedString {
self.thread.read(cx).summary_or_default()
}
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
@@ -1138,7 +1138,11 @@ impl ActiveThread {
return;
}
let title = self.thread.read(cx).summary().unwrap_or("Agent Panel");
let title = self
.thread
.read(cx)
.summary()
.unwrap_or("Agent Panel".into());
match AssistantSettings::get_global(cx).notify_when_agent_waiting {
NotifyWhenAgentWaiting::PrimaryScreen => {
@@ -1268,7 +1272,6 @@ impl ActiveThread {
&mut self,
message_id: MessageId,
message_segments: &[MessageSegment],
message_creases: &[MessageCrease],
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -1288,7 +1291,6 @@ impl ActiveThread {
);
editor.update(cx, |editor, cx| {
editor.set_text(message_text.clone(), window, cx);
insert_message_creases(editor, message_creases, &self.context_store, window, cx);
editor.focus_handle(cx).focus(window);
editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
});
@@ -1722,11 +1724,10 @@ impl ActiveThread {
.on_action(cx.listener(Self::confirm_editing_message))
.capture_action(cx.listener(Self::paste))
.min_h_6()
.w_full()
.flex_grow()
.w_full()
.gap_2()
.child(state.context_strip.clone())
.child(div().pt(px(-3.)).px_neg_0p5().child(EditorElement::new(
.child(EditorElement::new(
&state.editor,
EditorStyle {
background: colors.editor_background,
@@ -1735,7 +1736,8 @@ impl ActiveThread {
syntax: cx.theme().syntax().clone(),
..Default::default()
},
)))
))
.child(state.context_strip.clone())
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
@@ -1743,7 +1745,6 @@ impl ActiveThread {
let Some(message) = self.thread.read(cx).message(message_id) else {
return Empty.into_any();
};
let message_creases = message.creases.clone();
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
@@ -1863,8 +1864,7 @@ impl ActiveThread {
.child(open_as_markdown),
)
.into_any_element(),
None if AssistantSettings::get_global(cx).enable_feedback =>
feedback_container
None => feedback_container
.child(
div().visible_on_hover("feedback_container").child(
Label::new(
@@ -1907,9 +1907,6 @@ impl ActiveThread {
.child(open_as_markdown),
)
.into_any_element(),
None => feedback_container
.child(h_flex().child(open_as_markdown))
.into_any_element(),
};
let message_is_empty = message.should_display_content();
@@ -1923,6 +1920,16 @@ impl ActiveThread {
v_flex()
.w_full()
.gap_1()
.when(!message_is_empty, |parent| {
parent.child(div().min_h_6().child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
)))
})
.when(!added_context.is_empty(), |parent| {
parent.child(h_flex().flex_wrap().gap_1().children(
added_context.into_iter().map(|added_context| {
@@ -1941,16 +1948,6 @@ impl ActiveThread {
}),
))
})
.when(!message_is_empty, |parent| {
parent.child(div().pt_0p5().min_h_6().child(self.render_message_content(
message_id,
rendered_message,
has_tool_uses,
workspace.clone(),
window,
cx,
)))
})
.into_any_element()
}
});
@@ -1976,7 +1973,6 @@ impl ActiveThread {
h_flex()
.p_2p5()
.gap_1()
.items_end()
.children(message_content)
.when_some(editing_message_state, |this, state| {
let focus_handle = state.editor.focus_handle(cx).clone();
@@ -1990,7 +1986,6 @@ impl ActiveThread {
)
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
@@ -2008,12 +2003,11 @@ impl ActiveThread {
.child(
IconButton::new(
"confirm-edit-message",
IconName::Return,
IconName::Check,
)
.disabled(state.editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.icon_color(Color::Success)
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
@@ -2033,13 +2027,15 @@ impl ActiveThread {
)
}),
)
.when(editing_message_state.is_none(), |this| {
this.tooltip(Tooltip::text("Click To Edit"))
})
.on_click(cx.listener({
let message_segments = message.segments.clone();
move |this, _, window, cx| {
this.start_editing_message(
message_id,
&message_segments,
&message_creases,
window,
cx,
);
@@ -2073,16 +2069,6 @@ impl ActiveThread {
let panel_background = cx.theme().colors().panel_background;
let backdrop = div()
.id("backdrop")
.stop_mouse_events_except_scroll()
.absolute()
.inset_0()
.size_full()
.bg(panel_background)
.opacity(0.8)
.on_click(cx.listener(Self::handle_cancel_click));
v_flex()
.w_full()
.map(|parent| {
@@ -2252,7 +2238,15 @@ impl ActiveThread {
})
.when(after_editing_message, |parent| {
// Backdrop to dim out the whole thread below the editing user message
parent.relative().child(backdrop)
parent.relative().child(
div()
.occlude()
.absolute()
.inset_0()
.size_full()
.bg(panel_background)
.opacity(0.8),
)
})
.into_any()
}
@@ -2365,21 +2359,19 @@ impl ActiveThread {
let editor_bg = cx.theme().colors().editor_background;
move |el, range, metadata, _, cx| {
let can_expand = metadata.line_count
>= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK;
if !can_expand {
return el;
}
let is_expanded = active_thread
.read(cx)
.is_codeblock_expanded(message_id, range.start);
.expanded_code_blocks
.get(&(message_id, range.start))
.copied()
.unwrap_or(true);
if is_expanded {
if is_expanded
|| metadata.line_count
<= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK
{
return el;
}
el.child(
div()
.absolute()
@@ -2405,7 +2397,6 @@ impl ActiveThread {
markdown_element.code_block_renderer(
markdown::CodeBlockRenderer::Default {
copy_button: false,
copy_button_on_hover: false,
border: true,
},
)
@@ -2725,7 +2716,6 @@ impl ActiveThread {
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
copy_button_on_hover: false,
border: false,
})
.on_url_click({
@@ -2756,7 +2746,6 @@ impl ActiveThread {
)
.code_block_renderer(markdown::CodeBlockRenderer::Default {
copy_button: false,
copy_button_on_hover: false,
border: false,
})
.on_url_click({
@@ -3393,21 +3382,6 @@ impl ActiveThread {
.log_err();
}))
}
pub fn is_codeblock_expanded(&self, message_id: MessageId, ix: usize) -> bool {
self.expanded_code_blocks
.get(&(message_id, ix))
.copied()
.unwrap_or(true)
}
pub fn toggle_codeblock_expanded(&mut self, message_id: MessageId, ix: usize) {
let is_expanded = self
.expanded_code_blocks
.entry((message_id, ix))
.or_insert(true);
*is_expanded = !*is_expanded;
}
}
pub enum ActiveThreadEvent {
@@ -3421,7 +3395,6 @@ impl Render for ActiveThread {
v_flex()
.size_full()
.relative()
.bg(cx.theme().colors().panel_background)
.on_mouse_move(cx.listener(|this, _, _, cx| {
this.show_scrollbar = true;
this.hide_scrollbar_later(cx);
@@ -3463,7 +3436,10 @@ pub(crate) fn open_active_thread_as_markdown(
workspace.update_in(cx, |workspace, window, cx| {
let thread = thread.read(cx);
let markdown = thread.to_markdown(cx)?;
let thread_summary = thread.summary().or_default().to_string();
let thread_summary = thread
.summary()
.map(|summary| summary.to_string())
.unwrap_or_else(|| "Thread".to_string());
let project = workspace.project().clone();

View File

@@ -49,7 +49,7 @@ pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
use crate::slash_command_settings::SlashCommandSettings;
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
pub use crate::thread_store::{SerializedThread, TextThreadStore, ThreadStore};
pub use crate::thread_store::{TextThreadStore, ThreadStore};
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
pub use context_store::ContextStore;
pub use ui::preview::{all_agent_previews, get_agent_preview};

View File

@@ -18,8 +18,8 @@ use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageMod
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use settings::{Settings, update_settings_file};
use ui::{
Disclosure, ElevationIndex, Indicator, Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip,
prelude::*,
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
Switch, SwitchColor, Tooltip, prelude::*,
};
use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
@@ -36,7 +36,6 @@ pub struct AgentConfiguration {
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
context_server_store: Entity<ContextServerStore>,
expanded_context_server_tools: HashMap<ContextServerId, bool>,
expanded_provider_configurations: HashMap<LanguageModelProviderId, bool>,
tools: Entity<ToolWorkingSet>,
_registry_subscription: Subscription,
scroll_handle: ScrollHandle,
@@ -79,7 +78,6 @@ impl AgentConfiguration {
configuration_views_by_provider: HashMap::default(),
context_server_store,
expanded_context_server_tools: HashMap::default(),
expanded_provider_configurations: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
scroll_handle,
@@ -98,7 +96,6 @@ impl AgentConfiguration {
fn remove_provider_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
self.configuration_views_by_provider.remove(provider_id);
self.expanded_provider_configurations.remove(provider_id);
}
fn add_provider_configuration_view(
@@ -138,14 +135,9 @@ impl AgentConfiguration {
.get(&provider.id())
.cloned();
let is_expanded = self
.expanded_provider_configurations
.get(&provider.id())
.copied()
.unwrap_or(false);
v_flex()
.pt_3()
.pb_1()
.gap_1p5()
.border_t_1()
.border_color(cx.theme().colors().border.opacity(0.6))
@@ -160,63 +152,36 @@ impl AgentConfiguration {
.size(IconSize::Small)
.color(Color::Muted),
)
.child(Label::new(provider_name.clone()).size(LabelSize::Large))
.when(provider.is_authenticated(cx) && !is_expanded, |parent| {
parent.child(Icon::new(IconName::Check).color(Color::Success))
}),
.child(Label::new(provider_name.clone()).size(LabelSize::Large)),
)
.child(
h_flex()
.gap_1()
.when(provider.is_authenticated(cx), |parent| {
parent.child(
Button::new(
SharedString::from(format!("new-thread-{provider_id}")),
"Start New Thread",
)
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.layer(ElevationIndex::ModalSurface)
.label_size(LabelSize::Small)
.on_click(cx.listener({
let provider = provider.clone();
move |_this, _event, _window, cx| {
cx.emit(AssistantConfigurationEvent::NewThread(
provider.clone(),
))
}
})),
)
})
.child(
Disclosure::new(
SharedString::from(format!(
"provider-disclosure-{provider_id}"
)),
is_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
let provider_id = provider.id().clone();
move |this, _event, _window, _cx| {
let is_expanded = this
.expanded_provider_configurations
.entry(provider_id.clone())
.or_insert(false);
*is_expanded = !*is_expanded;
}
})),
),
),
.when(provider.is_authenticated(cx), |parent| {
parent.child(
Button::new(
SharedString::from(format!("new-thread-{provider_id}")),
"Start New Thread",
)
.icon_position(IconPosition::Start)
.icon(IconName::Plus)
.icon_size(IconSize::Small)
.style(ButtonStyle::Filled)
.layer(ElevationIndex::ModalSurface)
.label_size(LabelSize::Small)
.on_click(cx.listener({
let provider = provider.clone();
move |_this, _event, _window, cx| {
cx.emit(AssistantConfigurationEvent::NewThread(
provider.clone(),
))
}
})),
)
}),
)
.when(is_expanded, |parent| match configuration_view {
.map(|parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
None => parent.child(div().child(Label::new(format!(
"No configuration view for {provider_name}",
))),
)))),
})
}
@@ -230,8 +195,7 @@ impl AgentConfiguration {
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_4()
.border_b_1()
.border_color(cx.theme().colors().border)
.flex_1()
.child(
v_flex()
.gap_0p5()
@@ -332,8 +296,7 @@ impl AgentConfiguration {
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2p5()
.border_b_1()
.border_color(cx.theme().colors().border)
.flex_1()
.child(Headline::new("General Settings"))
.child(self.render_command_permission(cx))
.child(self.render_single_file_review(cx))
@@ -346,17 +309,18 @@ impl AgentConfiguration {
) -> impl IntoElement {
let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.gap_2()
.border_b_1()
.border_color(cx.theme().colors().border)
.flex_1()
.child(
v_flex()
.gap_0p5()
.child(Headline::new("Model Context Protocol (MCP) Servers"))
.child(Label::new("Connect to context servers via the Model Context Protocol either via Zed extensions or directly.").color(Color::Muted)),
.child(Label::new(SUBHEADING).color(Color::Muted)),
)
.children(
context_server_ids.into_iter().map(|context_server_id| {
@@ -423,7 +387,6 @@ impl AgentConfiguration {
.unwrap_or(ContextServerStatus::Stopped);
let is_running = matches!(server_status, ContextServerStatus::Running);
let item_id = SharedString::from(context_server_id.0.clone());
let error = if let ContextServerStatus::Error(error) = server_status.clone() {
Some(error)
@@ -445,38 +408,9 @@ impl AgentConfiguration {
let tool_count = tools.len();
let border_color = cx.theme().colors().border.opacity(0.6);
let success_color = Color::Success.color(cx);
let (status_indicator, tooltip_text) = match server_status {
ContextServerStatus::Starting => (
Indicator::dot()
.color(Color::Success)
.with_animation(
SharedString::from(format!("{}-starting", context_server_id.0.clone(),)),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 1.)),
move |this, delta| this.color(success_color.alpha(delta).into()),
)
.into_any_element(),
"Server is starting.",
),
ContextServerStatus::Running => (
Indicator::dot().color(Color::Success).into_any_element(),
"Server is running.",
),
ContextServerStatus::Error(_) => (
Indicator::dot().color(Color::Error).into_any_element(),
"Server has an error.",
),
ContextServerStatus::Stopped => (
Indicator::dot().color(Color::Muted).into_any_element(),
"Server is stopped.",
),
};
v_flex()
.id(item_id.clone())
.id(SharedString::from(context_server_id.0.clone()))
.border_1()
.rounded_md()
.border_color(border_color)
@@ -511,12 +445,35 @@ impl AgentConfiguration {
}
})),
)
.child(
div()
.id(item_id.clone())
.tooltip(Tooltip::text(tooltip_text))
.child(status_indicator),
)
.child(match server_status {
ContextServerStatus::Starting => {
let color = Color::Success.color(cx);
Indicator::dot()
.color(Color::Success)
.with_animation(
SharedString::from(format!(
"{}-starting",
context_server_id.0.clone(),
)),
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 1.)),
move |this, delta| {
this.color(color.alpha(delta).into())
},
)
.into_any_element()
}
ContextServerStatus::Running => {
Indicator::dot().color(Color::Success).into_any_element()
}
ContextServerStatus::Error(_) => {
Indicator::dot().color(Color::Error).into_any_element()
}
ContextServerStatus::Stopped => {
Indicator::dot().color(Color::Muted).into_any_element()
}
})
.child(Label::new(context_server_id.0.clone()).ml_0p5())
.when(is_running, |this| {
this.child(
@@ -631,7 +588,9 @@ impl Render for AgentConfiguration {
.size_full()
.overflow_y_scroll()
.child(self.render_general_settings_section(cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_context_servers_section(window, cx))
.child(Divider::horizontal().color(DividerColor::Border))
.child(self.render_provider_configuration_section(cx)),
)
.child(

View File

@@ -30,6 +30,7 @@ pub(crate) struct ConfigureContextServerModal {
context_server_store: Entity<ContextServerStore>,
}
#[allow(clippy::large_enum_variant)]
enum Configuration {
NotAvailable,
Required(ConfigurationRequiredState),

View File

@@ -1,4 +1,6 @@
use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll, Thread, ThreadEvent};
use crate::{
Keep, KeepAll, OpenAgentDiff, Reject, RejectAll, Thread, ThreadEvent, ui::AnimatedLabel,
};
use anyhow::Result;
use assistant_settings::AssistantSettings;
use buffer_diff::DiffHunkStatus;
@@ -9,9 +11,8 @@ use editor::{
scroll::Autoscroll,
};
use gpui::{
Action, Animation, AnimationExt, AnyElement, AnyView, App, AppContext, Empty, Entity,
EventEmitter, FocusHandle, Focusable, Global, SharedString, Subscription, Task, Transformation,
WeakEntity, Window, percentage, prelude::*,
Action, AnyElement, AnyView, App, AppContext, Empty, Entity, EventEmitter, FocusHandle,
Focusable, Global, SharedString, Subscription, Task, WeakEntity, Window, prelude::*,
};
use language::{Buffer, Capability, DiskState, OffsetRangeExt, Point};
@@ -24,7 +25,6 @@ use std::{
collections::hash_map::Entry,
ops::Range,
sync::Arc,
time::Duration,
};
use ui::{IconButtonShape, KeyBinding, Tooltip, prelude::*, vertical_divider};
use util::ResultExt;
@@ -215,7 +215,11 @@ impl AgentDiffPane {
}
fn update_title(&mut self, cx: &mut Context<Self>) {
let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes");
let new_title = self
.thread
.read(cx)
.summary()
.unwrap_or("Agent Changes".into());
if new_title != self.title {
self.title = new_title;
cx.emit(EditorEvent::TitleChanged);
@@ -465,7 +469,11 @@ impl Item for AgentDiffPane {
}
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes");
let summary = self
.thread
.read(cx)
.summary()
.unwrap_or("Agent Changes".into());
Label::new(format!("Review: {}", summary))
.color(if params.selected {
Color::Default
@@ -970,20 +978,9 @@ impl ToolbarItemView for AgentDiffToolbar {
impl Render for AgentDiffToolbar {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let spinner_icon = div()
.px_0p5()
.id("generating")
.tooltip(Tooltip::text("Generating Changes…"))
.child(
Icon::new(IconName::LoadCircle)
.size(IconSize::Small)
.color(Color::Accent)
.with_animation(
"load_circle",
Animation::new(Duration::from_secs(3)).repeat(),
|icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
),
)
let generating_label = div()
.w(rems_from_px(110.)) // Arbitrary size so the label doesn't dance around
.child(AnimatedLabel::new("Generating"))
.into_any();
let Some(active_item) = self.active_item.as_ref() else {
@@ -1000,7 +997,7 @@ impl Render for AgentDiffToolbar {
let content = match state {
EditorState::Idle => return Empty.into_any(),
EditorState::Generating => vec![spinner_icon],
EditorState::Generating => vec![generating_label],
EditorState::Reviewing => vec![
h_flex()
.child(
@@ -1118,7 +1115,7 @@ impl Render for AgentDiffToolbar {
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
if is_generating {
return div().px_2().child(spinner_icon).into_any();
return div().px_2().child(generating_label).into_any();
}
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();

View File

@@ -10,8 +10,8 @@ use serde::{Deserialize, Serialize};
use anyhow::{Result, anyhow};
use assistant_context_editor::{
AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent,
ContextSummary, SlashCommandCompletionProvider, humanize_token_count,
make_lsp_adapter_delegate, render_remaining_tokens,
SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate,
render_remaining_tokens,
};
use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
@@ -46,9 +46,7 @@ use ui::{
};
use util::{ResultExt as _, maybe};
use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::{
CollaboratorId, DraggedSelection, DraggedTab, ToggleZoom, ToolbarItemView, Workspace,
};
use workspace::{CollaboratorId, DraggedSelection, DraggedTab, ToolbarItemView, Workspace};
use zed_actions::agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding};
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize};
@@ -57,10 +55,10 @@ use zed_llm_client::UsageLimit;
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
use crate::agent_diff::AgentDiff;
use crate::history_store::{HistoryStore, RecentEntry};
use crate::history_store::{HistoryEntry, HistoryStore, RecentEntry};
use crate::message_editor::{MessageEditor, MessageEditorEvent};
use crate::thread::{Thread, ThreadError, ThreadId, ThreadSummary, TokenUsageRatio};
use crate::thread_history::{HistoryEntryElement, ThreadHistory};
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{EntryTimeFormat, PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::ui::AgentOnboardingModal;
use crate::{
@@ -196,7 +194,7 @@ impl ActiveView {
}
pub fn thread(thread: Entity<Thread>, window: &mut Window, cx: &mut App) -> Self {
let summary = thread.read(cx).summary().or_default();
let summary = thread.read(cx).summary_or_default();
let editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
@@ -218,7 +216,7 @@ impl ActiveView {
}
EditorEvent::Blurred => {
if editor.read(cx).text(cx).is_empty() {
let summary = thread.read(cx).summary().or_default();
let summary = thread.read(cx).summary_or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -233,7 +231,7 @@ impl ActiveView {
let editor = editor.clone();
move |thread, event, window, cx| match event {
ThreadEvent::SummaryGenerated => {
let summary = thread.read(cx).summary().or_default();
let summary = thread.read(cx).summary_or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -296,8 +294,7 @@ impl ActiveView {
.read(cx)
.context()
.read(cx)
.summary()
.or_default();
.summary_or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -312,7 +309,7 @@ impl ActiveView {
let editor = editor.clone();
move |assistant_context, event, window, cx| match event {
ContextEvent::SummaryGenerated => {
let summary = assistant_context.read(cx).summary().or_default();
let summary = assistant_context.read(cx).summary_or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -359,13 +356,11 @@ pub struct AgentPanel {
previous_view: Option<ActiveView>,
history_store: Entity<HistoryStore>,
history: Entity<ThreadHistory>,
hovered_recent_history_item: Option<usize>,
assistant_dropdown_menu_handle: PopoverMenuHandle<ContextMenu>,
assistant_navigation_menu_handle: PopoverMenuHandle<ContextMenu>,
assistant_navigation_menu: Option<Entity<ContextMenu>>,
width: Option<Pixels>,
height: Option<Pixels>,
zoomed: bool,
pending_serialization: Option<Task<Result<()>>>,
hide_trial_upsell: bool,
_trial_markdown: Entity<Markdown>,
@@ -701,13 +696,11 @@ impl AgentPanel {
previous_view: None,
history_store: history_store.clone(),
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)),
hovered_recent_history_item: None,
assistant_dropdown_menu_handle: PopoverMenuHandle::default(),
assistant_navigation_menu_handle: PopoverMenuHandle::default(),
assistant_navigation_menu: None,
width: None,
height: None,
zoomed: false,
pending_serialization: None,
hide_trial_upsell: false,
_trial_markdown: trial_markdown,
@@ -1149,17 +1142,6 @@ impl AgentPanel {
}
}
pub fn toggle_zoom(&mut self, _: &ToggleZoom, window: &mut Window, cx: &mut Context<Self>) {
if self.zoomed {
cx.emit(PanelEvent::ZoomOut);
} else {
if !self.focus_handle(cx).contains_focused(window, cx) {
cx.focus_self(window);
}
cx.emit(PanelEvent::ZoomIn);
}
}
pub fn open_agent_diff(
&mut self,
_: &OpenAgentDiff,
@@ -1432,15 +1414,6 @@ impl Panel for AgentPanel {
fn enabled(&self, cx: &App) -> bool {
AssistantSettings::get_global(cx).enabled
}
fn is_zoomed(&self, _window: &Window, _cx: &App) -> bool {
self.zoomed
}
fn set_zoomed(&mut self, zoomed: bool, _window: &mut Window, cx: &mut Context<Self>) {
self.zoomed = zoomed;
cx.notify();
}
}
impl AgentPanel {
@@ -1453,45 +1426,23 @@ impl AgentPanel {
..
} => {
let active_thread = self.thread.read(cx);
let state = if active_thread.is_empty() {
&ThreadSummary::Pending
} else {
active_thread.summary(cx)
};
let is_empty = active_thread.is_empty();
match state {
ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone())
let summary = active_thread.summary(cx);
if is_empty {
Label::new(Thread::DEFAULT_SUMMARY.clone())
.truncate()
.into_any_element(),
ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER)
.into_any_element()
} else if summary.is_none() {
Label::new(LOADING_SUMMARY_PLACEHOLDER)
.truncate()
.into_any_element(),
ThreadSummary::Ready(_) => div()
.into_any_element()
} else {
div()
.w_full()
.child(change_title_editor.clone())
.into_any_element(),
ThreadSummary::Error => h_flex()
.w_full()
.child(change_title_editor.clone())
.child(
ui::IconButton::new("retry-summary-generation", IconName::RotateCcw)
.on_click({
let active_thread = self.thread.clone();
move |_, _window, cx| {
active_thread.update(cx, |thread, cx| {
thread.regenerate_summary(cx);
});
}
})
.tooltip(move |_window, cx| {
cx.new(|_| {
Tooltip::new("Failed to generate title")
.meta("Click to try again")
})
.into()
}),
)
.into_any_element(),
.into_any_element()
}
}
ActiveView::PromptEditor {
@@ -1499,13 +1450,14 @@ impl AgentPanel {
context_editor,
..
} => {
let summary = context_editor.read(cx).context().read(cx).summary();
let context_editor = context_editor.read(cx);
let summary = context_editor.context().read(cx).summary();
match summary {
ContextSummary::Pending => Label::new(ContextSummary::DEFAULT)
None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone())
.truncate()
.into_any_element(),
ContextSummary::Content(summary) => {
Some(summary) => {
if summary.done {
div()
.w_full()
@@ -1517,28 +1469,6 @@ impl AgentPanel {
.into_any_element()
}
}
ContextSummary::Error => h_flex()
.w_full()
.child(title_editor.clone())
.child(
ui::IconButton::new("retry-summary-generation", IconName::RotateCcw)
.on_click({
let context_editor = context_editor.clone();
move |_, _window, cx| {
context_editor.update(cx, |context_editor, cx| {
context_editor.regenerate_summary(cx);
});
}
})
.tooltip(move |_window, cx| {
cx.new(|_| {
Tooltip::new("Failed to generate title")
.meta("Click to try again")
})
.into()
}),
)
.into_any_element(),
}
}
ActiveView::History => Label::new("History").truncate().into_any_element(),
@@ -1648,12 +1578,6 @@ impl AgentPanel {
}),
);
let zoom_in_label = if self.is_zoomed(window, cx) {
"Zoom Out"
} else {
"Zoom In"
};
let agent_extra_menu = PopoverMenu::new("agent-options-menu")
.trigger_with_tooltip(
IconButton::new("agent-options-menu", IconName::Ellipsis)
@@ -1740,8 +1664,7 @@ impl AgentPanel {
menu = menu
.action("Rules…", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenConfiguration))
.action(zoom_in_label, Box::new(ToggleZoom));
.action("Settings", Box::new(OpenConfiguration));
menu
}))
});
@@ -2135,7 +2058,6 @@ impl AgentPanel {
v_flex()
.size_full()
.bg(cx.theme().colors().panel_background)
.when(recent_history.is_empty(), |this| {
let configuration_error_ref = &configuration_error;
this.child(
@@ -2290,7 +2212,7 @@ impl AgentPanel {
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.child(
Label::new("Recent")
Label::new("Past Interactions")
.size(LabelSize::Small)
.color(Color::Muted),
)
@@ -2315,20 +2237,18 @@ impl AgentPanel {
v_flex()
.gap_1()
.children(
recent_history.into_iter().enumerate().map(|(index, entry)| {
recent_history.into_iter().map(|entry| {
// TODO: Add keyboard navigation.
let is_hovered = self.hovered_recent_history_item == Some(index);
HistoryEntryElement::new(entry.clone(), cx.entity().downgrade())
.hovered(is_hovered)
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
if *is_hovered {
this.hovered_recent_history_item = Some(index);
} else if this.hovered_recent_history_item == Some(index) {
this.hovered_recent_history_item = None;
}
cx.notify();
}))
.into_any_element()
match entry {
HistoryEntry::Thread(thread) => {
PastThread::new(thread, cx.entity().downgrade(), false, vec![], EntryTimeFormat::DateAndTime)
.into_any_element()
}
HistoryEntry::Context(context) => {
PastContext::new(context, cx.entity().downgrade(), false, vec![], EntryTimeFormat::DateAndTime)
.into_any_element()
}
}
}),
)
)
@@ -2440,6 +2360,9 @@ impl AgentPanel {
.occlude()
.child(match last_error {
ThreadError::PaymentRequired => self.render_payment_required_error(cx),
ThreadError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx)
}
ThreadError::ModelRequestLimitReached { plan } => {
self.render_model_request_limit_reached_error(plan, cx)
}
@@ -2499,6 +2422,56 @@ impl AgentPanel {
.into_any()
}
fn render_max_monthly_spend_reached_error(&self, cx: &mut Context<Self>) -> AnyElement {
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(ERROR_MESSAGE)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(ERROR_MESSAGE))
.child(
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
cx.listener(|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
}),
),
)
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
},
))),
)
.into_any()
}
fn render_model_request_limit_reached_error(
&self,
plan: Plan,
@@ -2771,16 +2744,42 @@ impl AgentPanel {
impl Render for AgentPanel {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
// WARNING: Changes to this element hierarchy can have
// non-obvious implications to the layout of children.
//
// If you need to change it, please confirm:
// - The message editor expands (⌘esc) correctly
// - When expanded, the buttons at the bottom of the panel are displayed correctly
// - Font size works as expected and can be changed with ⌘+/⌘-
// - Scrolling in all views works as expected
// - Files can be dropped into the panel
let content = v_flex()
let content = match &self.active_view {
ActiveView::Thread { .. } => v_flex()
.relative()
.justify_between()
.size_full()
.child(self.render_active_thread_or_empty_state(window, cx))
.children(self.render_tool_use_limit_reached(cx))
.child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx))
.child(self.render_drag_target(cx))
.into_any(),
ActiveView::History => self.history.clone().into_any_element(),
ActiveView::PromptEditor {
context_editor,
buffer_search_bar,
..
} => self
.render_prompt_editor(context_editor, buffer_search_bar, window, cx)
.into_any(),
ActiveView::Configuration => v_flex()
.size_full()
.children(self.configuration.clone())
.into_any(),
};
let content = match self.active_view.which_font_size_used() {
WhichFontSize::AgentFont => {
WithRemSize::new(ThemeSettings::get_global(cx).agent_font_size(cx))
.size_full()
.child(content)
.into_any()
}
_ => content,
};
v_flex()
.key_context(self.key_context())
.justify_between()
.size_full()
@@ -2803,40 +2802,9 @@ impl Render for AgentPanel {
.on_action(cx.listener(Self::increase_font_size))
.on_action(cx.listener(Self::decrease_font_size))
.on_action(cx.listener(Self::reset_font_size))
.on_action(cx.listener(Self::toggle_zoom))
.child(self.render_toolbar(window, cx))
.children(self.render_trial_upsell(window, cx))
.map(|parent| match &self.active_view {
ActiveView::Thread { .. } => parent
.relative()
.child(self.render_active_thread_or_empty_state(window, cx))
.children(self.render_tool_use_limit_reached(cx))
.child(h_flex().child(self.message_editor.clone()))
.children(self.render_last_error(cx))
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::PromptEditor {
context_editor,
buffer_search_bar,
..
} => parent.child(self.render_prompt_editor(
context_editor,
buffer_search_bar,
window,
cx,
)),
ActiveView::Configuration => parent.children(self.configuration.clone()),
});
match self.active_view.which_font_size_used() {
WhichFontSize::AgentFont => {
WithRemSize::new(ThemeSettings::get_global(cx).agent_font_size(cx))
.size_full()
.child(content)
.into_any()
}
_ => content.into_any(),
}
.child(content)
}
}

View File

@@ -586,7 +586,10 @@ impl ThreadContextHandle {
}
pub fn title(&self, cx: &App) -> SharedString {
self.thread.read(cx).summary().or_default()
self.thread
.read(cx)
.summary()
.unwrap_or_else(|| "New thread".into())
}
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
@@ -594,7 +597,9 @@ impl ThreadContextHandle {
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?;
let title = self
.thread
.read_with(cx, |thread, _cx| thread.summary().or_default())
.read_with(cx, |thread, _cx| {
thread.summary().unwrap_or_else(|| "New thread".into())
})
.ok()?;
let context = AgentContext::Thread(ThreadContext {
title,
@@ -637,7 +642,7 @@ impl TextThreadContextHandle {
}
pub fn title(&self, cx: &App) -> SharedString {
self.context.read(cx).summary().or_default()
self.context.read(cx).summary_or_default()
}
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
@@ -749,11 +754,11 @@ pub enum ImageStatus {
impl ImageContext {
pub fn eq_for_key(&self, other: &Self) -> bool {
self.original_image.id() == other.original_image.id()
self.original_image.id == other.original_image.id
}
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
self.original_image.id().hash(state);
self.original_image.id.hash(state);
}
pub fn image(&self) -> Option<LanguageModelImage> {
@@ -825,20 +830,23 @@ pub fn load_context(
prompt_store: &Option<Entity<PromptStore>>,
cx: &mut App,
) -> Task<ContextLoadResult> {
let load_tasks: Vec<_> = contexts
.into_iter()
.map(|context| match context {
AgentContextHandle::File(context) => context.load(cx),
AgentContextHandle::Directory(context) => context.load(project.clone(), cx),
AgentContextHandle::Symbol(context) => context.load(cx),
AgentContextHandle::Selection(context) => context.load(cx),
AgentContextHandle::FetchedUrl(context) => context.load(),
AgentContextHandle::Thread(context) => context.load(cx),
AgentContextHandle::TextThread(context) => context.load(cx),
AgentContextHandle::Rules(context) => context.load(prompt_store, cx),
AgentContextHandle::Image(context) => context.load(cx),
})
.collect();
let mut load_tasks = Vec::new();
for context in contexts.iter().cloned() {
match context {
AgentContextHandle::File(context) => load_tasks.push(context.load(cx)),
AgentContextHandle::Directory(context) => {
load_tasks.push(context.load(project.clone(), cx))
}
AgentContextHandle::Symbol(context) => load_tasks.push(context.load(cx)),
AgentContextHandle::Selection(context) => load_tasks.push(context.load(cx)),
AgentContextHandle::FetchedUrl(context) => load_tasks.push(context.load()),
AgentContextHandle::Thread(context) => load_tasks.push(context.load(cx)),
AgentContextHandle::TextThread(context) => load_tasks.push(context.load(cx)),
AgentContextHandle::Rules(context) => load_tasks.push(context.load(prompt_store, cx)),
AgentContextHandle::Image(context) => load_tasks.push(context.load(cx)),
}
}
cx.background_spawn(async move {
let load_results = future::join_all(load_tasks).await;

View File

@@ -381,16 +381,6 @@ impl ContextPicker {
cx.focus_self(window);
}
pub fn select_first(&mut self, window: &mut Window, cx: &mut Context<Self>) {
match &self.mode {
ContextPickerState::Default(entity) => entity.update(cx, |entity, cx| {
entity.select_first(&Default::default(), window, cx)
}),
// Other variants already select their first entry on open automatically
_ => {}
}
}
fn recent_menu_item(
&self,
context_picker: Entity<ContextPicker>,
@@ -942,8 +932,8 @@ impl MentionLink {
format!("[@{}]({}:{})", title, Self::THREAD, id)
}
ThreadContextEntry::Context { path, title } => {
let filename = path.file_name().unwrap_or_default().to_string_lossy();
let escaped_filename = urlencoding::encode(&filename);
let filename = path.file_name().unwrap_or_default();
let escaped_filename = urlencoding::encode(&filename.to_string_lossy()).to_string();
format!(
"[@{}]({}:{}{})",
title,

View File

@@ -84,12 +84,6 @@ impl ContextStrip {
}
}
/// Whether or not the context strip has items to display
pub fn has_context_items(&self, cx: &App) -> bool {
self.context_store.read(cx).context().next().is_some()
|| self.suggested_context(cx).is_some()
}
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
if let Some(workspace) = self.workspace.upgrade() {
let project = workspace.read(cx).project().read(cx);
@@ -110,14 +104,14 @@ impl ContextStrip {
}
}
fn suggested_context(&self, cx: &App) -> Option<SuggestedContext> {
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
match self.suggest_context_kind {
SuggestContextKind::File => self.suggested_file(cx),
SuggestContextKind::Thread => self.suggested_thread(cx),
}
}
fn suggested_file(&self, cx: &App) -> Option<SuggestedContext> {
fn suggested_file(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
let workspace = self.workspace.upgrade()?;
let active_item = workspace.read(cx).active_item(cx)?;
@@ -144,7 +138,7 @@ impl ContextStrip {
})
}
fn suggested_thread(&self, cx: &App) -> Option<SuggestedContext> {
fn suggested_thread(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
if !self.context_picker.read(cx).allow_threads() {
return None;
}
@@ -166,7 +160,7 @@ impl ContextStrip {
}
Some(SuggestedContext::Thread {
name: active_thread.summary().or_default(),
name: active_thread.summary_or_default(),
thread: weak_active_thread,
})
} else if let Some(active_context_editor) = panel.active_context_editor() {
@@ -180,7 +174,7 @@ impl ContextStrip {
}
Some(SuggestedContext::TextThread {
name: context.summary().or_default(),
name: context.summary_or_default(),
context: weak_context,
})
} else {
@@ -426,25 +420,12 @@ impl Render for ContextStrip {
})
.child(
PopoverMenu::new("context-picker")
.menu({
let context_picker = context_picker.clone();
move |window, cx| {
context_picker.update(cx, |this, cx| {
this.init(window, cx);
});
.menu(move |window, cx| {
context_picker.update(cx, |this, cx| {
this.init(window, cx);
});
Some(context_picker.clone())
}
})
.on_open({
let context_picker = context_picker.downgrade();
Rc::new(move |window, cx| {
context_picker
.update(cx, |context_picker, cx| {
context_picker.select_first(window, cx);
})
.ok();
})
Some(context_picker.clone())
})
.trigger_with_tooltip(
IconButton::new("add-context", IconName::Plus)

View File

@@ -75,7 +75,7 @@ impl Default for DebugAccountState {
Self {
enabled: false,
trial_expired: false,
plan: Plan::ZedFree,
plan: Plan::Free,
custom_prompt_usage: RequestUsage {
limit: UsageLimit::Unlimited,
amount: 0,

View File

@@ -1,4 +1,4 @@
use std::{collections::VecDeque, path::Path, sync::Arc};
use std::{collections::VecDeque, path::Path};
use anyhow::{Context as _, anyhow};
use assistant_context_editor::{AssistantContext, SavedContextMetadata};
@@ -34,20 +34,6 @@ impl HistoryEntry {
HistoryEntry::Context(context) => context.mtime.to_utc(),
}
}
pub fn id(&self) -> HistoryEntryId {
match self {
HistoryEntry::Thread(thread) => HistoryEntryId::Thread(thread.id.clone()),
HistoryEntry::Context(context) => HistoryEntryId::Context(context.path.clone()),
}
}
}
/// Generic identifier for a history entry.
#[derive(Clone, PartialEq, Eq)]
pub enum HistoryEntryId {
Thread(ThreadId),
Context(Arc<Path>),
}
#[derive(Clone, Debug)]
@@ -71,8 +57,8 @@ impl Eq for RecentEntry {}
impl RecentEntry {
pub(crate) fn summary(&self, cx: &App) -> SharedString {
match self {
RecentEntry::Thread(_, thread) => thread.read(cx).summary().or_default(),
RecentEntry::Context(context) => context.read(cx).summary().or_default(),
RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(),
RecentEntry::Context(context) => context.read(cx).summary_or_default(),
}
}
}

View File

@@ -338,27 +338,13 @@ impl InlineAssistant {
window: &mut Window,
cx: &mut App,
) {
let (snapshot, initial_selections, newest_selection) = editor.update(cx, |editor, cx| {
let selections = editor.selections.all::<Point>(cx);
let newest_selection = editor.selections.newest::<Point>(cx);
(editor.snapshot(window, cx), selections, newest_selection)
let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
(
editor.snapshot(window, cx),
editor.selections.all::<Point>(cx),
)
});
// Check if there is already an inline assistant that contains the
// newest selection, if there is, focus it
if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
for assist_id in &editor_assists.assist_ids {
let assist = &self.assists[assist_id];
let range = assist.range.to_point(&snapshot.buffer_snapshot);
if range.start.row <= newest_selection.start.row
&& newest_selection.end.row <= range.end.row
{
self.focus_assist(*assist_id, window, cx);
return;
}
}
}
let mut selections = Vec::<Selection<Point>>::new();
let mut newest_selection = None;
for mut selection in initial_selections {
@@ -1421,7 +1407,7 @@ impl InlineAssistant {
enum DeletedLines {}
let mut editor = Editor::for_multibuffer(multi_buffer, None, window, cx);
editor.disable_scrollbars_and_minimap(window, cx);
editor.disable_scrollbars_and_minimap(cx);
editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
editor.set_show_wrap_guides(false, cx);
editor.set_show_gutter(false, cx);

View File

@@ -451,7 +451,7 @@ impl<T: 'static> PromptEditor<T> {
editor.move_to_end(&Default::default(), window, cx)
});
}
} else if self.context_strip.read(cx).has_context_items(cx) {
} else {
self.context_strip.focus_handle(cx).focus(window);
}
}

View File

@@ -200,13 +200,7 @@ impl MessageEditor {
});
let profile_selector = cx.new(|cx| {
ProfileSelector::new(
fs,
thread.clone(),
thread_store,
editor.focus_handle(cx),
cx,
)
ProfileSelector::new(thread.clone(), thread_store, editor.focus_handle(cx), cx)
});
Self {
@@ -401,7 +395,7 @@ impl MessageEditor {
fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
if self.context_picker_menu_handle.is_deployed() {
cx.propagate();
} else if self.context_strip.read(cx).has_context_items(cx) {
} else {
self.context_strip.focus_handle(cx).focus(window);
}
}
@@ -1085,11 +1079,11 @@ impl MessageEditor {
let plan = user_store
.current_plan()
.map(|plan| match plan {
Plan::Free => zed_llm_client::Plan::ZedFree,
Plan::Free => zed_llm_client::Plan::Free,
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
})
.unwrap_or(zed_llm_client::Plan::ZedFree);
.unwrap_or(zed_llm_client::Plan::Free);
let usage = self.thread.read(cx).last_usage().or_else(|| {
maybe!({
let amount = user_store.model_request_usage_amount()?;

View File

@@ -1,13 +1,10 @@
use std::sync::Arc;
use assistant_settings::{
AgentProfile, AgentProfileId, AssistantDockPosition, AssistantSettings, GroupedAgentProfiles,
builtin_profiles,
};
use fs::Fs;
use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
use language_model::LanguageModelRegistry;
use settings::{Settings as _, SettingsStore, update_settings_file};
use settings::{Settings as _, SettingsStore};
use ui::{
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
prelude::*,
@@ -18,7 +15,6 @@ use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
pub struct ProfileSelector {
profiles: GroupedAgentProfiles,
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
thread_store: WeakEntity<ThreadStore>,
menu_handle: PopoverMenuHandle<ContextMenu>,
@@ -28,7 +24,6 @@ pub struct ProfileSelector {
impl ProfileSelector {
pub fn new(
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
thread_store: WeakEntity<ThreadStore>,
focus_handle: FocusHandle,
@@ -40,7 +35,6 @@ impl ProfileSelector {
Self {
profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)),
fs,
thread,
thread_store,
menu_handle: PopoverMenuHandle::default(),
@@ -65,12 +59,8 @@ impl ProfileSelector {
ContextMenu::build(window, cx, |mut menu, _window, cx| {
let settings = AssistantSettings::get_global(cx);
for (profile_id, profile) in self.profiles.builtin.iter() {
menu = menu.item(self.menu_entry_for_profile(
profile_id.clone(),
profile,
settings,
cx,
));
menu =
menu.item(self.menu_entry_for_profile(profile_id.clone(), profile, settings));
}
if !self.profiles.custom.is_empty() {
@@ -80,7 +70,6 @@ impl ProfileSelector {
profile_id.clone(),
profile,
settings,
cx,
));
}
}
@@ -101,7 +90,6 @@ impl ProfileSelector {
profile_id: AgentProfileId,
profile: &AgentProfile,
settings: &AssistantSettings,
_cx: &App,
) -> ContextMenuEntry {
let documentation = match profile.name.to_lowercase().as_str() {
builtin_profiles::WRITE => Some("Get help to write anything."),
@@ -122,15 +110,15 @@ impl ProfileSelector {
};
entry.handler({
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let profile_id = profile_id.clone();
let profile = profile.clone();
let thread = self.thread.clone();
move |_window, cx| {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let profile_id = profile_id.clone();
move |settings, _cx| {
settings.set_profile(profile_id.clone());
}
thread.update(cx, |thread, cx| {
thread.set_configured_profile(Some(profile.clone()), cx);
});
thread_store
@@ -146,8 +134,14 @@ impl ProfileSelector {
impl Render for ProfileSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let settings = AssistantSettings::get_global(cx);
let profile_id = &settings.default_profile;
let profile = settings.profiles.get(profile_id);
let profile = self
.thread
.read_with(cx, |thread, _cx| thread.configured_profile())
.or_else(|| {
let profile_id = &settings.default_profile;
let profile = settings.profiles.get(profile_id);
profile.cloned()
});
let selected_profile = profile
.map(|profile| profile.name.clone())

View File

@@ -191,7 +191,7 @@ impl TerminalInlineAssistant {
};
self.prompt_history.retain(|prompt| *prompt != user_prompt);
self.prompt_history.push_back(user_prompt);
self.prompt_history.push_back(user_prompt.clone());
if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
self.prompt_history.pop_front();
}

View File

@@ -5,7 +5,7 @@ use std::sync::Arc;
use std::time::Instant;
use anyhow::{Result, anyhow};
use assistant_settings::{AssistantSettings, CompletionMode};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
@@ -22,7 +22,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
StopReason, TokenUsage,
};
@@ -36,7 +36,7 @@ use serde::{Deserialize, Serialize};
use settings::Settings;
use thiserror::Error;
use ui::Window;
use util::{ResultExt as _, post_inc};
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionRequestStatus;
@@ -324,7 +324,7 @@ pub enum QueueState {
pub struct Thread {
id: ThreadId,
updated_at: DateTime<Utc>,
summary: ThreadSummary,
summary: Option<SharedString>,
pending_summary: Task<Option<()>>,
detailed_summary_task: Task<Option<()>>,
detailed_summary_tx: postage::watch::Sender<DetailedSummaryState>,
@@ -359,33 +359,7 @@ pub struct Thread {
>,
remaining_turns: u32,
configured_model: Option<ConfiguredModel>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ThreadSummary {
Pending,
Generating,
Ready(SharedString),
Error,
}
impl ThreadSummary {
pub const DEFAULT: SharedString = SharedString::new_static("New Thread");
pub fn or_default(&self) -> SharedString {
self.unwrap_or(Self::DEFAULT)
}
pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
self.ready().unwrap_or_else(|| message.into())
}
pub fn ready(&self) -> Option<SharedString> {
match self {
ThreadSummary::Ready(summary) => Some(summary.clone()),
ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None,
}
}
configured_profile: Option<AgentProfile>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -406,11 +380,14 @@ impl Thread {
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
let assistant_settings = AssistantSettings::get_global(cx);
let profile_id = &assistant_settings.default_profile;
let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
summary: ThreadSummary::Pending,
summary: None,
pending_summary: Task::ready(None),
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
@@ -448,6 +425,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
configured_profile,
}
}
@@ -458,7 +436,7 @@ impl Thread {
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
window: Option<&mut Window>, // None in headless mode
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
@@ -495,10 +473,17 @@ impl Thread {
.completion_mode
.unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
let configured_profile = serialized.profile.and_then(|profile| {
AssistantSettings::get_global(cx)
.profiles
.get(&profile)
.cloned()
});
Self {
id,
updated_at: serialized.updated_at,
summary: ThreadSummary::Ready(serialized.summary),
summary: Some(serialized.summary),
pending_summary: Task::ready(None),
detailed_summary_task: Task::ready(None),
detailed_summary_tx,
@@ -568,6 +553,7 @@ impl Thread {
request_callback: None,
remaining_turns: u32::MAX,
configured_model,
configured_profile,
}
}
@@ -599,6 +585,10 @@ impl Thread {
self.last_prompt_id = PromptId::new();
}
pub fn summary(&self) -> Option<SharedString> {
self.summary.clone()
}
pub fn project_context(&self) -> SharedProjectContext {
self.project_context.clone()
}
@@ -619,25 +609,39 @@ impl Thread {
cx.notify();
}
pub fn summary(&self) -> &ThreadSummary {
&self.summary
pub fn configured_profile(&self) -> Option<AgentProfile> {
self.configured_profile.clone()
}
pub fn set_configured_profile(
&mut self,
profile: Option<AgentProfile>,
cx: &mut Context<Self>,
) {
self.configured_profile = profile;
cx.notify();
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
}
pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
let current_summary = match &self.summary {
ThreadSummary::Pending | ThreadSummary::Generating => return,
ThreadSummary::Ready(summary) => summary,
ThreadSummary::Error => &ThreadSummary::DEFAULT,
let Some(current_summary) = &self.summary else {
// Don't allow setting summary until generated
return;
};
let mut new_summary = new_summary.into();
if new_summary.is_empty() {
new_summary = ThreadSummary::DEFAULT;
new_summary = Self::DEFAULT_SUMMARY;
}
if current_summary != &new_summary {
self.summary = ThreadSummary::Ready(new_summary);
self.summary = Some(new_summary);
cx.emit(ThreadEvent::SummaryChanged);
}
}
@@ -880,13 +884,7 @@ impl Thread {
}
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
match &self.tool_use.tool_result(id)?.content {
LanguageModelToolResultContent::Text(str) => Some(str),
LanguageModelToolResultContent::Image(_) => {
// TODO: We should display image
None
}
}
Some(&self.tool_use.tool_result(id)?.content)
}
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
@@ -1057,7 +1055,7 @@ impl Thread {
let initial_project_snapshot = initial_project_snapshot.await;
this.read_with(cx, |this, cx| SerializedThread {
version: SerializedThread::VERSION.to_string(),
summary: this.summary().or_default(),
summary: this.summary_or_default(),
updated_at: this.updated_at(),
messages: this
.messages()
@@ -1128,6 +1126,10 @@ impl Thread {
provider: model.provider.id().0.to_string(),
model: model.model.id().0.to_string(),
}),
profile: this
.configured_profile
.as_ref()
.map(|profile| AgentProfileId(profile.name.clone().into())),
completion_mode: Some(this.completion_mode),
})
})
@@ -1653,7 +1655,7 @@ impl Thread {
// If there is a response without tool use, summarize the message. Otherwise,
// allow two tool uses before summarizing.
if matches!(thread.summary, ThreadSummary::Pending)
if thread.summary.is_none()
&& thread.messages.len() >= 2
&& (!thread.has_pending_tool_uses() || thread.messages.len() >= 6)
{
@@ -1688,6 +1690,10 @@ impl Thread {
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
} else if let Some(error) =
error.downcast_ref::<ModelRequestLimitReachedError>()
{
@@ -1763,7 +1769,6 @@ impl Thread {
pub fn summarize(&mut self, cx: &mut Context<Self>) {
let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
println!("No thread summary model");
return;
};
@@ -1778,17 +1783,13 @@ impl Thread {
let request = self.to_summarize_request(&model.model, added_user_message.into(), cx);
self.summary = ThreadSummary::Generating;
self.pending_summary = cx.spawn(async move |this, cx| {
let result = async {
async move {
let mut messages = model.model.stream_completion(request, &cx).await?;
let mut new_summary = String::new();
while let Some(event) = messages.next().await {
let Ok(event) = event else {
continue;
};
let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
@@ -1814,29 +1815,18 @@ impl Thread {
}
}
anyhow::Ok(new_summary)
this.update(cx, |this, cx| {
if !new_summary.is_empty() {
this.summary = Some(new_summary.into());
}
cx.emit(ThreadEvent::SummaryGenerated);
})?;
anyhow::Ok(())
}
.await;
this.update(cx, |this, cx| {
match result {
Ok(new_summary) => {
if new_summary.is_empty() {
this.summary = ThreadSummary::Error;
} else {
this.summary = ThreadSummary::Ready(new_summary.into());
}
}
Err(err) => {
this.summary = ThreadSummary::Error;
log::error!("Failed to generate thread summary: {}", err);
}
}
cx.emit(ThreadEvent::SummaryGenerated);
})
.log_err()?;
Some(())
.log_err()
.await
});
}
@@ -2243,7 +2233,7 @@ impl Thread {
.read(cx)
.enabled_tools(cx)
.iter()
.map(|tool| tool.name())
.map(|tool| tool.name().to_string())
.collect();
self.message_feedback.insert(message_id, feedback);
@@ -2446,8 +2436,9 @@ impl Thread {
pub fn to_markdown(&self, cx: &App) -> Result<String> {
let mut markdown = Vec::new();
let summary = self.summary().or_default();
writeln!(markdown, "# {summary}\n")?;
if let Some(summary) = self.summary() {
writeln!(markdown, "# {summary}\n")?;
};
for message in self.messages() {
writeln!(
@@ -2504,22 +2495,7 @@ impl Thread {
}
writeln!(markdown, "**\n")?;
match &tool_result.content {
LanguageModelToolResultContent::Text(str) => {
writeln!(markdown, "{}", str)?;
}
LanguageModelToolResultContent::Image(image) => {
writeln!(markdown, "![Image](data:base64,{})", image.source)?;
}
}
if let Some(output) = tool_result.output.as_ref() {
writeln!(
markdown,
"\n\nDebug Output:\n\n```json\n{}\n```\n",
serde_json::to_string_pretty(output)?
)?;
}
writeln!(markdown, "{}", tool_result.content)?;
}
}
@@ -2583,7 +2559,7 @@ impl Thread {
.read(cx)
.current_user()
.map(|user| user.github_login.clone());
let client = self.project.read(cx).client();
let client = self.project.read(cx).client().clone();
let serialize_task = self.serialize(cx);
cx.background_executor()
@@ -2702,6 +2678,8 @@ impl Thread {
pub enum ThreadError {
#[error("Payment required")]
PaymentRequired,
#[error("Max monthly spend reached")]
MaxMonthlySpendReached,
#[error("Model request limit reached")]
ModelRequestLimitReached { plan: Plan },
#[error("Message {header}: {message}")]
@@ -2770,7 +2748,7 @@ mod tests {
use assistant_tool::ToolRegistry;
use editor::EditorSettings;
use gpui::TestAppContext;
use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider};
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
use prompt_store::PromptBuilder;
use serde_json::json;
@@ -3271,196 +3249,6 @@ fn main() {{
assert_eq!(request.temperature, None);
}
#[gpui::test]
async fn test_thread_summary(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
// Initial state should be pending
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Pending));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
// Manually setting the summary should not be allowed in this state
thread.update(cx, |thread, cx| {
thread.set_summary("This should not work", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Pending));
});
// Send a message
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(model.clone(), None, cx);
});
let fake_model = model.as_fake();
simulate_successful_response(&fake_model, cx);
// Should start generating summary when there are >= 2 messages
thread.read_with(cx, |thread, _| {
assert_eq!(*thread.summary(), ThreadSummary::Generating);
});
// Should not be able to set the summary while generating
thread.update(cx, |thread, cx| {
thread.set_summary("This should not work either", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
cx.run_until_parked();
fake_model.stream_last_completion_response("Brief".into());
fake_model.stream_last_completion_response(" Introduction".into());
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Summary should be set
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), "Brief Introduction");
});
// Now we should be able to set a summary
thread.update(cx, |thread, cx| {
thread.set_summary("Brief Intro", cx);
});
thread.read_with(cx, |thread, _| {
assert_eq!(thread.summary().or_default(), "Brief Intro");
});
// Test setting an empty summary (should default to DEFAULT)
thread.update(cx, |thread, cx| {
thread.set_summary("", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
}
#[gpui::test]
async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
test_summarize_error(&model, &thread, cx);
// Now we should be able to set a summary
thread.update(cx, |thread, cx| {
thread.set_summary("Brief Intro", cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), "Brief Intro");
});
}
#[gpui::test]
async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
init_test_settings(cx);
let project = create_test_project(cx, json!({})).await;
let (_, _thread_store, thread, _context_store, model) =
setup_test_environment(cx, project.clone()).await;
test_summarize_error(&model, &thread, cx);
// Sending another message should not trigger another summarize request
thread.update(cx, |thread, cx| {
thread.insert_user_message(
"How are you?",
ContextLoadResult::default(),
None,
vec![],
cx,
);
thread.send_to_model(model.clone(), None, cx);
});
let fake_model = model.as_fake();
simulate_successful_response(&fake_model, cx);
thread.read_with(cx, |thread, _| {
// State is still Error, not Generating
assert!(matches!(thread.summary(), ThreadSummary::Error));
});
// But the summarize request can be invoked manually
thread.update(cx, |thread, cx| {
thread.summarize(cx);
});
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
});
cx.run_until_parked();
fake_model.stream_last_completion_response("A successful summary".into());
fake_model.end_last_completion_stream();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Ready(_)));
assert_eq!(thread.summary().or_default(), "A successful summary");
});
}
fn test_summarize_error(
model: &Arc<dyn LanguageModel>,
thread: &Entity<Thread>,
cx: &mut TestAppContext,
) {
thread.update(cx, |thread, cx| {
thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx);
thread.send_to_model(model.clone(), None, cx);
});
let fake_model = model.as_fake();
simulate_successful_response(&fake_model, cx);
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Generating));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
// Simulate summary request ending
cx.run_until_parked();
fake_model.end_last_completion_stream();
cx.run_until_parked();
// State is set to Error and default message
thread.read_with(cx, |thread, _| {
assert!(matches!(thread.summary(), ThreadSummary::Error));
assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT);
});
}
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.stream_last_completion_response("Assistant response".into());
fake_model.end_last_completion_stream();
cx.run_until_parked();
}
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
@@ -3517,29 +3305,9 @@ fn main() {{
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
let provider = Arc::new(FakeLanguageModelProvider);
let model = provider.test_model();
let model = FakeLanguageModel::default();
let model: Arc<dyn LanguageModel> = Arc::new(model);
cx.update(|_, cx| {
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(
Some(ConfiguredModel {
provider: provider.clone(),
model: model.clone(),
}),
cx,
);
registry.set_thread_summary_model(
Some(ConfiguredModel {
provider,
model: model.clone(),
}),
cx,
);
})
});
(workspace, thread_store, thread, context_store, model)
}

View File

@@ -2,11 +2,12 @@ use std::fmt::Display;
use std::ops::Range;
use std::sync::Arc;
use assistant_context_editor::SavedContextMetadata;
use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
use editor::{Editor, EditorEvent};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, ClickEvent, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
App, Empty, Entity, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
UniformListScrollHandle, WeakEntity, Window, uniform_list,
};
use time::{OffsetDateTime, UtcOffset};
@@ -17,6 +18,7 @@ use ui::{
use util::ResultExt;
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::thread_store::SerializedThreadMetadata;
use crate::{AgentPanel, RemoveSelectedThread};
pub struct ThreadHistory {
@@ -24,14 +26,11 @@ pub struct ThreadHistory {
history_store: Entity<HistoryStore>,
scroll_handle: UniformListScrollHandle,
selected_index: usize,
hovered_index: Option<usize>,
search_editor: Entity<Editor>,
all_entries: Arc<Vec<HistoryEntry>>,
// When the search is empty, we display date separators between history entries
// This vector contains an enum of either a separator or an actual entry
separated_items: Vec<ListItemType>,
// Maps entry indexes to list item indexes
separated_item_indexes: Vec<u32>,
separated_items: Vec<HistoryListItem>,
_separated_items_task: Option<Task<()>>,
search_state: SearchState,
scrollbar_visibility: bool,
@@ -51,7 +50,7 @@ enum SearchState {
},
}
enum ListItemType {
enum HistoryListItem {
BucketSeparator(TimeBucket),
Entry {
index: usize,
@@ -59,11 +58,11 @@ enum ListItemType {
},
}
impl ListItemType {
impl HistoryListItem {
fn entry_index(&self) -> Option<usize> {
match self {
ListItemType::BucketSeparator(_) => None,
ListItemType::Entry { index, .. } => Some(*index),
HistoryListItem::BucketSeparator(_) => None,
HistoryListItem::Entry { index, .. } => Some(*index),
}
}
}
@@ -101,11 +100,9 @@ impl ThreadHistory {
history_store,
scroll_handle,
selected_index: 0,
hovered_index: None,
search_state: SearchState::Empty,
all_entries: Default::default(),
separated_items: Default::default(),
separated_item_indexes: Default::default(),
search_editor,
scrollbar_visibility: true,
scrollbar_state,
@@ -117,21 +114,35 @@ impl ThreadHistory {
}
fn update_all_entries(&mut self, cx: &mut Context<Self>) {
let new_entries: Arc<Vec<HistoryEntry>> = self
self.all_entries = self
.history_store
.update(cx, |store, cx| store.entries(cx))
.into();
self.set_selected_index(0, cx);
self.update_separated_items(cx);
match &self.search_state {
SearchState::Empty => {}
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
self.search(query.clone(), cx);
}
}
cx.notify();
}
fn update_separated_items(&mut self, cx: &mut Context<Self>) {
self._separated_items_task.take();
let mut items = Vec::with_capacity(new_entries.len() + 1);
let mut indexes = Vec::with_capacity(new_entries.len() + 1);
let mut separated_items = std::mem::take(&mut self.separated_items);
separated_items.clear();
let all_entries = self.all_entries.clone();
let bg_task = cx.background_spawn(async move {
let mut bucket = None;
let today = Local::now().naive_local().date();
for (index, entry) in new_entries.iter().enumerate() {
for (index, entry) in all_entries.iter().enumerate() {
let entry_date = entry
.updated_at()
.with_timezone(&Local)
@@ -141,50 +152,20 @@ impl ThreadHistory {
if Some(entry_bucket) != bucket {
bucket = Some(entry_bucket);
items.push(ListItemType::BucketSeparator(entry_bucket));
separated_items.push(HistoryListItem::BucketSeparator(entry_bucket));
}
indexes.push(items.len() as u32);
items.push(ListItemType::Entry {
separated_items.push(HistoryListItem::Entry {
index,
format: entry_bucket.into(),
});
}
(new_entries, items, indexes)
separated_items
});
let task = cx.spawn(async move |this, cx| {
let (new_entries, items, indexes) = bg_task.await;
let separated_items = bg_task.await;
this.update(cx, |this, cx| {
let previously_selected_entry =
this.all_entries.get(this.selected_index).map(|e| e.id());
this.all_entries = new_entries;
this.separated_items = items;
this.separated_item_indexes = indexes;
match &this.search_state {
SearchState::Empty => {
if this.selected_index >= this.all_entries.len() {
this.set_selected_entry_index(
this.all_entries.len().saturating_sub(1),
cx,
);
} else if let Some(prev_id) = previously_selected_entry {
if let Some(new_ix) = this
.all_entries
.iter()
.position(|probe| probe.id() == prev_id)
{
this.set_selected_entry_index(new_ix, cx);
}
}
}
SearchState::Searching { query, .. } | SearchState::Searched { query, .. } => {
this.search(query.clone(), cx);
}
}
this.separated_items = separated_items;
cx.notify();
})
.log_err();
@@ -252,7 +233,7 @@ impl ThreadHistory {
matches,
};
this.set_selected_entry_index(0, cx);
this.set_selected_index(0, cx);
cx.notify();
};
})
@@ -260,7 +241,10 @@ impl ThreadHistory {
}
});
self.search_state = SearchState::Searching { query, _task: task };
self.search_state = SearchState::Searching {
query: query.clone(),
_task: task,
};
cx.notify();
}
@@ -307,9 +291,9 @@ impl ThreadHistory {
let count = self.matched_count();
if count > 0 {
if self.selected_index == 0 {
self.set_selected_entry_index(count - 1, cx);
self.set_selected_index(count - 1, cx);
} else {
self.set_selected_entry_index(self.selected_index - 1, cx);
self.set_selected_index(self.selected_index - 1, cx);
}
}
}
@@ -323,9 +307,9 @@ impl ThreadHistory {
let count = self.matched_count();
if count > 0 {
if self.selected_index == count - 1 {
self.set_selected_entry_index(0, cx);
self.set_selected_index(0, cx);
} else {
self.set_selected_entry_index(self.selected_index + 1, cx);
self.set_selected_index(self.selected_index + 1, cx);
}
}
}
@@ -338,32 +322,21 @@ impl ThreadHistory {
) {
let count = self.matched_count();
if count > 0 {
self.set_selected_entry_index(0, cx);
self.set_selected_index(0, cx);
}
}
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
let count = self.matched_count();
if count > 0 {
self.set_selected_entry_index(count - 1, cx);
self.set_selected_index(count - 1, cx);
}
}
fn set_selected_entry_index(&mut self, entry_index: usize, cx: &mut Context<Self>) {
self.selected_index = entry_index;
let scroll_ix = match self.search_state {
SearchState::Empty | SearchState::Searching { .. } => self
.separated_item_indexes
.get(entry_index)
.map(|ix| *ix as usize)
.unwrap_or(entry_index + 1),
SearchState::Searched { .. } => entry_index,
};
fn set_selected_index(&mut self, index: usize, cx: &mut Context<Self>) {
self.selected_index = index;
self.scroll_handle
.scroll_to_item(scroll_ix, ScrollStrategy::Top);
.scroll_to_item(index, ScrollStrategy::Top);
cx.notify();
}
@@ -472,7 +445,7 @@ impl ThreadHistory {
.map(|(ix, m)| {
self.render_list_item(
Some(range_start + ix),
&ListItemType::Entry {
&HistoryListItem::Entry {
index: m.candidate_id,
format: EntryTimeFormat::DateAndTime,
},
@@ -490,36 +463,25 @@ impl ThreadHistory {
fn render_list_item(
&self,
list_entry_ix: Option<usize>,
item: &ListItemType,
item: &HistoryListItem,
highlight_positions: Vec<usize>,
cx: &Context<Self>,
cx: &App,
) -> AnyElement {
match item {
ListItemType::Entry { index, format } => match self.all_entries.get(*index) {
HistoryListItem::Entry { index, format } => match self.all_entries.get(*index) {
Some(entry) => h_flex()
.w_full()
.pb_1()
.child(
HistoryEntryElement::new(entry.clone(), self.agent_panel.clone())
.highlight_positions(highlight_positions)
.timestamp_format(*format)
.selected(list_entry_ix == Some(self.selected_index))
.hovered(list_entry_ix == self.hovered_index)
.on_hover(cx.listener(move |this, is_hovered, _window, cx| {
if *is_hovered {
this.hovered_index = list_entry_ix;
} else if this.hovered_index == list_entry_ix {
this.hovered_index = None;
}
cx.notify();
}))
.into_any_element(),
)
.child(self.render_history_entry(
entry,
list_entry_ix == Some(self.selected_index),
highlight_positions,
*format,
))
.into_any(),
None => Empty.into_any_element(),
},
ListItemType::BucketSeparator(bucket) => div()
HistoryListItem::BucketSeparator(bucket) => div()
.px(DynamicSpacing::Base06.rems(cx))
.pt_2()
.pb_1()
@@ -531,6 +493,33 @@ impl ThreadHistory {
.into_any_element(),
}
}
fn render_history_entry(
&self,
entry: &HistoryEntry,
is_active: bool,
highlight_positions: Vec<usize>,
format: EntryTimeFormat,
) -> AnyElement {
match entry {
HistoryEntry::Thread(thread) => PastThread::new(
thread.clone(),
self.agent_panel.clone(),
is_active,
highlight_positions,
format,
)
.into_any_element(),
HistoryEntry::Context(context) => PastContext::new(
context.clone(),
self.agent_panel.clone(),
is_active,
highlight_positions,
format,
)
.into_any_element(),
}
}
}
impl Focusable for ThreadHistory {
@@ -612,97 +601,155 @@ impl Render for ThreadHistory {
}
#[derive(IntoElement)]
pub struct HistoryEntryElement {
entry: HistoryEntry,
pub struct PastThread {
thread: SerializedThreadMetadata,
agent_panel: WeakEntity<AgentPanel>,
selected: bool,
hovered: bool,
highlight_positions: Vec<usize>,
timestamp_format: EntryTimeFormat,
on_hover: Box<dyn Fn(&bool, &mut Window, &mut App) + 'static>,
}
impl HistoryEntryElement {
pub fn new(entry: HistoryEntry, agent_panel: WeakEntity<AgentPanel>) -> Self {
impl PastThread {
pub fn new(
thread: SerializedThreadMetadata,
agent_panel: WeakEntity<AgentPanel>,
selected: bool,
highlight_positions: Vec<usize>,
timestamp_format: EntryTimeFormat,
) -> Self {
Self {
entry,
thread,
agent_panel,
selected: false,
hovered: false,
highlight_positions: vec![],
timestamp_format: EntryTimeFormat::DateAndTime,
on_hover: Box::new(|_, _, _| {}),
selected,
highlight_positions,
timestamp_format,
}
}
pub fn selected(mut self, selected: bool) -> Self {
self.selected = selected;
self
}
pub fn hovered(mut self, hovered: bool) -> Self {
self.hovered = hovered;
self
}
pub fn highlight_positions(mut self, positions: Vec<usize>) -> Self {
self.highlight_positions = positions;
self
}
pub fn on_hover(mut self, on_hover: impl Fn(&bool, &mut Window, &mut App) + 'static) -> Self {
self.on_hover = Box::new(on_hover);
self
}
pub fn timestamp_format(mut self, format: EntryTimeFormat) -> Self {
self.timestamp_format = format;
self
}
}
impl RenderOnce for HistoryEntryElement {
impl RenderOnce for PastThread {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let (id, summary, timestamp) = match &self.entry {
HistoryEntry::Thread(thread) => (
thread.id.to_string(),
thread.summary.clone(),
thread.updated_at.timestamp(),
),
HistoryEntry::Context(context) => (
context.path.to_string_lossy().to_string(),
context.title.clone().into(),
context.mtime.timestamp(),
),
};
let summary = self.thread.summary;
let thread_timestamp =
self.timestamp_format
.format_timestamp(&self.agent_panel, timestamp, cx);
let thread_timestamp = self.timestamp_format.format_timestamp(
&self.agent_panel,
self.thread.updated_at.timestamp(),
cx,
);
ListItem::new(SharedString::from(id))
ListItem::new(SharedString::from(self.thread.id.to_string()))
.rounded()
.toggle_state(self.selected)
.spacing(ListItemSpacing::Sparse)
.start_slot(
div().max_w_4_5().child(
HighlightedLabel::new(summary, self.highlight_positions)
.size(LabelSize::Small)
.truncate(),
),
)
.end_slot(
h_flex()
.w_full()
.gap_2()
.justify_between()
.child(
HighlightedLabel::new(summary, self.highlight_positions)
.size(LabelSize::Small)
.truncate(),
)
.gap_1p5()
.child(
Label::new(thread_timestamp)
.color(Color::Muted)
.size(LabelSize::XSmall),
)
.child(
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.tooltip(move |window, cx| {
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
})
.on_click({
let agent_panel = self.agent_panel.clone();
let id = self.thread.id.clone();
move |_event, _window, cx| {
agent_panel
.update(cx, |this, cx| {
this.delete_thread(&id, cx).detach_and_log_err(cx);
})
.ok();
}
}),
),
)
.on_hover(self.on_hover)
.end_slot::<IconButton>(if self.hovered || self.selected {
Some(
.on_click({
let agent_panel = self.agent_panel.clone();
let id = self.thread.id.clone();
move |_event, window, cx| {
agent_panel
.update(cx, |this, cx| {
this.open_thread_by_id(&id, window, cx)
.detach_and_log_err(cx);
})
.ok();
}
})
}
}
#[derive(IntoElement)]
pub struct PastContext {
context: SavedContextMetadata,
agent_panel: WeakEntity<AgentPanel>,
selected: bool,
highlight_positions: Vec<usize>,
timestamp_format: EntryTimeFormat,
}
impl PastContext {
pub fn new(
context: SavedContextMetadata,
agent_panel: WeakEntity<AgentPanel>,
selected: bool,
highlight_positions: Vec<usize>,
timestamp_format: EntryTimeFormat,
) -> Self {
Self {
context,
agent_panel,
selected,
highlight_positions,
timestamp_format,
}
}
}
impl RenderOnce for PastContext {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let summary = self.context.title;
let context_timestamp = self.timestamp_format.format_timestamp(
&self.agent_panel,
self.context.mtime.timestamp(),
cx,
);
ListItem::new(SharedString::from(
self.context.path.to_string_lossy().to_string(),
))
.rounded()
.toggle_state(self.selected)
.spacing(ListItemSpacing::Sparse)
.start_slot(
div().max_w_4_5().child(
HighlightedLabel::new(summary, self.highlight_positions)
.size(LabelSize::Small)
.truncate(),
),
)
.end_slot(
h_flex()
.gap_1p5()
.child(
Label::new(context_timestamp)
.color(Color::Muted)
.size(LabelSize::XSmall),
)
.child(
IconButton::new("delete", IconName::TrashAlt)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
@@ -712,70 +759,30 @@ impl RenderOnce for HistoryEntryElement {
})
.on_click({
let agent_panel = self.agent_panel.clone();
let f: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static> =
match &self.entry {
HistoryEntry::Thread(thread) => {
let id = thread.id.clone();
Box::new(move |_event, _window, cx| {
agent_panel
.update(cx, |this, cx| {
this.delete_thread(&id, cx)
.detach_and_log_err(cx);
})
.ok();
})
}
HistoryEntry::Context(context) => {
let path = context.path.clone();
Box::new(move |_event, _window, cx| {
agent_panel
.update(cx, |this, cx| {
this.delete_context(path.clone(), cx)
.detach_and_log_err(cx);
})
.ok();
})
}
};
f
let path = self.context.path.clone();
move |_event, _window, cx| {
agent_panel
.update(cx, |this, cx| {
this.delete_context(path.clone(), cx)
.detach_and_log_err(cx);
})
.ok();
}
}),
)
} else {
None
})
.on_click({
let agent_panel = self.agent_panel.clone();
let f: Box<dyn Fn(&ClickEvent, &mut Window, &mut App) + 'static> = match &self.entry
{
HistoryEntry::Thread(thread) => {
let id = thread.id.clone();
Box::new(move |_event, window, cx| {
agent_panel
.update(cx, |this, cx| {
this.open_thread_by_id(&id, window, cx)
.detach_and_log_err(cx);
})
.ok();
})
}
HistoryEntry::Context(context) => {
let path = context.path.clone();
Box::new(move |_event, window, cx| {
agent_panel
.update(cx, |this, cx| {
this.open_saved_prompt_editor(path.clone(), window, cx)
.detach_and_log_err(cx);
})
.ok();
})
}
};
f
})
),
)
.on_click({
let agent_panel = self.agent_panel.clone();
let path = self.context.path.clone();
move |_event, window, cx| {
agent_panel
.update(cx, |this, cx| {
this.open_saved_prompt_editor(path.clone(), window, cx)
.detach_and_log_err(cx);
})
.ok();
}
})
}
}

View File

@@ -19,7 +19,7 @@ use gpui::{
};
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage};
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
@@ -386,25 +386,6 @@ impl ThreadStore {
})
}
pub fn create_thread_from_serialized(
&mut self,
serialized: SerializedThread,
cx: &mut Context<Self>,
) -> Entity<Thread> {
cx.new(|cx| {
Thread::deserialize(
ThreadId::new(),
serialized,
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
None,
cx,
)
})
}
pub fn open_thread(
&self,
id: &ThreadId,
@@ -430,7 +411,7 @@ impl ThreadStore {
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
Some(window),
window,
cx,
)
})
@@ -505,8 +486,8 @@ impl ThreadStore {
ToolSource::Native,
&profile
.tools
.into_iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool))
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
);
@@ -530,32 +511,32 @@ impl ThreadStore {
});
}
// Enable all the tools from all context servers, but disable the ones that are explicitly disabled
for (context_server_id, preset) in profile.context_servers {
for (context_server_id, preset) in &profile.context_servers {
self.tools.update(cx, |tools, cx| {
tools.disable(
ToolSource::ContextServer {
id: context_server_id.into(),
id: context_server_id.clone().into(),
},
&preset
.tools
.into_iter()
.filter_map(|(tool, enabled)| (!enabled).then(|| tool))
.iter()
.filter_map(|(tool, enabled)| (!enabled).then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
})
}
} else {
for (context_server_id, preset) in profile.context_servers {
for (context_server_id, preset) in &profile.context_servers {
self.tools.update(cx, |tools, cx| {
tools.enable(
ToolSource::ContextServer {
id: context_server_id.into(),
id: context_server_id.clone().into(),
},
&preset
.tools
.into_iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool))
.iter()
.filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
.collect::<Vec<_>>(),
cx,
)
@@ -676,6 +657,8 @@ pub struct SerializedThread {
pub model: Option<SerializedLanguageModel>,
#[serde(default)]
pub completion_mode: Option<CompletionMode>,
#[serde(default)]
pub profile: Option<AgentProfileId>,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -794,7 +777,7 @@ pub struct SerializedToolUse {
pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
pub content: LanguageModelToolResultContent,
pub content: Arc<str>,
pub output: Option<serde_json::Value>,
}
@@ -821,6 +804,7 @@ impl LegacySerializedThread {
exceeded_window_error: None,
model: None,
completion_mode: None,
profile: None,
}
}
}

View File

@@ -1,16 +1,14 @@
use std::sync::Arc;
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use assistant_tool::{AnyToolCard, Tool, ToolResultOutput, ToolUseStatus, ToolWorkingSet};
use collections::HashMap;
use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
LanguageModelToolUse, LanguageModelToolUseId, Role,
};
use project::Project;
use ui::{IconName, Window};
@@ -54,19 +52,15 @@ impl ToolUseState {
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
///
/// If `window` is `None` (e.g., when in headless mode or when running evals),
/// tool cards won't be deserialized
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
project: Entity<Project>,
window: Option<&mut Window>, // None in headless mode
window: &mut Window,
cx: &mut App,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
let mut window = window;
for message in messages {
match message.role {
@@ -111,17 +105,12 @@ impl ToolUseState {
},
);
if let Some(window) = &mut window {
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) = tool.deserialize_card(
output,
project.clone(),
window,
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) =
tool.deserialize_card(output, project.clone(), window, cx)
{
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
@@ -176,16 +165,10 @@ impl ToolUseState {
let status = (|| {
if let Some(tool_result) = tool_result {
let content = tool_result
.content
.to_str()
.map(|str| str.to_owned().into())
.unwrap_or_default();
return if tool_result.is_error {
ToolUseStatus::Error(content)
ToolUseStatus::Error(tool_result.content.clone().into())
} else {
ToolUseStatus::Finished(content)
ToolUseStatus::Finished(tool_result.content.clone().into())
};
}
@@ -416,45 +399,21 @@ impl ToolUseState {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
// Protect from overly large output
// Protect from clearly large output
let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);
let content = match tool_result {
ToolResultContent::Text(text) => {
let text = if text.len() < tool_output_limit {
text
} else {
let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
format!(
"Tool result too long. The first {} bytes:\n\n{}",
truncated.len(),
truncated
)
};
LanguageModelToolResultContent::Text(text.into())
}
ToolResultContent::Image(language_model_image) => {
if language_model_image.estimate_tokens() < tool_output_limit {
LanguageModelToolResultContent::Image(language_model_image)
} else {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
is_error: true,
output: None,
},
);
let tool_result = if tool_result.len() <= tool_output_limit {
tool_result
} else {
let truncated = truncate_lines_to_byte_limit(&tool_result, tool_output_limit);
return old_use;
}
}
format!(
"Tool result too long. The first {} bytes:\n\n{}",
truncated.len(),
truncated
)
};
self.tool_results.insert(
@@ -462,13 +421,12 @@ impl ToolUseState {
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content,
content: tool_result.into(),
is_error: false,
output: output.output,
},
);
old_use
self.pending_tool_uses_by_id.remove(&tool_use_id)
}
Err(err) => {
self.tool_results.insert(
@@ -476,7 +434,7 @@ impl ToolUseState {
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: LanguageModelToolResultContent::Text(err.to_string().into()),
content: err.to_string().into(),
is_error: true,
output: None,
},

View File

@@ -1,8 +1,8 @@
use std::sync::OnceLock;
use collections::HashMap;
use component::ComponentId;
use gpui::{App, Entity, WeakEntity};
use linkme::distributed_slice;
use std::sync::OnceLock;
use ui::{AnyElement, Component, ComponentScope, Window};
use workspace::Workspace;
@@ -12,15 +12,9 @@ use crate::ActiveThread;
pub type PreviewFn =
fn(WeakEntity<Workspace>, Entity<ActiveThread>, &mut Window, &mut App) -> Option<AnyElement>;
pub struct AgentPreviewFn(fn() -> (ComponentId, PreviewFn));
impl AgentPreviewFn {
pub const fn new(f: fn() -> (ComponentId, PreviewFn)) -> Self {
Self(f)
}
}
inventory::collect!(AgentPreviewFn);
/// Distributed slice for preview registration functions
#[distributed_slice]
pub static __ALL_AGENT_PREVIEWS: [fn() -> (ComponentId, PreviewFn)] = [..];
/// Trait that must be implemented by components that provide agent previews.
pub trait AgentPreview: Component + Sized {
@@ -42,14 +36,16 @@ pub trait AgentPreview: Component + Sized {
#[macro_export]
macro_rules! register_agent_preview {
($type:ty) => {
inventory::submit! {
$crate::ui::preview::AgentPreviewFn::new(|| {
(
<$type as component::Component>::id(),
<$type as $crate::ui::preview::AgentPreview>::agent_preview,
)
})
}
#[linkme::distributed_slice($crate::ui::preview::__ALL_AGENT_PREVIEWS)]
static __REGISTER_AGENT_PREVIEW: fn() -> (
component::ComponentId,
$crate::ui::preview::PreviewFn,
) = || {
(
<$type as component::Component>::id(),
<$type as $crate::ui::preview::AgentPreview>::agent_preview,
)
};
};
}
@@ -60,8 +56,8 @@ static AGENT_PREVIEW_REGISTRY: OnceLock<HashMap<ComponentId, PreviewFn>> = OnceL
fn get_or_init_registry() -> &'static HashMap<ComponentId, PreviewFn> {
AGENT_PREVIEW_REGISTRY.get_or_init(|| {
let mut map = HashMap::default();
for register_fn in inventory::iter::<AgentPreviewFn>() {
let (id, preview_fn) = (register_fn.0)();
for register_fn in __ALL_AGENT_PREVIEWS.iter() {
let (id, preview_fn) = register_fn();
map.insert(id, preview_fn);
}
map

View File

@@ -39,7 +39,7 @@ impl RenderOnce for UsageCallout {
let (title, message, button_text, url) = if is_limit_reached {
match self.plan {
Plan::ZedFree => (
Plan::Free => (
"Out of free prompts",
"Upgrade to continue, wait for the next reset, or switch to API key."
.to_string(),
@@ -61,7 +61,7 @@ impl RenderOnce for UsageCallout {
}
} else {
match self.plan {
Plan::ZedFree => (
Plan::Free => (
"Reaching free plan limit soon",
format!(
"{remaining} remaining - Upgrade to increase limit, or switch providers",
@@ -120,7 +120,7 @@ impl Component for UsageCallout {
single_example(
"Approaching limit (90%)",
UsageCallout::new(
Plan::ZedFree,
Plan::Free,
RequestUsage {
limit: UsageLimit::Limited(50),
amount: 45, // 90% of limit
@@ -131,7 +131,7 @@ impl Component for UsageCallout {
single_example(
"Limit reached (100%)",
UsageCallout::new(
Plan::ZedFree,
Plan::Free,
RequestUsage {
limit: UsageLimit::Limited(50),
amount: 50, // 100% of limit

View File

@@ -534,26 +534,12 @@ pub enum RequestContent {
ToolResult {
tool_use_id: String,
is_error: bool,
content: ToolResultContent,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolResultContent {
Plain(String),
Multipart(Vec<ToolResultPart>),
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolResultPart {
Text { text: String },
Image { source: ImageSource },
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseContent {

View File

@@ -163,26 +163,20 @@ impl AskPassSession {
#[cfg(unix)]
fn get_shell_safe_zed_path() -> anyhow::Result<String> {
let zed_path = std::env::current_exe()
.context("Failed to determine current executable path for use in askpass")?
.context("Failed to figure out current executable path for use in askpass")?
.to_string_lossy()
// see https://github.com/rust-lang/rust/issues/69343
.trim_end_matches(" (deleted)")
.to_string();
// NOTE: this was previously enabled, however, it caused errors when it shouldn't have
// (see https://github.com/zed-industries/zed/issues/29819)
// The zed path failing to execute within the askpass script results in very vague ssh
// authentication failed errors, so this was done to try and surface a better error
//
// use std::os::unix::fs::MetadataExt;
// let metadata = std::fs::metadata(&zed_path)
// .context("Failed to check metadata of Zed executable path for use in askpass")?;
// let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
// anyhow::ensure!(
// is_executable,
// "Failed to verify Zed executable path for use in askpass"
// );
// sanity check on unix systems that the path exists and is executable
// todo(windows): implement this check for windows (or just use `is-executable` crate)
use std::os::unix::fs::MetadataExt;
let metadata = std::fs::metadata(&zed_path)
.context("Failed to check metadata of Zed executable path for use in askpass")?;
let is_executable = metadata.is_file() && metadata.mode() & 0o111 != 0;
anyhow::ensure!(
is_executable,
"Failed to verify Zed executable path for use in askpass"
);
// As of writing, this can only be fail if the path contains a null byte, which shouldn't be possible
// but shlex has annotated the error as #[non_exhaustive] so we can't make it a compile error if other
// errors are introduced in the future :(

View File

@@ -46,6 +46,7 @@ serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
smol.workspace = true
strum.workspace = true
telemetry_events.workspace = true
text.workspace = true
theme.workspace = true

View File

@@ -2,33 +2,22 @@ mod context;
mod context_editor;
mod context_history;
mod context_store;
mod patch;
mod slash_command;
mod slash_command_picker;
use std::sync::Arc;
use client::Client;
use gpui::{App, Context};
use workspace::Workspace;
use gpui::App;
pub use crate::context::*;
pub use crate::context_editor::*;
pub use crate::context_history::*;
pub use crate::context_store::*;
pub use crate::patch::*;
pub use crate::slash_command::*;
pub fn init(client: Arc<Client>, cx: &mut App) {
pub fn init(client: Arc<Client>, _cx: &mut App) {
context_store::init(&client.into());
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
cx.observe_new(
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
workspace
.register_action(ContextEditor::quote_selection)
.register_action(ContextEditor::insert_selection)
.register_action(ContextEditor::copy_code)
.register_action(ContextEditor::handle_insert_dragged_files);
},
)
.detach();
}

View File

@@ -1,7 +1,8 @@
#[cfg(test)]
mod context_tests;
use anyhow::{Context as _, Result, anyhow, bail};
use crate::patch::{AssistantEdit, AssistantPatch, AssistantPatchStatus};
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_slash_command::{
SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection,
@@ -21,8 +22,8 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{
LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolUseId, MessageContent, PaymentRequiredError, Role, StopReason,
report_assistant_event,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason, report_assistant_event,
};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
@@ -36,6 +37,7 @@ use std::{
iter, mem,
ops::Range,
path::Path,
str::FromStr as _,
sync::Arc,
time::{Duration, Instant},
};
@@ -120,6 +122,14 @@ impl MessageStatus {
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestType {
/// Request a normal chat response from the model.
Chat,
/// Add a preamble to the message, which tells the model to return a structured response that suggests edits.
SuggestEdits,
}
#[derive(Clone, Debug)]
pub enum ContextOperation {
InsertMessage {
@@ -133,7 +143,7 @@ pub enum ContextOperation {
version: clock::Global,
},
UpdateSummary {
summary: ContextSummaryContent,
summary: ContextSummary,
version: clock::Global,
},
SlashCommandStarted {
@@ -203,7 +213,7 @@ impl ContextOperation {
version: language::proto::deserialize_version(&update.version),
}),
proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
summary: ContextSummaryContent {
summary: ContextSummary {
text: update.summary,
done: update.done,
timestamp: language::proto::deserialize_timestamp(
@@ -447,12 +457,17 @@ impl ContextOperation {
pub enum ContextEvent {
ShowAssistError(SharedString),
ShowPaymentRequiredError,
ShowMaxMonthlySpendReachedError,
MessagesEdited,
SummaryChanged,
SummaryGenerated,
StreamedCompletion,
StartedThoughtProcess(Range<language::Anchor>),
EndedThoughtProcess(language::Anchor),
PatchesUpdated {
removed: Vec<Range<language::Anchor>>,
updated: Vec<Range<language::Anchor>>,
},
InvokedSlashCommandChanged {
command_id: InvokedSlashCommandId,
},
@@ -466,73 +481,11 @@ pub enum ContextEvent {
Operation(ContextOperation),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ContextSummary {
Pending,
Content(ContextSummaryContent),
Error,
}
#[derive(Default, Clone, Debug, Eq, PartialEq)]
pub struct ContextSummaryContent {
#[derive(Clone, Default, Debug)]
pub struct ContextSummary {
pub text: String,
pub done: bool,
pub timestamp: clock::Lamport,
}
impl ContextSummary {
pub const DEFAULT: &str = "New Text Thread";
pub fn or_default(&self) -> SharedString {
self.unwrap_or(Self::DEFAULT)
}
pub fn unwrap_or(&self, message: impl Into<SharedString>) -> SharedString {
self.content()
.map_or_else(|| message.into(), |content| content.text.clone().into())
}
pub fn content(&self) -> Option<&ContextSummaryContent> {
match self {
ContextSummary::Content(content) => Some(content),
ContextSummary::Pending | ContextSummary::Error => None,
}
}
fn content_as_mut(&mut self) -> Option<&mut ContextSummaryContent> {
match self {
ContextSummary::Content(content) => Some(content),
ContextSummary::Pending | ContextSummary::Error => None,
}
}
fn content_or_set_empty(&mut self) -> &mut ContextSummaryContent {
match self {
ContextSummary::Content(content) => content,
ContextSummary::Pending | ContextSummary::Error => {
let content = ContextSummaryContent::default();
*self = ContextSummary::Content(content);
self.content_as_mut().unwrap()
}
}
}
pub fn is_pending(&self) -> bool {
matches!(self, ContextSummary::Pending)
}
fn timestamp(&self) -> Option<clock::Lamport> {
match self {
ContextSummary::Content(content) => Some(content.timestamp),
ContextSummary::Pending | ContextSummary::Error => None,
}
}
}
impl PartialOrd for ContextSummary {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.timestamp().partial_cmp(&other.timestamp())
}
timestamp: clock::Lamport,
}
#[derive(Clone, Debug, Eq, PartialEq)]
@@ -652,6 +605,26 @@ struct PendingCompletion {
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct InvokedSlashCommandId(clock::Lamport);
#[derive(Clone, Debug)]
pub struct XmlTag {
pub kind: XmlTagKind,
pub range: Range<text::Anchor>,
pub is_open_tag: bool,
}
#[derive(Copy, Clone, Debug, strum::EnumString, PartialEq, Eq, strum::AsRefStr)]
#[strum(serialize_all = "snake_case")]
pub enum XmlTagKind {
Patch,
Title,
Edit,
Path,
Description,
OldText,
NewText,
Operation,
}
pub struct AssistantContext {
id: ContextId,
timestamp: clock::Lamport,
@@ -668,7 +641,7 @@ pub struct AssistantContext {
message_anchors: Vec<MessageAnchor>,
contents: Vec<Content>,
messages_metadata: HashMap<MessageId, MessageMetadata>,
summary: ContextSummary,
summary: Option<ContextSummary>,
summary_task: Task<Option<()>>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
@@ -680,6 +653,8 @@ pub struct AssistantContext {
_subscriptions: Vec<Subscription>,
telemetry: Option<Arc<Telemetry>>,
language_registry: Arc<LanguageRegistry>,
patches: Vec<AssistantPatch>,
xml_tags: Vec<XmlTag>,
project: Option<Entity<Project>>,
prompt_builder: Arc<PromptBuilder>,
}
@@ -694,6 +669,18 @@ impl ContextAnnotation for ParsedSlashCommand {
}
}
impl ContextAnnotation for AssistantPatch {
fn range(&self) -> &Range<language::Anchor> {
&self.range
}
}
impl ContextAnnotation for XmlTag {
fn range(&self) -> &Range<language::Anchor> {
&self.range
}
}
impl EventEmitter<ContextEvent> for AssistantContext {}
impl AssistantContext {
@@ -755,7 +742,7 @@ impl AssistantContext {
slash_command_output_sections: Vec::new(),
thought_process_output_sections: Vec::new(),
edits_since_last_parse: edits_since_last_slash_command_parse,
summary: ContextSummary::Pending,
summary: None,
summary_task: Task::ready(None),
completion_count: Default::default(),
pending_completions: Default::default(),
@@ -770,6 +757,8 @@ impl AssistantContext {
project,
language_registry,
slash_commands,
patches: Vec::new(),
xml_tags: Vec::new(),
prompt_builder,
};
@@ -814,7 +803,7 @@ impl AssistantContext {
.collect(),
summary: self
.summary
.content()
.as_ref()
.map(|summary| summary.text.clone())
.unwrap_or_default(),
slash_command_output_sections: self
@@ -1000,10 +989,12 @@ impl AssistantContext {
summary: new_summary,
..
} => {
if self.summary.timestamp().map_or(true, |current_timestamp| {
new_summary.timestamp > current_timestamp
}) {
self.summary = ContextSummary::Content(new_summary);
if self
.summary
.as_ref()
.map_or(true, |summary| new_summary.timestamp > summary.timestamp)
{
self.summary = Some(new_summary);
summary_generated = true;
}
}
@@ -1161,8 +1152,50 @@ impl AssistantContext {
self.path.as_ref()
}
pub fn summary(&self) -> &ContextSummary {
&self.summary
pub fn summary(&self) -> Option<&ContextSummary> {
self.summary.as_ref()
}
pub fn patch_containing(&self, position: Point, cx: &App) -> Option<&AssistantPatch> {
let buffer = self.buffer.read(cx);
let index = self.patches.binary_search_by(|patch| {
let patch_range = patch.range.to_point(&buffer);
if position < patch_range.start {
Ordering::Greater
} else if position > patch_range.end {
Ordering::Less
} else {
Ordering::Equal
}
});
if let Ok(ix) = index {
Some(&self.patches[ix])
} else {
None
}
}
pub fn patch_ranges(&self) -> impl Iterator<Item = Range<language::Anchor>> + '_ {
self.patches.iter().map(|patch| patch.range.clone())
}
pub fn patch_for_range(
&self,
range: &Range<language::Anchor>,
cx: &App,
) -> Option<&AssistantPatch> {
let buffer = self.buffer.read(cx);
let index = self.patch_index_for_range(range, buffer).ok()?;
Some(&self.patches[index])
}
fn patch_index_for_range(
&self,
tagged_range: &Range<text::Anchor>,
buffer: &text::BufferSnapshot,
) -> Result<usize, usize> {
self.patches
.binary_search_by(|probe| probe.range.cmp(&tagged_range, buffer))
}
pub fn parsed_slash_commands(&self) -> &[ParsedSlashCommand] {
@@ -1244,7 +1277,7 @@ impl AssistantContext {
let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
return;
};
let request = self.to_completion_request(Some(&model.model), cx);
let request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx);
let debounce = self.token_count.is_some();
self.pending_token_count = cx.spawn(async move |this, cx| {
async move {
@@ -1390,7 +1423,7 @@ impl AssistantContext {
}
let request = {
let mut req = self.to_completion_request(Some(&model), cx);
let mut req = self.to_completion_request(Some(&model), RequestType::Chat, cx);
// Skip the last message because it's likely to change and
// therefore would be a waste to cache.
req.messages.pop();
@@ -1465,6 +1498,8 @@ impl AssistantContext {
let mut removed_parsed_slash_command_ranges = Vec::new();
let mut updated_parsed_slash_commands = Vec::new();
let mut removed_patches = Vec::new();
let mut updated_patches = Vec::new();
while let Some(mut row_range) = row_ranges.next() {
while let Some(next_row_range) = row_ranges.peek() {
if row_range.end >= next_row_range.start {
@@ -1489,6 +1524,13 @@ impl AssistantContext {
cx,
);
self.invalidate_pending_slash_commands(&buffer, cx);
self.reparse_patches_in_range(
start..end,
&buffer,
&mut updated_patches,
&mut removed_patches,
cx,
);
}
if !updated_parsed_slash_commands.is_empty()
@@ -1499,6 +1541,13 @@ impl AssistantContext {
updated: updated_parsed_slash_commands,
});
}
if !updated_patches.is_empty() || !removed_patches.is_empty() {
cx.emit(ContextEvent::PatchesUpdated {
removed: removed_patches,
updated: updated_patches,
});
}
}
fn reparse_slash_commands_in_range(
@@ -1589,6 +1638,267 @@ impl AssistantContext {
}
}
fn reparse_patches_in_range(
&mut self,
range: Range<text::Anchor>,
buffer: &BufferSnapshot,
updated: &mut Vec<Range<text::Anchor>>,
removed: &mut Vec<Range<text::Anchor>>,
cx: &mut Context<Self>,
) {
// Rebuild the XML tags in the edited range.
let intersecting_tags_range =
self.indices_intersecting_buffer_range(&self.xml_tags, range.clone(), cx);
let new_tags = self.parse_xml_tags_in_range(buffer, range.clone(), cx);
self.xml_tags
.splice(intersecting_tags_range.clone(), new_tags);
// Find which patches intersect the changed range.
let intersecting_patches_range =
self.indices_intersecting_buffer_range(&self.patches, range.clone(), cx);
// Reparse all tags after the last unchanged patch before the change.
let mut tags_start_ix = 0;
if let Some(preceding_unchanged_patch) =
self.patches[..intersecting_patches_range.start].last()
{
tags_start_ix = match self.xml_tags.binary_search_by(|tag| {
tag.range
.start
.cmp(&preceding_unchanged_patch.range.end, buffer)
.then(Ordering::Less)
}) {
Ok(ix) | Err(ix) => ix,
};
}
// Rebuild the patches in the range.
let new_patches = self.parse_patches(tags_start_ix, range.end, buffer, cx);
updated.extend(new_patches.iter().map(|patch| patch.range.clone()));
let removed_patches = self.patches.splice(intersecting_patches_range, new_patches);
removed.extend(
removed_patches
.map(|patch| patch.range)
.filter(|range| !updated.contains(&range)),
);
}
fn parse_xml_tags_in_range(
&self,
buffer: &BufferSnapshot,
range: Range<text::Anchor>,
cx: &App,
) -> Vec<XmlTag> {
let mut messages = self.messages(cx).peekable();
let mut tags = Vec::new();
let mut lines = buffer.text_for_range(range).lines();
let mut offset = lines.offset();
while let Some(line) = lines.next() {
while let Some(message) = messages.peek() {
if offset < message.offset_range.end {
break;
} else {
messages.next();
}
}
let is_assistant_message = messages
.peek()
.map_or(false, |message| message.role == Role::Assistant);
if is_assistant_message {
for (start_ix, _) in line.match_indices('<') {
let mut name_start_ix = start_ix + 1;
let closing_bracket_ix = line[start_ix..].find('>').map(|i| start_ix + i);
if let Some(closing_bracket_ix) = closing_bracket_ix {
let end_ix = closing_bracket_ix + 1;
let mut is_open_tag = true;
if line[name_start_ix..closing_bracket_ix].starts_with('/') {
name_start_ix += 1;
is_open_tag = false;
}
let tag_inner = &line[name_start_ix..closing_bracket_ix];
let tag_name_len = tag_inner
.find(|c: char| c.is_whitespace())
.unwrap_or(tag_inner.len());
if let Ok(kind) = XmlTagKind::from_str(&tag_inner[..tag_name_len]) {
tags.push(XmlTag {
range: buffer.anchor_after(offset + start_ix)
..buffer.anchor_before(offset + end_ix),
is_open_tag,
kind,
});
};
}
}
}
offset = lines.offset();
}
tags
}
fn parse_patches(
&mut self,
tags_start_ix: usize,
buffer_end: text::Anchor,
buffer: &BufferSnapshot,
cx: &App,
) -> Vec<AssistantPatch> {
let mut new_patches = Vec::new();
let mut pending_patch = None;
let mut patch_tag_depth = 0;
let mut tags = self.xml_tags[tags_start_ix..].iter().peekable();
'tags: while let Some(tag) = tags.next() {
if tag.range.start.cmp(&buffer_end, buffer).is_gt() && patch_tag_depth == 0 {
break;
}
if tag.kind == XmlTagKind::Patch && tag.is_open_tag {
patch_tag_depth += 1;
let patch_start = tag.range.start;
let mut edits = Vec::<Result<AssistantEdit>>::new();
let mut patch = AssistantPatch {
range: patch_start..patch_start,
title: String::new().into(),
edits: Default::default(),
status: crate::AssistantPatchStatus::Pending,
};
while let Some(tag) = tags.next() {
if tag.kind == XmlTagKind::Patch && !tag.is_open_tag {
patch_tag_depth -= 1;
if patch_tag_depth == 0 {
patch.range.end = tag.range.end;
// Include the line immediately after this <patch> tag if it's empty.
let patch_end_offset = patch.range.end.to_offset(buffer);
let mut patch_end_chars = buffer.chars_at(patch_end_offset);
if patch_end_chars.next() == Some('\n')
&& patch_end_chars.next().map_or(true, |ch| ch == '\n')
{
let messages = self.messages_for_offsets(
[patch_end_offset, patch_end_offset + 1],
cx,
);
if messages.len() == 1 {
patch.range.end = buffer.anchor_before(patch_end_offset + 1);
}
}
edits.sort_unstable_by(|a, b| {
if let (Ok(a), Ok(b)) = (a, b) {
a.path.cmp(&b.path)
} else {
Ordering::Equal
}
});
patch.edits = edits.into();
patch.status = AssistantPatchStatus::Ready;
new_patches.push(patch);
continue 'tags;
}
}
if tag.kind == XmlTagKind::Title && tag.is_open_tag {
let content_start = tag.range.end;
while let Some(tag) = tags.next() {
if tag.kind == XmlTagKind::Title && !tag.is_open_tag {
let content_end = tag.range.start;
patch.title =
trimmed_text_in_range(buffer, content_start..content_end)
.into();
break;
}
}
}
if tag.kind == XmlTagKind::Edit && tag.is_open_tag {
let mut path = None;
let mut old_text = None;
let mut new_text = None;
let mut operation = None;
let mut description = None;
while let Some(tag) = tags.next() {
if tag.kind == XmlTagKind::Edit && !tag.is_open_tag {
edits.push(AssistantEdit::new(
path,
operation,
old_text,
new_text,
description,
));
break;
}
if tag.is_open_tag
&& [
XmlTagKind::Path,
XmlTagKind::OldText,
XmlTagKind::NewText,
XmlTagKind::Operation,
XmlTagKind::Description,
]
.contains(&tag.kind)
{
let kind = tag.kind;
let content_start = tag.range.end;
if let Some(tag) = tags.peek() {
if tag.kind == kind && !tag.is_open_tag {
let tag = tags.next().unwrap();
let content_end = tag.range.start;
let content = trimmed_text_in_range(
buffer,
content_start..content_end,
);
match kind {
XmlTagKind::Path => path = Some(content),
XmlTagKind::Operation => operation = Some(content),
XmlTagKind::OldText => {
old_text = Some(content).filter(|s| !s.is_empty())
}
XmlTagKind::NewText => {
new_text = Some(content).filter(|s| !s.is_empty())
}
XmlTagKind::Description => {
description =
Some(content).filter(|s| !s.is_empty())
}
_ => {}
}
}
}
}
}
}
}
patch.edits = edits.into();
pending_patch = Some(patch);
}
}
if let Some(mut pending_patch) = pending_patch {
let patch_start = pending_patch.range.start.to_offset(buffer);
if let Some(message) = self.message_for_offset(patch_start, cx) {
if message.anchor_range.end == text::Anchor::MAX {
pending_patch.range.end = text::Anchor::MAX;
} else {
let message_end = buffer.anchor_after(message.offset_range.end - 1);
pending_patch.range.end = message_end;
}
} else {
pending_patch.range.end = text::Anchor::MAX;
}
new_patches.push(pending_patch);
}
new_patches
}
pub fn pending_command_for_position(
&mut self,
position: language::Anchor,
@@ -1993,7 +2303,11 @@ impl AssistantContext {
})
}
pub fn assist(&mut self, cx: &mut Context<Self>) -> Option<MessageAnchor> {
pub fn assist(
&mut self,
request_type: RequestType,
cx: &mut Context<Self>,
) -> Option<MessageAnchor> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry.default_model()?;
let last_message_id = self.get_last_valid_message_id(cx)?;
@@ -2008,7 +2322,7 @@ impl AssistantContext {
// Compute which messages to cache, including the last one.
self.mark_cache_anchors(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(Some(&model), cx);
let request = self.to_completion_request(Some(&model), request_type, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
@@ -2154,6 +2468,12 @@ impl AssistantContext {
metadata.status = MessageStatus::Canceled;
});
Some(error.to_string())
} else if error.is::<MaxMonthlySpendReachedError>() {
cx.emit(ContextEvent::ShowMaxMonthlySpendReachedError);
this.update_metadata(assistant_message_id, cx, |metadata| {
metadata.status = MessageStatus::Canceled;
});
Some(error.to_string())
} else {
let error_message = error
.chain()
@@ -2243,6 +2563,7 @@ impl AssistantContext {
pub fn to_completion_request(
&self,
model: Option<&Arc<dyn LanguageModel>>,
request_type: RequestType,
cx: &App,
) -> LanguageModelRequest {
let buffer = self.buffer.read(cx);
@@ -2323,6 +2644,25 @@ impl AssistantContext {
}
}
if let RequestType::SuggestEdits = request_type {
if let Ok(preamble) = self.prompt_builder.generate_suggest_edits_prompt() {
let last_elem_index = completion_request.messages.len();
completion_request
.messages
.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text(preamble)],
cache: false,
});
// The preamble message should be sent right before the last actual user message.
completion_request
.messages
.swap(last_elem_index, last_elem_index.saturating_sub(1));
}
}
completion_request
}
@@ -2357,6 +2697,17 @@ impl AssistantContext {
ranges.push(message.anchor_range.clone());
}
}
let buffer = self.buffer.read(cx).text_snapshot();
let mut updated = Vec::new();
let mut removed = Vec::new();
for range in ranges {
self.reparse_patches_in_range(range, &buffer, &mut updated, &mut removed, cx);
}
if !updated.is_empty() || !removed.is_empty() {
cx.emit(ContextEvent::PatchesUpdated { removed, updated })
}
}
pub fn update_metadata(
@@ -2629,12 +2980,12 @@ impl AssistantContext {
return;
};
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_pending()) {
if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
if !model.provider.is_authenticated(cx) {
return;
}
let mut request = self.to_completion_request(Some(&model.model), cx);
let mut request = self.to_completion_request(Some(&model.model), RequestType::Chat, cx);
request.messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![
@@ -2646,20 +2997,17 @@ impl AssistantContext {
// If there is no summary, it is set with `done: false` so that "Loading Summary…" can
// be displayed.
match self.summary {
ContextSummary::Pending | ContextSummary::Error => {
self.summary = ContextSummary::Content(ContextSummaryContent {
text: "".to_string(),
done: false,
timestamp: clock::Lamport::default(),
});
replace_old = true;
}
ContextSummary::Content(_) => {}
if self.summary.is_none() {
self.summary = Some(ContextSummary {
text: "".to_string(),
done: false,
timestamp: clock::Lamport::default(),
});
replace_old = true;
}
self.summary_task = cx.spawn(async move |this, cx| {
let result = async {
async move {
let stream = model.model.stream_completion_text(request, &cx);
let mut messages = stream.await?;
@@ -2670,7 +3018,7 @@ impl AssistantContext {
this.update(cx, |this, cx| {
let version = this.version.clone();
let timestamp = this.next_timestamp();
let summary = this.summary.content_or_set_empty();
let summary = this.summary.get_or_insert(ContextSummary::default());
if !replaced && replace_old {
summary.text.clear();
replaced = true;
@@ -2692,19 +3040,10 @@ impl AssistantContext {
}
}
this.read_with(cx, |this, _cx| {
if let Some(summary) = this.summary.content() {
if summary.text.is_empty() {
bail!("Model generated an empty summary");
}
}
Ok(())
})??;
this.update(cx, |this, cx| {
let version = this.version.clone();
let timestamp = this.next_timestamp();
if let Some(summary) = this.summary.content_as_mut() {
if let Some(summary) = this.summary.as_mut() {
summary.done = true;
summary.timestamp = timestamp;
let operation = ContextOperation::UpdateSummary {
@@ -2719,18 +3058,8 @@ impl AssistantContext {
anyhow::Ok(())
}
.await;
if let Err(err) = result {
this.update(cx, |this, cx| {
this.summary = ContextSummary::Error;
cx.emit(ContextEvent::SummaryChanged);
})
.log_err();
log::error!("Error generating context summary: {}", err);
}
Some(())
.log_err()
.await
});
}
}
@@ -2844,7 +3173,7 @@ impl AssistantContext {
let (old_path, summary) = this.read_with(cx, |this, _| {
let path = this.path.clone();
let summary = if let Some(summary) = this.summary.content() {
let summary = if let Some(summary) = this.summary.as_ref() {
if summary.done {
Some(summary.text.clone())
} else {
@@ -2898,12 +3227,39 @@ impl AssistantContext {
pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context<Self>) {
let timestamp = self.next_timestamp();
let summary = self.summary.content_or_set_empty();
let summary = self.summary.get_or_insert(ContextSummary::default());
summary.timestamp = timestamp;
summary.done = true;
summary.text = custom_summary;
cx.emit(ContextEvent::SummaryChanged);
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread");
pub fn summary_or_default(&self) -> SharedString {
self.summary
.as_ref()
.map(|summary| summary.text.clone().into())
.unwrap_or(Self::DEFAULT_SUMMARY)
}
}
fn trimmed_text_in_range(buffer: &BufferSnapshot, range: Range<text::Anchor>) -> String {
let mut is_start = true;
let mut content = buffer
.text_for_range(range)
.map(|mut chunk| {
if is_start {
chunk = chunk.trim_start_matches('\n');
if !chunk.is_empty() {
is_start = false;
}
}
chunk
})
.collect::<String>();
content.truncate(content.trim_end().len());
content
}
#[derive(Debug, Default)]
@@ -3119,7 +3475,7 @@ impl SavedContext {
let timestamp = next_timestamp.tick();
operations.push(ContextOperation::UpdateSummary {
summary: ContextSummaryContent {
summary: ContextSummary {
text: self.summary,
done: true,
timestamp,

View File

@@ -1,6 +1,6 @@
use crate::{
AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, ContextSummary,
InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId,
ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus,
};
use anyhow::Result;
use assistant_slash_command::{
@@ -16,10 +16,7 @@ use futures::{
};
use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*};
use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate};
use language_model::{
ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role,
fake_provider::{FakeLanguageModel, FakeLanguageModelProvider},
};
use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
use project::Project;
@@ -35,10 +32,13 @@ use std::{
rc::Rc,
sync::{Arc, atomic::AtomicBool},
};
use text::{ReplicaId, ToOffset, network::Network};
use text::{OffsetRangeExt as _, ReplicaId, ToOffset, network::Network};
use ui::{IconName, Window};
use unindent::Unindent;
use util::RandomCharIter;
use util::{
RandomCharIter,
test::{generate_marked_text, marked_text_ranges},
};
use workspace::Workspace;
#[gpui::test]
@@ -664,6 +664,401 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
}
}
#[gpui::test]
async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.update(|cx| {
init_test(cx);
cx.update_global(|settings_store: &mut SettingsStore, cx| {
settings_store
.set_user_settings(
r#"{ "assistant": { "enable_experimental_live_diffs": true } }"#,
cx,
)
.unwrap()
})
});
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [Path::new("/root")], cx).await;
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
// Create a new context
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry.clone(),
Some(project),
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
// Insert an assistant message to simulate a response.
let assistant_message_id = context.update(cx, |context, cx| {
let user_message_id = context.messages(cx).next().unwrap().id;
context
.insert_message_after(user_message_id, Role::Assistant, MessageStatus::Done, cx)
.unwrap()
.id
});
// No edit tags
edit(
&context,
"
«one
two
»",
cx,
);
expect_patches(
&context,
"
one
two
",
&[],
cx,
);
// Partial edit step tag is added
edit(
&context,
"
one
two
«
<patch»",
cx,
);
expect_patches(
&context,
"
one
two
<patch",
&[],
cx,
);
// The rest of the step tag is added. The unclosed
// step is treated as incomplete.
edit(
&context,
"
one
two
<patch«>
<edit>»",
cx,
);
expect_patches(
&context,
"
one
two
«<patch>
<edit>»",
&[&[]],
cx,
);
// The full patch is added
edit(
&context,
"
one
two
<patch>
<edit>«
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn one</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
also,»",
cx,
);
expect_patches(
&context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn one</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn one".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
// The step is manually edited.
edit(
&context,
"
one
two
<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>«fn zero»</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
also,",
cx,
);
expect_patches(
&context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
// When setting the message role to User, the steps are cleared.
context.update(cx, |context, cx| {
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
});
expect_patches(
&context,
"
one
two
<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
also,",
&[],
cx,
);
// When setting the message role back to Assistant, the steps are reparsed.
context.update(cx, |context, cx| {
context.cycle_message_roles(HashSet::from_iter([assistant_message_id]), cx);
});
expect_patches(
&context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
// Ensure steps are re-parsed when deserializing.
let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
let deserialized_context = cx.new(|cx| {
AssistantContext::deserialize(
serialized_context,
Path::new("").into(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
None,
None,
cx,
)
});
expect_patches(
&deserialized_context,
"
one
two
«<patch>
<edit>
<description>add a `two` function</description>
<path>src/lib.rs</path>
<operation>insert_after</operation>
<old_text>fn zero</old_text>
<new_text>
fn two() {}
</new_text>
</edit>
</patch>
»
also,",
&[&[AssistantEdit {
path: "src/lib.rs".into(),
kind: AssistantEditKind::InsertAfter {
old_text: "fn zero".into(),
new_text: "fn two() {}".into(),
description: Some("add a `two` function".into()),
},
}]],
cx,
);
fn edit(
context: &Entity<AssistantContext>,
new_text_marked_with_edits: &str,
cx: &mut TestAppContext,
) {
context.update(cx, |context, cx| {
context.buffer.update(cx, |buffer, cx| {
buffer.edit_via_marked_text(&new_text_marked_with_edits.unindent(), None, cx);
});
});
cx.executor().run_until_parked();
}
#[track_caller]
fn expect_patches(
context: &Entity<AssistantContext>,
expected_marked_text: &str,
expected_suggestions: &[&[AssistantEdit]],
cx: &mut TestAppContext,
) {
let expected_marked_text = expected_marked_text.unindent();
let (expected_text, _) = marked_text_ranges(&expected_marked_text, false);
let (buffer_text, ranges, patches) = context.update(cx, |context, cx| {
context.buffer.read_with(cx, |buffer, _| {
let ranges = context
.patches
.iter()
.map(|entry| entry.range.to_offset(buffer))
.collect::<Vec<_>>();
(
buffer.text(),
ranges,
context
.patches
.iter()
.map(|step| step.edits.clone())
.collect::<Vec<_>>(),
)
})
});
assert_eq!(buffer_text, expected_text);
let actual_marked_text = generate_marked_text(&expected_text, &ranges, false);
assert_eq!(actual_marked_text, expected_marked_text);
assert_eq!(
patches
.iter()
.map(|patch| {
patch
.iter()
.map(|edit| {
let edit = edit.as_ref().unwrap();
AssistantEdit {
path: edit.path.clone(),
kind: edit.kind.clone(),
}
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(),
expected_suggestions
);
}
}
#[gpui::test]
async fn test_serialization(cx: &mut TestAppContext) {
cx.update(init_test);
@@ -1180,187 +1575,6 @@ fn test_mark_cache_anchors(cx: &mut App) {
);
}
#[gpui::test]
async fn test_summarization(cx: &mut TestAppContext) {
let (context, fake_model) = setup_context_editor_with_fake_model(cx);
// Initial state should be pending
context.read_with(cx, |context, _| {
assert!(matches!(context.summary(), ContextSummary::Pending));
assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
});
let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
context.update(cx, |context, cx| {
context
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap();
});
// Send a message
context.update(cx, |context, cx| {
context.assist(cx);
});
simulate_successful_response(&fake_model, cx);
// Should start generating summary when there are >= 2 messages
context.read_with(cx, |context, _| {
assert!(!context.summary().content().unwrap().done);
});
cx.run_until_parked();
fake_model.stream_last_completion_response("Brief".into());
fake_model.stream_last_completion_response(" Introduction".into());
fake_model.end_last_completion_stream();
cx.run_until_parked();
// Summary should be set
context.read_with(cx, |context, _| {
assert_eq!(context.summary().or_default(), "Brief Introduction");
});
// We should be able to manually set a summary
context.update(cx, |context, cx| {
context.set_custom_summary("Brief Intro".into(), cx);
});
context.read_with(cx, |context, _| {
assert_eq!(context.summary().or_default(), "Brief Intro");
});
}
#[gpui::test]
async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) {
let (context, fake_model) = setup_context_editor_with_fake_model(cx);
test_summarize_error(&fake_model, &context, cx);
// Now we should be able to set a summary
context.update(cx, |context, cx| {
context.set_custom_summary("Brief Intro".into(), cx);
});
context.read_with(cx, |context, _| {
assert_eq!(context.summary().or_default(), "Brief Intro");
});
}
#[gpui::test]
async fn test_thread_summary_error_retry(cx: &mut TestAppContext) {
let (context, fake_model) = setup_context_editor_with_fake_model(cx);
test_summarize_error(&fake_model, &context, cx);
// Sending another message should not trigger another summarize request
context.update(cx, |context, cx| {
context.assist(cx);
});
simulate_successful_response(&fake_model, cx);
context.read_with(cx, |context, _| {
// State is still Error, not Generating
assert!(matches!(context.summary(), ContextSummary::Error));
});
// But the summarize request can be invoked manually
context.update(cx, |context, cx| {
context.summarize(true, cx);
});
context.read_with(cx, |context, _| {
assert!(!context.summary().content().unwrap().done);
});
cx.run_until_parked();
fake_model.stream_last_completion_response("A successful summary".into());
fake_model.end_last_completion_stream();
cx.run_until_parked();
context.read_with(cx, |context, _| {
assert_eq!(context.summary().or_default(), "A successful summary");
});
}
fn test_summarize_error(
model: &Arc<FakeLanguageModel>,
context: &Entity<AssistantContext>,
cx: &mut TestAppContext,
) {
let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone());
context.update(cx, |context, cx| {
context
.insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
.unwrap();
});
// Send a message
context.update(cx, |context, cx| {
context.assist(cx);
});
simulate_successful_response(&model, cx);
context.read_with(cx, |context, _| {
assert!(!context.summary().content().unwrap().done);
});
// Simulate summary request ending
cx.run_until_parked();
model.end_last_completion_stream();
cx.run_until_parked();
// State is set to Error and default message
context.read_with(cx, |context, _| {
assert_eq!(*context.summary(), ContextSummary::Error);
assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT);
});
}
fn setup_context_editor_with_fake_model(
cx: &mut TestAppContext,
) -> (Entity<AssistantContext>, Arc<FakeLanguageModel>) {
let registry = Arc::new(LanguageRegistry::test(cx.executor().clone()));
let fake_provider = Arc::new(FakeLanguageModelProvider);
let fake_model = Arc::new(fake_provider.test_model());
cx.update(|cx| {
init_test(cx);
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(
Some(ConfiguredModel {
provider: fake_provider.clone(),
model: fake_model.clone(),
}),
cx,
)
})
});
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
(context, fake_model)
}
fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.stream_last_completion_response("Assistant response".into());
fake_model.end_last_completion_stream();
cx.run_until_parked();
}
fn messages(context: &Entity<AssistantContext>, cx: &App) -> Vec<(MessageId, Role, Range<usize>)> {
context
.read(cx)

View File

@@ -9,11 +9,11 @@ use client::{proto, zed_urls};
use collections::{BTreeSet, HashMap, HashSet, hash_map};
use editor::{
Anchor, Editor, EditorEvent, MenuInlineCompletionsPolicy, MultiBuffer, MultiBufferSnapshot,
RowExt, ToOffset as _, ToPoint,
ProposedChangeLocation, ProposedChangesEditor, RowExt, ToOffset as _, ToPoint,
actions::{MoveToEndOfLine, Newline, ShowCompletions},
display_map::{
BlockPlacement, BlockProperties, BlockStyle, Crease, CreaseMetadata, CustomBlockId, FoldId,
RenderBlock, ToDisplayPoint,
BlockContext, BlockId, BlockPlacement, BlockProperties, BlockStyle, Crease, CreaseMetadata,
CustomBlockId, FoldId, RenderBlock, ToDisplayPoint,
},
scroll::Autoscroll,
};
@@ -21,11 +21,12 @@ use editor::{FoldPlaceholder, display_map::CreaseId};
use fs::Fs;
use futures::FutureExt;
use gpui::{
Animation, AnimationExt, AnyElement, AnyView, App, ClipboardEntry, ClipboardItem, Empty,
Entity, EventEmitter, FocusHandle, Focusable, FontWeight, Global, InteractiveElement,
IntoElement, ParentElement, Pixels, Render, RenderImage, SharedString, Size,
StatefulInteractiveElement, Styled, Subscription, Task, Transformation, WeakEntity, actions,
div, img, impl_internal_actions, percentage, point, prelude::*, pulsating_between, size,
Animation, AnimationExt, AnyElement, AnyView, App, AsyncWindowContext, ClipboardEntry,
ClipboardItem, CursorStyle, Empty, Entity, EventEmitter, FocusHandle, Focusable, FontWeight,
Global, InteractiveElement, IntoElement, ParentElement, Pixels, Render, RenderImage,
SharedString, Size, StatefulInteractiveElement, Styled, Subscription, Task, Transformation,
WeakEntity, actions, div, img, impl_internal_actions, percentage, point, prelude::*,
pulsating_between, size,
};
use indexed_docs::IndexedDocsStore;
use language::{
@@ -68,14 +69,14 @@ use workspace::{
Save, Toast, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
item::{self, FollowableItem, Item, ItemHandle},
notifications::NotificationId,
pane,
pane::{self, SaveIntent},
searchable::{SearchEvent, SearchableItem},
};
use crate::{
AssistantContext, CacheStatus, Content, ContextEvent, ContextId, InvokedSlashCommandId,
InvokedSlashCommandStatus, Message, MessageId, MessageMetadata, MessageStatus,
ParsedSlashCommand, PendingSlashCommandStatus,
AssistantContext, AssistantPatch, AssistantPatchStatus, CacheStatus, Content, ContextEvent,
ContextId, InvokedSlashCommandId, InvokedSlashCommandStatus, Message, MessageId,
MessageMetadata, MessageStatus, ParsedSlashCommand, PendingSlashCommandStatus, RequestType,
};
use crate::{
ThoughtProcessOutputSection, slash_command::SlashCommandCompletionProvider,
@@ -89,6 +90,7 @@ actions!(
ConfirmCommand,
CopyCode,
CycleMessageRole,
Edit,
InsertIntoEditor,
QuoteSelection,
Split,
@@ -109,11 +111,24 @@ struct ScrollPosition {
cursor: Anchor,
}
struct PatchViewState {
crease_id: CreaseId,
editor: Option<PatchEditorState>,
update_task: Option<Task<()>>,
}
struct PatchEditorState {
editor: WeakEntity<ProposedChangesEditor>,
opened_patch: AssistantPatch,
}
type MessageHeader = MessageMetadata;
#[derive(Clone)]
enum AssistError {
FileRequired,
PaymentRequired,
MaxMonthlySpendReached,
Message(SharedString),
}
@@ -189,6 +204,8 @@ pub struct ContextEditor {
pending_slash_command_creases: HashMap<Range<language::Anchor>, CreaseId>,
invoked_slash_command_creases: HashMap<InvokedSlashCommandId, CreaseId>,
_subscriptions: Vec<Subscription>,
patches: HashMap<Range<language::Anchor>, PatchViewState>,
active_patch: Option<Range<language::Anchor>>,
last_error: Option<AssistError>,
show_accept_terms: bool,
pub(crate) slash_menu_handle:
@@ -225,7 +242,7 @@ impl ContextEditor {
let editor = cx.new(|cx| {
let mut editor =
Editor::for_buffer(context.read(cx).buffer().clone(), None, window, cx);
editor.disable_scrollbars_and_minimap(window, cx);
editor.disable_scrollbars_and_minimap(cx);
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
editor.set_show_line_numbers(false, cx);
editor.set_show_git_diff_gutter(false, cx);
@@ -257,6 +274,7 @@ impl ContextEditor {
let slash_command_sections = context.read(cx).slash_command_output_sections().to_vec();
let thought_process_sections = context.read(cx).thought_process_output_sections().to_vec();
let patch_ranges = context.read(cx).patch_ranges().collect::<Vec<_>>();
let slash_commands = context.read(cx).slash_commands().clone();
let mut this = Self {
context,
@@ -274,6 +292,8 @@ impl ContextEditor {
pending_slash_command_creases: HashMap::default(),
invoked_slash_command_creases: HashMap::default(),
_subscriptions,
patches: HashMap::default(),
active_patch: None,
last_error: None,
show_accept_terms: false,
slash_menu_handle: Default::default(),
@@ -304,6 +324,7 @@ impl ContextEditor {
window,
cx,
);
this.patches_updated(&Vec::new(), &patch_ranges, window, cx);
this
}
@@ -349,10 +370,37 @@ impl ContextEditor {
if self.sending_disabled(cx) {
return;
}
self.send_to_model(window, cx);
self.send_to_model(RequestType::Chat, window, cx);
}
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
fn edit(&mut self, _: &Edit, window: &mut Window, cx: &mut Context<Self>) {
if self.sending_disabled(cx) {
return;
}
self.send_to_model(RequestType::SuggestEdits, window, cx);
}
fn focus_active_patch(&mut self, window: &mut Window, cx: &mut Context<Self>) -> bool {
if let Some((_range, patch)) = self.active_patch() {
if let Some(editor) = patch
.editor
.as_ref()
.and_then(|state| state.editor.upgrade())
{
editor.focus_handle(cx).focus(window);
return true;
}
}
false
}
fn send_to_model(
&mut self,
request_type: RequestType,
window: &mut Window,
cx: &mut Context<Self>,
) {
let provider = LanguageModelRegistry::read_global(cx)
.default_model()
.map(|default| default.provider);
@@ -365,9 +413,19 @@ impl ContextEditor {
return;
}
if self.focus_active_patch(window, cx) {
return;
}
self.last_error = None;
if let Some(user_message) = self.context.update(cx, |context, cx| context.assist(cx)) {
if request_type == RequestType::SuggestEdits && !self.context.read(cx).contains_files(cx) {
self.last_error = Some(AssistError::FileRequired);
cx.notify();
} else if let Some(user_message) = self
.context
.update(cx, |context, cx| context.assist(request_type, cx))
{
let new_selection = {
let cursor = user_message
.start
@@ -632,6 +690,9 @@ impl ContextEditor {
}
});
}
ContextEvent::PatchesUpdated { removed, updated } => {
self.patches_updated(removed, updated, window, cx);
}
ContextEvent::ParsedSlashCommandsUpdated { removed, updated } => {
self.editor.update(cx, |editor, cx| {
let buffer = editor.buffer().read(cx).snapshot(cx);
@@ -731,6 +792,9 @@ impl ContextEditor {
ContextEvent::ShowPaymentRequiredError => {
self.last_error = Some(AssistError::PaymentRequired);
}
ContextEvent::ShowMaxMonthlySpendReachedError => {
self.last_error = Some(AssistError::MaxMonthlySpendReached);
}
}
}
@@ -833,6 +897,131 @@ impl ContextEditor {
});
}
fn patches_updated(
&mut self,
removed: &Vec<Range<text::Anchor>>,
updated: &Vec<Range<text::Anchor>>,
window: &mut Window,
cx: &mut Context<ContextEditor>,
) {
let this = cx.entity().downgrade();
let mut editors_to_close = Vec::new();
self.editor.update(cx, |editor, cx| {
let snapshot = editor.snapshot(window, cx);
let multibuffer = &snapshot.buffer_snapshot;
let (&excerpt_id, _, _) = multibuffer.as_singleton().unwrap();
let mut removed_crease_ids = Vec::new();
let mut ranges_to_unfold: Vec<Range<Anchor>> = Vec::new();
for range in removed {
if let Some(state) = self.patches.remove(range) {
let patch_start = multibuffer
.anchor_in_excerpt(excerpt_id, range.start)
.unwrap();
let patch_end = multibuffer
.anchor_in_excerpt(excerpt_id, range.end)
.unwrap();
editors_to_close.extend(state.editor.and_then(|state| state.editor.upgrade()));
ranges_to_unfold.push(patch_start..patch_end);
removed_crease_ids.push(state.crease_id);
}
}
editor.unfold_ranges(&ranges_to_unfold, true, false, cx);
editor.remove_creases(removed_crease_ids, cx);
for range in updated {
let Some(patch) = self.context.read(cx).patch_for_range(&range, cx).cloned() else {
continue;
};
let path_count = patch.path_count();
let patch_start = multibuffer
.anchor_in_excerpt(excerpt_id, patch.range.start)
.unwrap();
let patch_end = multibuffer
.anchor_in_excerpt(excerpt_id, patch.range.end)
.unwrap();
let render_block: RenderBlock = Arc::new({
let this = this.clone();
let patch_range = range.clone();
move |cx: &mut BlockContext| {
let max_width = cx.max_width;
let gutter_width = cx.margins.gutter.full_width();
let block_id = cx.block_id;
let selected = cx.selected;
let window = &mut cx.window;
this.update(cx.app, |this, cx| {
this.render_patch_block(
patch_range.clone(),
max_width,
gutter_width,
block_id,
selected,
window,
cx,
)
})
.ok()
.flatten()
.unwrap_or_else(|| Empty.into_any())
}
});
let height = path_count as u32 + 1;
let crease = Crease::block(
patch_start..patch_end,
height,
BlockStyle::Flex,
render_block.clone(),
);
let should_refold;
if let Some(state) = self.patches.get_mut(&range) {
if let Some(editor_state) = &state.editor {
if editor_state.opened_patch != patch {
state.update_task = Some({
let this = this.clone();
cx.spawn_in(window, async move |_, cx| {
Self::update_patch_editor(this.clone(), patch, cx)
.await
.log_err();
})
});
}
}
should_refold =
snapshot.intersects_fold(patch_start.to_offset(&snapshot.buffer_snapshot));
} else {
let crease_id = editor.insert_creases([crease.clone()], cx)[0];
self.patches.insert(
range.clone(),
PatchViewState {
crease_id,
editor: None,
update_task: None,
},
);
should_refold = true;
}
if should_refold {
editor.unfold_ranges(&[patch_start..patch_end], true, false, cx);
editor.fold_creases(vec![crease], false, window, cx);
}
}
});
for editor in editors_to_close {
self.close_patch_editor(editor, window, cx);
}
self.update_active_patch(window, cx);
}
fn insert_thought_process_output_sections(
&mut self,
sections: impl IntoIterator<
@@ -961,12 +1150,177 @@ impl ContextEditor {
}
EditorEvent::SelectionsChanged { .. } => {
self.scroll_position = self.cursor_scroll_position(window, cx);
self.update_active_patch(window, cx);
}
_ => {}
}
cx.emit(event.clone());
}
fn active_patch(&self) -> Option<(Range<text::Anchor>, &PatchViewState)> {
let patch = self.active_patch.as_ref()?;
Some((patch.clone(), self.patches.get(&patch)?))
}
fn update_active_patch(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let newest_cursor = self.editor.update(cx, |editor, cx| {
editor.selections.newest::<Point>(cx).head()
});
let context = self.context.read(cx);
let new_patch = context.patch_containing(newest_cursor, cx).cloned();
if new_patch.as_ref().map(|p| &p.range) == self.active_patch.as_ref() {
return;
}
if let Some(old_patch_range) = self.active_patch.take() {
if let Some(patch_state) = self.patches.get_mut(&old_patch_range) {
if let Some(state) = patch_state.editor.take() {
if let Some(editor) = state.editor.upgrade() {
self.close_patch_editor(editor, window, cx);
}
}
}
}
if let Some(new_patch) = new_patch {
self.active_patch = Some(new_patch.range.clone());
if let Some(patch_state) = self.patches.get_mut(&new_patch.range) {
let mut editor = None;
if let Some(state) = &patch_state.editor {
if let Some(opened_editor) = state.editor.upgrade() {
editor = Some(opened_editor);
}
}
if let Some(editor) = editor {
self.workspace
.update(cx, |workspace, cx| {
workspace.activate_item(&editor, true, false, window, cx);
})
.ok();
} else {
patch_state.update_task = Some(cx.spawn_in(window, async move |this, cx| {
Self::open_patch_editor(this, new_patch, cx).await.log_err();
}));
}
}
}
}
fn close_patch_editor(
&mut self,
editor: Entity<ProposedChangesEditor>,
window: &mut Window,
cx: &mut Context<ContextEditor>,
) {
self.workspace
.update(cx, |workspace, cx| {
if let Some(pane) = workspace.pane_for(&editor) {
pane.update(cx, |pane, cx| {
let item_id = editor.entity_id();
if !editor.read(cx).focus_handle(cx).is_focused(window) {
pane.close_item_by_id(item_id, SaveIntent::Skip, window, cx)
.detach_and_log_err(cx);
}
});
}
})
.ok();
}
async fn open_patch_editor(
this: WeakEntity<Self>,
patch: AssistantPatch,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let project = this.read_with(cx, |this, _| this.project.clone())?;
let resolved_patch = patch.resolve(project.clone(), cx).await;
let editor = cx.new_window_entity(|window, cx| {
let editor = ProposedChangesEditor::new(
patch.title.clone(),
resolved_patch
.edit_groups
.iter()
.map(|(buffer, groups)| ProposedChangeLocation {
buffer: buffer.clone(),
ranges: groups
.iter()
.map(|group| group.context_range.clone())
.collect(),
})
.collect(),
Some(project.clone()),
window,
cx,
);
resolved_patch.apply(&editor, cx);
editor
})?;
this.update(cx, |this, _| {
if let Some(patch_state) = this.patches.get_mut(&patch.range) {
patch_state.editor = Some(PatchEditorState {
editor: editor.downgrade(),
opened_patch: patch,
});
patch_state.update_task.take();
}
})?;
this.read_with(cx, |this, _| this.workspace.clone())?
.update_in(cx, |workspace, window, cx| {
workspace.add_item_to_active_pane(Box::new(editor.clone()), None, false, window, cx)
})
.log_err();
Ok(())
}
async fn update_patch_editor(
this: WeakEntity<Self>,
patch: AssistantPatch,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let project = this.update(cx, |this, _| this.project.clone())?;
let resolved_patch = patch.resolve(project.clone(), cx).await;
this.update_in(cx, |this, window, cx| {
let patch_state = this.patches.get_mut(&patch.range)?;
let locations = resolved_patch
.edit_groups
.iter()
.map(|(buffer, groups)| ProposedChangeLocation {
buffer: buffer.clone(),
ranges: groups
.iter()
.map(|group| group.context_range.clone())
.collect(),
})
.collect();
if let Some(state) = &mut patch_state.editor {
if let Some(editor) = state.editor.upgrade() {
editor.update(cx, |editor, cx| {
editor.set_title(patch.title.clone(), cx);
editor.reset_locations(locations, window, cx);
resolved_patch.apply(editor, cx);
});
state.opened_patch = patch;
} else {
patch_state.editor.take();
}
}
patch_state.update_task.take();
Some(())
})?;
Ok(())
}
fn handle_editor_search_event(
&mut self,
_: &Entity<Editor>,
@@ -1590,7 +1944,7 @@ impl ContextEditor {
&mut self,
cx: &mut Context<Self>,
) -> (String, CopyMetadata, Vec<text::Selection<usize>>) {
let (mut selection, creases) = self.editor.update(cx, |editor, cx| {
let (selection, creases) = self.editor.update(cx, |editor, cx| {
let mut selection = editor.selections.newest_adjusted(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
@@ -1648,18 +2002,7 @@ impl ContextEditor {
} else if message.offset_range.end >= selection.range().start {
let range = cmp::max(message.offset_range.start, selection.range().start)
..cmp::min(message.offset_range.end, selection.range().end);
if range.is_empty() {
let snapshot = context.buffer().read(cx).snapshot();
let point = snapshot.offset_to_point(range.start);
selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
selection.end = snapshot.point_to_offset(cmp::min(
Point::new(point.row + 1, 0),
snapshot.max_point(),
));
for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
text.push_str(chunk);
}
} else {
if !range.is_empty() {
for chunk in context.buffer().read(cx).text_for_range(range) {
text.push_str(chunk);
}
@@ -1867,12 +2210,119 @@ impl ContextEditor {
}
pub fn title(&self, cx: &App) -> SharedString {
self.context.read(cx).summary().or_default()
self.context.read(cx).summary_or_default()
}
pub fn regenerate_summary(&mut self, cx: &mut Context<Self>) {
self.context
.update(cx, |context, cx| context.summarize(true, cx));
fn render_patch_block(
&mut self,
range: Range<text::Anchor>,
max_width: Pixels,
gutter_width: Pixels,
id: BlockId,
selected: bool,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<AnyElement> {
let snapshot = self
.editor
.update(cx, |editor, cx| editor.snapshot(window, cx));
let (excerpt_id, _buffer_id, _) = snapshot.buffer_snapshot.as_singleton().unwrap();
let excerpt_id = *excerpt_id;
let anchor = snapshot
.buffer_snapshot
.anchor_in_excerpt(excerpt_id, range.start)
.unwrap();
let theme = cx.theme().clone();
let patch = self.context.read(cx).patch_for_range(&range, cx)?;
let paths = patch
.paths()
.map(|p| SharedString::from(p.to_string()))
.collect::<BTreeSet<_>>();
Some(
v_flex()
.id(id)
.bg(theme.colors().editor_background)
.ml(gutter_width)
.pb_1()
.w(max_width - gutter_width)
.rounded_sm()
.border_1()
.border_color(theme.colors().border_variant)
.overflow_hidden()
.hover(|style| style.border_color(theme.colors().text_accent))
.when(selected, |this| {
this.border_color(theme.colors().text_accent)
})
.cursor(CursorStyle::PointingHand)
.on_click(cx.listener(move |this, _, window, cx| {
this.editor.update(cx, |editor, cx| {
editor.change_selections(None, window, cx, |selections| {
selections.select_ranges(vec![anchor..anchor]);
});
});
this.focus_active_patch(window, cx);
}))
.child(
div()
.px_2()
.py_1()
.overflow_hidden()
.text_ellipsis()
.border_b_1()
.border_color(theme.colors().border_variant)
.bg(theme.colors().element_background)
.child(
Label::new(patch.title.clone())
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.children(paths.into_iter().map(|path| {
h_flex()
.px_2()
.pt_1()
.gap_1p5()
.child(Icon::new(IconName::File).size(IconSize::Small))
.child(Label::new(path).size(LabelSize::Small))
}))
.when(patch.status == AssistantPatchStatus::Pending, |div| {
div.child(
h_flex()
.pt_1()
.px_2()
.gap_1()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::XSmall)
.color(Color::Muted)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(
delta,
)))
},
),
)
.child(
Label::new("Generating…")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.alpha(delta),
),
),
)
})
.into_any(),
)
}
fn render_notice(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -1906,24 +2356,11 @@ impl ContextEditor {
.log_err();
if let Some(client) = client {
cx.spawn(async move |context_editor, cx| {
match client.authenticate_and_connect(true, cx).await {
util::ConnectionResult::Timeout => {
log::error!("Authentication timeout")
}
util::ConnectionResult::ConnectionReset => {
log::error!("Connection reset")
}
util::ConnectionResult::Result(r) => {
if r.log_err().is_some() {
context_editor
.update(cx, |_, cx| cx.notify())
.ok();
}
}
}
cx.spawn(async move |this, cx| {
client.authenticate_and_connect(true, cx).await?;
this.update(cx, |_, cx| cx.notify())
})
.detach()
.detach_and_log_err(cx)
}
})),
)
@@ -2014,7 +2451,13 @@ impl ContextEditor {
button.tooltip(move |_, _| tooltip.clone())
})
.layer(ElevationIndex::ModalSurface)
.child(Label::new("Send"))
.child(Label::new(
if AssistantSettings::get_global(cx).are_live_diffs_enabled(cx) {
"Chat"
} else {
"Send"
},
))
.children(
KeyBinding::for_action_in(&Assist, &focus_handle, window, cx)
.map(|binding| binding.into_any_element()),
@@ -2038,6 +2481,50 @@ impl ContextEditor {
has_configuration_error || needs_to_accept_terms
}
fn render_edit_button(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle(cx).clone();
let (style, tooltip) = match token_state(&self.context, cx) {
Some(TokenState::NoTokensLeft { .. }) => (
ButtonStyle::Tinted(TintColor::Error),
Some(Tooltip::text("Token limit reached")(window, cx)),
),
Some(TokenState::HasMoreTokens {
over_warn_threshold,
..
}) => {
let (style, tooltip) = if over_warn_threshold {
(
ButtonStyle::Tinted(TintColor::Warning),
Some(Tooltip::text("Token limit is close to exhaustion")(
window, cx,
)),
)
} else {
(ButtonStyle::Filled, None)
};
(style, tooltip)
}
None => (ButtonStyle::Filled, None),
};
ButtonLike::new("edit_button")
.disabled(self.sending_disabled(cx))
.style(style)
.when_some(tooltip, |button, tooltip| {
button.tooltip(move |_, _| tooltip.clone())
})
.layer(ElevationIndex::ModalSurface)
.child(Label::new("Suggest Edits"))
.children(
KeyBinding::for_action_in(&Edit, &focus_handle, window, cx)
.map(|binding| binding.into_any_element()),
)
.on_click(move |_event, window, cx| {
focus_handle.dispatch_action(&Edit, window, cx);
})
}
fn render_inject_context_menu(&self, cx: &mut Context<Self>) -> impl IntoElement {
slash_command_picker::SlashCommandSelector::new(
self.slash_commands.clone(),
@@ -2113,7 +2600,11 @@ impl ContextEditor {
.elevation_2(cx)
.occlude()
.child(match last_error {
AssistError::FileRequired => self.render_file_required_error(cx),
AssistError::PaymentRequired => self.render_payment_required_error(cx),
AssistError::MaxMonthlySpendReached => {
self.render_max_monthly_spend_reached_error(cx)
}
AssistError::Message(error_message) => {
self.render_assist_error(error_message, cx)
}
@@ -2122,6 +2613,41 @@ impl ContextEditor {
)
}
fn render_file_required_error(&self, cx: &mut Context<Self>) -> AnyElement {
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::Warning).color(Color::Warning))
.child(
Label::new("Suggest Edits needs a file to edit").weight(FontWeight::MEDIUM),
),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(
"To include files, type /file or /tab in your prompt.",
)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _window, cx| {
this.last_error = None;
cx.notify();
},
))),
)
.into_any()
}
fn render_payment_required_error(&self, cx: &mut Context<Self>) -> AnyElement {
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
@@ -2162,6 +2688,48 @@ impl ContextEditor {
.into_any()
}
fn render_max_monthly_spend_reached_error(&self, cx: &mut Context<Self>) -> AnyElement {
const ERROR_MESSAGE: &str = "You have reached your maximum monthly spend. Increase your spend limit to continue using Zed LLMs.";
v_flex()
.gap_0p5()
.child(
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Max Monthly Spend Reached").weight(FontWeight::MEDIUM)),
)
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(ERROR_MESSAGE)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.child(
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
cx.listener(|this, _, _window, cx| {
this.last_error = None;
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
}),
),
)
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _window, cx| {
this.last_error = None;
cx.notify();
},
))),
)
.into_any()
}
fn render_assist_error(
&self,
error_message: &SharedString,
@@ -2520,6 +3088,7 @@ impl Render for ContextEditor {
.capture_action(cx.listener(ContextEditor::paste))
.capture_action(cx.listener(ContextEditor::cycle_message_role))
.capture_action(cx.listener(ContextEditor::confirm_command))
.on_action(cx.listener(ContextEditor::edit))
.on_action(cx.listener(ContextEditor::assist))
.on_action(cx.listener(ContextEditor::split))
.on_action(move |_: &ToggleModelSelector, window, cx| {
@@ -2572,6 +3141,20 @@ impl Render for ContextEditor {
h_flex()
.w_full()
.justify_end()
.when(
AssistantSettings::get_global(cx).are_live_diffs_enabled(cx),
|buttons| {
buttons
.items_center()
.gap_1p5()
.child(self.render_edit_button(window, cx))
.child(
Label::new("or")
.size(LabelSize::Small)
.color(Color::Muted),
)
},
)
.child(self.render_send_button(window, cx)),
),
),
@@ -3044,7 +3627,7 @@ fn invoked_slash_command_fold_placeholder(
.gap_2()
.bg(cx.theme().colors().surface_background)
.rounded_sm()
.child(Label::new(format!("/{}", command.name)))
.child(Label::new(format!("/{}", command.name.clone())))
.map(|parent| match &command.status {
InvokedSlashCommandStatus::Running(_) => {
parent.child(Icon::new(IconName::ArrowCircle).with_animation(
@@ -3213,77 +3796,9 @@ pub fn make_lsp_adapter_delegate(
#[cfg(test)]
mod tests {
use super::*;
use fs::FakeFs;
use gpui::{App, TestAppContext, VisualTestContext};
use language::{Buffer, LanguageRegistry};
use prompt_store::PromptBuilder;
use gpui::App;
use language::Buffer;
use unindent::Unindent;
use util::path;
#[gpui::test]
async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
cx.update(init_test);
let fs = FakeFs::new(cx.executor());
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
let context = cx.new(|cx| {
AssistantContext::local(
registry,
None,
None,
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
cx,
)
});
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let workspace = window.root(cx).unwrap();
let cx = &mut VisualTestContext::from_window(*window, cx);
let context_editor = window
.update(cx, |_, window, cx| {
cx.new(|cx| {
ContextEditor::for_context(
context,
fs,
workspace.downgrade(),
project,
None,
window,
cx,
)
})
})
.unwrap();
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.set_text("abc\ndef\nghi", window, cx);
editor.move_to_beginning(&Default::default(), window, cx);
})
});
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.copy(&Default::default(), window, cx);
editor.paste(&Default::default(), window, cx);
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
})
});
context_editor.update_in(cx, |context_editor, window, cx| {
context_editor.editor.update(cx, |editor, cx| {
editor.cut(&Default::default(), window, cx);
assert_eq!(editor.text(cx), "abc\ndef\nghi");
editor.paste(&Default::default(), window, cx);
assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
})
});
}
#[gpui::test]
fn test_find_code_blocks(cx: &mut App) {
@@ -3358,17 +3873,4 @@ mod tests {
assert_eq!(range, expected, "unexpected result on row {:?}", row);
}
}
fn init_test(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
prompt_store::init(cx);
LanguageModelRegistry::test(cx);
cx.set_global(settings_store);
language::init(cx);
assistant_settings::init(cx);
Project::init_settings(cx);
theme::init(theme::LoadThemes::JustBase, cx);
workspace::init_settings(cx);
editor::init_settings(cx);
}
}

View File

@@ -648,10 +648,7 @@ impl ContextStore {
if context.replica_id() == ReplicaId::default() {
Some(proto::ContextMetadata {
context_id: context.id().to_proto(),
summary: context
.summary()
.content()
.map(|summary| summary.text.clone()),
summary: context.summary().map(|summary| summary.text.clone()),
})
} else {
None

View File

@@ -0,0 +1,957 @@
use anyhow::{Context as _, Result, anyhow};
use collections::HashMap;
use editor::ProposedChangesEditor;
use futures::{TryFutureExt as _, future};
use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString};
use language::{AutoindentMode, Buffer, BufferSnapshot};
use project::{Project, ProjectPath};
use std::{cmp, ops::Range, path::Path, sync::Arc};
use text::{AnchorRangeExt as _, Bias, OffsetRangeExt as _, Point};
#[derive(Clone, Debug)]
pub struct AssistantPatch {
pub range: Range<language::Anchor>,
pub title: SharedString,
pub edits: Arc<[Result<AssistantEdit>]>,
pub status: AssistantPatchStatus,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum AssistantPatchStatus {
Pending,
Ready,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AssistantEdit {
pub path: String,
pub kind: AssistantEditKind,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AssistantEditKind {
Update {
old_text: String,
new_text: String,
description: Option<String>,
},
Create {
new_text: String,
description: Option<String>,
},
InsertBefore {
old_text: String,
new_text: String,
description: Option<String>,
},
InsertAfter {
old_text: String,
new_text: String,
description: Option<String>,
},
Delete {
old_text: String,
},
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ResolvedPatch {
pub edit_groups: HashMap<Entity<Buffer>, Vec<ResolvedEditGroup>>,
pub errors: Vec<AssistantPatchResolutionError>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ResolvedEditGroup {
pub context_range: Range<language::Anchor>,
pub edits: Vec<ResolvedEdit>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ResolvedEdit {
range: Range<language::Anchor>,
new_text: String,
description: Option<String>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AssistantPatchResolutionError {
pub edit_ix: usize,
pub message: String,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(rows: usize, cols: usize) -> Self {
SearchMatrix {
cols,
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
}
}
fn get(&self, row: usize, col: usize) -> SearchState {
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
self.data[row * self.cols + col] = cost;
}
}
impl ResolvedPatch {
pub fn apply(&self, editor: &ProposedChangesEditor, cx: &mut App) {
for (buffer, groups) in &self.edit_groups {
let branch = editor.branch_buffer_for_base(buffer).unwrap();
Self::apply_edit_groups(groups, &branch, cx);
}
editor.recalculate_all_buffer_diffs();
}
fn apply_edit_groups(groups: &Vec<ResolvedEditGroup>, buffer: &Entity<Buffer>, cx: &mut App) {
let mut edits = Vec::new();
for group in groups {
for suggestion in &group.edits {
edits.push((suggestion.range.clone(), suggestion.new_text.clone()));
}
}
buffer.update(cx, |buffer, cx| {
buffer.edit(
edits,
Some(AutoindentMode::Block {
original_indent_columns: Vec::new(),
}),
cx,
);
});
}
}
impl ResolvedEdit {
pub fn try_merge(&mut self, other: &Self, buffer: &text::BufferSnapshot) -> bool {
let range = &self.range;
let other_range = &other.range;
// Don't merge if we don't contain the other suggestion.
if range.start.cmp(&other_range.start, buffer).is_gt()
|| range.end.cmp(&other_range.end, buffer).is_lt()
{
return false;
}
let other_offset_range = other_range.to_offset(buffer);
let offset_range = range.to_offset(buffer);
// If the other range is empty at the start of this edit's range, combine the new text
if other_offset_range.is_empty() && other_offset_range.start == offset_range.start {
self.new_text = format!("{}\n{}", other.new_text, self.new_text);
self.range.start = other_range.start;
if let Some((description, other_description)) =
self.description.as_mut().zip(other.description.as_ref())
{
*description = format!("{}\n{}", other_description, description)
}
} else {
if let Some((description, other_description)) =
self.description.as_mut().zip(other.description.as_ref())
{
description.push('\n');
description.push_str(other_description);
}
}
true
}
}
impl AssistantEdit {
pub fn new(
path: Option<String>,
operation: Option<String>,
old_text: Option<String>,
new_text: Option<String>,
description: Option<String>,
) -> Result<Self> {
let path = path.ok_or_else(|| anyhow!("missing path"))?;
let operation = operation.ok_or_else(|| anyhow!("missing operation"))?;
let kind = match operation.as_str() {
"update" => AssistantEditKind::Update {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
description,
},
"insert_before" => AssistantEditKind::InsertBefore {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
description,
},
"insert_after" => AssistantEditKind::InsertAfter {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
description,
},
"delete" => AssistantEditKind::Delete {
old_text: old_text.ok_or_else(|| anyhow!("missing old_text"))?,
},
"create" => AssistantEditKind::Create {
description,
new_text: new_text.ok_or_else(|| anyhow!("missing new_text"))?,
},
_ => Err(anyhow!("unknown operation {operation:?}"))?,
};
Ok(Self { path, kind })
}
pub async fn resolve(
&self,
project: Entity<Project>,
mut cx: AsyncApp,
) -> Result<(Entity<Buffer>, ResolvedEdit)> {
let path = self.path.clone();
let kind = self.kind.clone();
let buffer = project
.update(&mut cx, |project, cx| {
let project_path = project
.find_project_path(Path::new(&path), cx)
.or_else(|| {
// If we couldn't find a project path for it, put it in the active worktree
// so that when we create the buffer, it can be saved.
let worktree = project
.active_entry()
.and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
.or_else(|| project.worktrees(cx).next())?;
let worktree = worktree.read(cx);
Some(ProjectPath {
worktree_id: worktree.id(),
path: Arc::from(Path::new(&path)),
})
})
.with_context(|| format!("worktree not found for {:?}", path))?;
anyhow::Ok(project.open_buffer(project_path, cx))
})??
.await?;
let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
let suggestion = cx
.background_spawn(async move { kind.resolve(&snapshot) })
.await;
Ok((buffer, suggestion))
}
}
impl AssistantEditKind {
fn resolve(self, snapshot: &BufferSnapshot) -> ResolvedEdit {
match self {
Self::Update {
old_text,
new_text,
description,
} => {
let range = Self::resolve_location(&snapshot, &old_text);
ResolvedEdit {
range,
new_text,
description,
}
}
Self::Create {
new_text,
description,
} => ResolvedEdit {
range: text::Anchor::MIN..text::Anchor::MAX,
description,
new_text,
},
Self::InsertBefore {
old_text,
mut new_text,
description,
} => {
let range = Self::resolve_location(&snapshot, &old_text);
new_text.push('\n');
ResolvedEdit {
range: range.start..range.start,
new_text,
description,
}
}
Self::InsertAfter {
old_text,
mut new_text,
description,
} => {
let range = Self::resolve_location(&snapshot, &old_text);
new_text.insert(0, '\n');
ResolvedEdit {
range: range.end..range.end,
new_text,
description,
}
}
Self::Delete { old_text } => {
let range = Self::resolve_location(&snapshot, &old_text);
ResolvedEdit {
range,
new_text: String::new(),
description: None,
}
}
}
}
fn resolve_location(buffer: &text::BufferSnapshot, search_query: &str) -> Range<text::Anchor> {
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
const WHITESPACE_INSERTION_COST: u32 = 1;
const WHITESPACE_DELETION_COST: u32 = 1;
let buffer_len = buffer.len();
let query_len = search_query.len();
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
let mut leading_deletion_cost = 0_u32;
for (row, query_byte) in search_query.bytes().enumerate() {
let deletion_cost = if query_byte.is_ascii_whitespace() {
WHITESPACE_DELETION_COST
} else {
DELETION_COST
};
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
);
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
WHITESPACE_INSERTION_COST
} else {
INSERTION_COST
};
let up = SearchState::new(
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
SearchDirection::Up,
);
let left = SearchState::new(
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_byte == *buffer_byte {
matrix.get(row, col).cost
} else {
matrix
.get(row, col)
.cost
.saturating_add(deletion_cost + insertion_cost)
},
SearchDirection::Diagonal,
);
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
}
}
// Traceback to find the best match
let mut best_buffer_end = buffer_len;
let mut best_cost = u32::MAX;
for col in 1..=buffer_len {
let cost = matrix.get(query_len, col).cost;
if cost < best_cost {
best_cost = cost;
best_buffer_end = col;
}
}
let mut query_ix = query_len;
let mut buffer_ix = best_buffer_end;
while query_ix > 0 && buffer_ix > 0 {
let current = matrix.get(query_ix, buffer_ix);
match current.direction {
SearchDirection::Diagonal => {
query_ix -= 1;
buffer_ix -= 1;
}
SearchDirection::Up => {
query_ix -= 1;
}
SearchDirection::Left => {
buffer_ix -= 1;
}
}
}
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
start.column = 0;
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
if end.column > 0 {
end.column = buffer.line_len(end.row);
}
buffer.anchor_after(start)..buffer.anchor_before(end)
}
}
impl AssistantPatch {
pub async fn resolve(&self, project: Entity<Project>, cx: &mut AsyncApp) -> ResolvedPatch {
let mut resolve_tasks = Vec::new();
for (ix, edit) in self.edits.iter().enumerate() {
if let Ok(edit) = edit.as_ref() {
resolve_tasks.push(
edit.resolve(project.clone(), cx.clone())
.map_err(move |error| (ix, error)),
);
}
}
let edits = future::join_all(resolve_tasks).await;
let mut errors = Vec::new();
let mut edits_by_buffer = HashMap::default();
for entry in edits {
match entry {
Ok((buffer, edit)) => {
edits_by_buffer
.entry(buffer)
.or_insert_with(Vec::new)
.push(edit);
}
Err((edit_ix, error)) => errors.push(AssistantPatchResolutionError {
edit_ix,
message: error.to_string(),
}),
}
}
// Expand the context ranges of each edit and group edits with overlapping context ranges.
let mut edit_groups_by_buffer = HashMap::default();
for (buffer, edits) in edits_by_buffer {
if let Ok(snapshot) = buffer.update(cx, |buffer, _| buffer.text_snapshot()) {
edit_groups_by_buffer.insert(buffer, Self::group_edits(edits, &snapshot));
}
}
ResolvedPatch {
edit_groups: edit_groups_by_buffer,
errors,
}
}
fn group_edits(
mut edits: Vec<ResolvedEdit>,
snapshot: &text::BufferSnapshot,
) -> Vec<ResolvedEditGroup> {
let mut edit_groups = Vec::<ResolvedEditGroup>::new();
// Sort edits by their range so that earlier, larger ranges come first
edits.sort_by(|a, b| a.range.cmp(&b.range, &snapshot));
// Merge overlapping edits
edits.dedup_by(|a, b| b.try_merge(a, &snapshot));
// Create context ranges for each edit
for edit in edits {
let context_range = {
let edit_point_range = edit.range.to_point(&snapshot);
let start_row = edit_point_range.start.row.saturating_sub(5);
let end_row = cmp::min(edit_point_range.end.row + 5, snapshot.max_point().row);
let start = snapshot.anchor_before(Point::new(start_row, 0));
let end = snapshot.anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
start..end
};
if let Some(last_group) = edit_groups.last_mut() {
if last_group
.context_range
.end
.cmp(&context_range.start, &snapshot)
.is_ge()
{
// Merge with the previous group if context ranges overlap
last_group.context_range.end = context_range.end;
last_group.edits.push(edit);
} else {
// Create a new group
edit_groups.push(ResolvedEditGroup {
context_range,
edits: vec![edit],
});
}
} else {
// Create the first group
edit_groups.push(ResolvedEditGroup {
context_range,
edits: vec![edit],
});
}
}
edit_groups
}
pub fn path_count(&self) -> usize {
self.paths().count()
}
pub fn paths(&self) -> impl '_ + Iterator<Item = &str> {
let mut prev_path = None;
self.edits.iter().filter_map(move |edit| {
if let Ok(edit) = edit {
let path = Some(edit.path.as_str());
if path != prev_path {
prev_path = path;
return path;
}
}
None
})
}
}
impl PartialEq for AssistantPatch {
fn eq(&self, other: &Self) -> bool {
self.range == other.range
&& self.title == other.title
&& Arc::ptr_eq(&self.edits, &other.edits)
}
}
impl Eq for AssistantPatch {}
#[cfg(test)]
mod tests {
use super::*;
use gpui::App;
use language::{
Language, LanguageConfig, LanguageMatcher, language_settings::AllLanguageSettings,
};
use settings::SettingsStore;
use ui::BorrowAppContext;
use unindent::Unindent as _;
use util::test::{generate_marked_text, marked_text_ranges};
#[gpui::test]
fn test_resolve_location(cx: &mut App) {
assert_location_resolution(
concat!(
" Lorem\n",
"« ipsum\n",
" dolor sit amet»\n",
" consecteur",
),
"ipsum\ndolor",
cx,
);
assert_location_resolution(
&"
«fn foo1(a: usize) -> usize {
40
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"fn foo1(b: usize) {\n40\n}",
cx,
);
assert_location_resolution(
&"
fn main() {
« Foo
.bar()
.baz()
.qux()»
}
fn foo2(b: usize) -> usize {
42
}
"
.unindent(),
"Foo.bar.baz.qux()",
cx,
);
assert_location_resolution(
&"
class Something {
one() { return 1; }
« two() { return 2222; }
three() { return 333; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
» seven() { return 7; }
eight() { return 8; }
}
"
.unindent(),
&"
two() { return 2222; }
four() { return 4444; }
five() { return 5555; }
six() { return 6666; }
"
.unindent(),
cx,
);
}
#[gpui::test]
fn test_resolve_edits(cx: &mut App) {
init_test(cx);
assert_edits(
"
/// A person
struct Person {
name: String,
age: usize,
}
/// A dog
struct Dog {
weight: f32,
}
impl Person {
fn name(&self) -> &str {
&self.name
}
}
"
.unindent(),
vec![
AssistantEditKind::Update {
old_text: "
name: String,
"
.unindent(),
new_text: "
first_name: String,
last_name: String,
"
.unindent(),
description: None,
},
AssistantEditKind::Update {
old_text: "
fn name(&self) -> &str {
&self.name
}
"
.unindent(),
new_text: "
fn name(&self) -> String {
format!(\"{} {}\", self.first_name, self.last_name)
}
"
.unindent(),
description: None,
},
],
"
/// A person
struct Person {
first_name: String,
last_name: String,
age: usize,
}
/// A dog
struct Dog {
weight: f32,
}
impl Person {
fn name(&self) -> String {
format!(\"{} {}\", self.first_name, self.last_name)
}
}
"
.unindent(),
cx,
);
// Ensure InsertBefore merges correctly with Update of the same text
assert_edits(
"
fn foo() {
}
"
.unindent(),
vec![
AssistantEditKind::InsertBefore {
old_text: "
fn foo() {"
.unindent(),
new_text: "
fn bar() {
qux();
}"
.unindent(),
description: Some("implement bar".into()),
},
AssistantEditKind::Update {
old_text: "
fn foo() {
}"
.unindent(),
new_text: "
fn foo() {
bar();
}"
.unindent(),
description: Some("call bar in foo".into()),
},
AssistantEditKind::InsertAfter {
old_text: "
fn foo() {
}
"
.unindent(),
new_text: "
fn qux() {
// todo
}
"
.unindent(),
description: Some("implement qux".into()),
},
],
"
fn bar() {
qux();
}
fn foo() {
bar();
}
fn qux() {
// todo
}
"
.unindent(),
cx,
);
// Correctly indent new text when replacing multiple adjacent indented blocks.
assert_edits(
"
impl Numbers {
fn one() {
1
}
fn two() {
2
}
fn three() {
3
}
}
"
.unindent(),
vec![
AssistantEditKind::Update {
old_text: "
fn one() {
1
}
"
.unindent(),
new_text: "
fn one() {
101
}
"
.unindent(),
description: None,
},
AssistantEditKind::Update {
old_text: "
fn two() {
2
}
"
.unindent(),
new_text: "
fn two() {
102
}
"
.unindent(),
description: None,
},
AssistantEditKind::Update {
old_text: "
fn three() {
3
}
"
.unindent(),
new_text: "
fn three() {
103
}
"
.unindent(),
description: None,
},
],
"
impl Numbers {
fn one() {
101
}
fn two() {
102
}
fn three() {
103
}
}
"
.unindent(),
cx,
);
assert_edits(
"
impl Person {
fn set_name(&mut self, name: String) {
self.name = name;
}
fn name(&self) -> String {
return self.name;
}
}
"
.unindent(),
vec![
AssistantEditKind::Update {
old_text: "self.name = name;".unindent(),
new_text: "self._name = name;".unindent(),
description: None,
},
AssistantEditKind::Update {
old_text: "return self.name;\n".unindent(),
new_text: "return self._name;\n".unindent(),
description: None,
},
],
"
impl Person {
fn set_name(&mut self, name: String) {
self._name = name;
}
fn name(&self) -> String {
return self._name;
}
}
"
.unindent(),
cx,
);
}
fn init_test(cx: &mut App) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
cx.update_global::<SettingsStore, _>(|settings, cx| {
settings.update_user_settings::<AllLanguageSettings>(cx, |_| {});
});
}
#[track_caller]
fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
let (text, _) = marked_text_ranges(text_with_expected_range, false);
let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
let snapshot = buffer.read(cx).snapshot();
let range = AssistantEditKind::resolve_location(&snapshot, query).to_offset(&snapshot);
let text_with_actual_range = generate_marked_text(&text, &[range], false);
pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
}
#[track_caller]
fn assert_edits(
old_text: String,
edits: Vec<AssistantEditKind>,
new_text: String,
cx: &mut App,
) {
let buffer =
cx.new(|cx| Buffer::local(old_text, cx).with_language(Arc::new(rust_lang()), cx));
let snapshot = buffer.read(cx).snapshot();
let resolved_edits = edits
.into_iter()
.map(|kind| kind.resolve(&snapshot))
.collect();
let edit_groups = AssistantPatch::group_edits(resolved_edits, &snapshot);
ResolvedPatch::apply_edit_groups(&edit_groups, &buffer, cx);
let actual_new_text = buffer.read(cx).text();
pretty_assertions::assert_eq!(actual_new_text, new_text);
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(language::tree_sitter_rust::LANGUAGE.into()),
)
.with_indents_query(
r#"
(call_expression) @indent
(field_expression) @indent
(_ "(" ")" @end) @indent
(_ "{" "}" @end) @indent
"#,
)
.unwrap()
}
}

View File

@@ -278,8 +278,8 @@ impl CompletionProvider for SlashCommandCompletionProvider {
buffer.anchor_after(Point::new(position.row, first_arg_start.start as u32));
let arguments = call
.arguments
.into_iter()
.filter_map(|argument| Some(line.get(argument)?.to_string()))
.iter()
.filter_map(|argument| Some(line.get(argument.clone())?.to_string()))
.collect::<Vec<_>>();
let argument_range = first_arg_start..buffer_position;
(

View File

@@ -41,7 +41,6 @@ pub enum NotifyWhenAgentWaiting {
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
#[serde(tag = "name", rename_all = "snake_case")]
#[schemars(deny_unknown_fields)]
pub enum AssistantProviderContentV1 {
#[serde(rename = "zed.dev")]
ZedDotDev { default_model: Option<CloudModel> },
@@ -86,6 +85,7 @@ pub struct AssistantSettings {
pub thread_summary_model: Option<LanguageModelSelection>,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
pub default_profile: AgentProfileId,
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
pub always_allow_tool_actions: bool,
@@ -94,7 +94,6 @@ pub struct AssistantSettings {
pub single_file_review: bool,
pub model_parameters: Vec<LanguageModelParameters>,
pub preferred_completion_mode: CompletionMode,
pub enable_feedback: bool,
}
impl AssistantSettings {
@@ -107,6 +106,10 @@ impl AssistantSettings {
.and_then(|m| m.temperature)
}
pub fn are_live_diffs_enabled(&self, _cx: &App) -> bool {
false
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.inline_assistant_model = Some(LanguageModelSelection {
provider: provider.into(),
@@ -254,6 +257,7 @@ impl AssistantSettingsContent {
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
profiles: None,
always_allow_tool_actions: None,
@@ -262,7 +266,6 @@ impl AssistantSettingsContent {
single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
enable_feedback: None,
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
@@ -285,6 +288,7 @@ impl AssistantSettingsContent {
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
profiles: None,
always_allow_tool_actions: None,
@@ -293,7 +297,6 @@ impl AssistantSettingsContent {
single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
enable_feedback: None,
},
None => AssistantSettingsContentV2::default(),
}
@@ -547,7 +550,6 @@ impl AssistantSettingsContent {
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[serde(tag = "version")]
#[schemars(deny_unknown_fields)]
pub enum VersionedAssistantSettingsContent {
#[serde(rename = "1")]
V1(AssistantSettingsContentV1),
@@ -568,6 +570,7 @@ impl Default for VersionedAssistantSettingsContent {
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
default_profile: None,
profiles: None,
always_allow_tool_actions: None,
@@ -576,13 +579,11 @@ impl Default for VersionedAssistantSettingsContent {
single_file_review: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
enable_feedback: None,
})
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
#[schemars(deny_unknown_fields)]
pub struct AssistantSettingsContentV2 {
/// Whether the Assistant is enabled.
///
@@ -614,6 +615,10 @@ pub struct AssistantSettingsContentV2 {
thread_summary_model: Option<LanguageModelSelection>,
/// Additional models with which to generate alternatives when performing inline assists.
inline_alternatives: Option<Vec<LanguageModelSelection>>,
/// Enable experimental live diffs in the assistant panel.
///
/// Default: false
enable_experimental_live_diffs: Option<bool>,
/// The default profile to use in the Agent.
///
/// Default: write
@@ -651,10 +656,6 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: normal
preferred_completion_mode: Option<CompletionMode>,
/// Whether to show thumb buttons for feedback in the agent panel.
///
/// Default: true
enable_feedback: Option<bool>,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)]
@@ -692,7 +693,7 @@ impl JsonSchema for LanguageModelProviderSetting {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"anthropic".into(),
"amazon-bedrock".into(),
"bedrock".into(),
"google".into(),
"lmstudio".into(),
"ollama".into(),
@@ -745,7 +746,6 @@ pub struct ContextServerPresetContent {
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct AssistantSettingsContentV1 {
/// Whether the Assistant is enabled.
///
@@ -775,7 +775,6 @@ pub struct AssistantSettingsContentV1 {
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct LegacyAssistantSettingsContent {
/// Whether to show the assistant panel button in the status bar.
///
@@ -846,6 +845,10 @@ impl Settings for AssistantSettings {
.thread_summary_model
.or(settings.thread_summary_model.take());
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.enable_experimental_live_diffs,
value.enable_experimental_live_diffs,
);
merge(
&mut settings.always_allow_tool_actions,
value.always_allow_tool_actions,
@@ -861,7 +864,6 @@ impl Settings for AssistantSettings {
&mut settings.preferred_completion_mode,
value.preferred_completion_mode,
);
merge(&mut settings.enable_feedback, value.enable_feedback);
settings
.model_parameters
@@ -992,13 +994,13 @@ mod tests {
dock: None,
default_width: None,
default_height: None,
enable_experimental_live_diffs: None,
default_profile: None,
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
single_file_review: None,
enable_feedback: None,
model_parameters: Vec::new(),
preferred_completion_mode: None,
},

View File

@@ -49,37 +49,6 @@ impl ActionLog {
is_created: bool,
cx: &mut Context<Self>,
) -> &mut TrackedBuffer {
let status = if is_created {
if let Some(tracked) = self.tracked_buffers.remove(&buffer) {
match tracked.status {
TrackedBufferStatus::Created {
existing_file_content,
} => TrackedBufferStatus::Created {
existing_file_content,
},
TrackedBufferStatus::Modified | TrackedBufferStatus::Deleted => {
TrackedBufferStatus::Created {
existing_file_content: Some(tracked.diff_base),
}
}
}
} else if buffer
.read(cx)
.file()
.map_or(false, |file| file.disk_state().exists())
{
TrackedBufferStatus::Created {
existing_file_content: Some(buffer.read(cx).as_rope().clone()),
}
} else {
TrackedBufferStatus::Created {
existing_file_content: None,
}
}
} else {
TrackedBufferStatus::Modified
};
let tracked_buffer = self
.tracked_buffers
.entry(buffer.clone())
@@ -91,21 +60,36 @@ impl ActionLog {
let text_snapshot = buffer.read(cx).text_snapshot();
let diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
let (diff_update_tx, diff_update_rx) = mpsc::unbounded();
let diff_base;
let base_text;
let status;
let unreviewed_changes;
if is_created {
diff_base = Rope::default();
let existing_file_content = if buffer
.read(cx)
.file()
.map_or(false, |file| file.disk_state().exists())
{
Some(text_snapshot.as_rope().clone())
} else {
None
};
base_text = Rope::default();
status = TrackedBufferStatus::Created {
existing_file_content,
};
unreviewed_changes = Patch::new(vec![Edit {
old: 0..1,
new: 0..text_snapshot.max_point().row + 1,
}])
} else {
diff_base = buffer.read(cx).as_rope().clone();
base_text = buffer.read(cx).as_rope().clone();
status = TrackedBufferStatus::Modified;
unreviewed_changes = Patch::default();
}
TrackedBuffer {
buffer: buffer.clone(),
diff_base,
base_text,
unreviewed_changes,
snapshot: text_snapshot.clone(),
status,
@@ -200,7 +184,7 @@ impl ActionLog {
.context("buffer not tracked")?;
let rebase = cx.background_spawn({
let mut base_text = tracked_buffer.diff_base.clone();
let mut base_text = tracked_buffer.base_text.clone();
let old_snapshot = tracked_buffer.snapshot.clone();
let new_snapshot = buffer_snapshot.clone();
let unreviewed_changes = tracked_buffer.unreviewed_changes.clone();
@@ -226,7 +210,7 @@ impl ActionLog {
))
})??;
let (new_base_text, new_diff_base) = rebase.await;
let (new_base_text, new_base_text_rope) = rebase.await;
let diff_snapshot = BufferDiff::update_diff(
diff.clone(),
buffer_snapshot.clone(),
@@ -245,23 +229,24 @@ impl ActionLog {
.background_spawn({
let diff_snapshot = diff_snapshot.clone();
let buffer_snapshot = buffer_snapshot.clone();
let new_diff_base = new_diff_base.clone();
let new_base_text_rope = new_base_text_rope.clone();
async move {
let mut unreviewed_changes = Patch::default();
for hunk in diff_snapshot.hunks_intersecting_range(
Anchor::MIN..Anchor::MAX,
&buffer_snapshot,
) {
let old_range = new_diff_base
let old_range = new_base_text_rope
.offset_to_point(hunk.diff_base_byte_range.start)
..new_diff_base.offset_to_point(hunk.diff_base_byte_range.end);
..new_base_text_rope
.offset_to_point(hunk.diff_base_byte_range.end);
let new_range = hunk.range.start..hunk.range.end;
unreviewed_changes.push(point_to_row_edit(
Edit {
old: old_range,
new: new_range,
},
&new_diff_base,
&new_base_text_rope,
&buffer_snapshot.as_rope(),
));
}
@@ -279,7 +264,7 @@ impl ActionLog {
.tracked_buffers
.get_mut(&buffer)
.context("buffer not tracked")?;
tracked_buffer.diff_base = new_diff_base;
tracked_buffer.base_text = new_base_text_rope;
tracked_buffer.snapshot = buffer_snapshot;
tracked_buffer.unreviewed_changes = unreviewed_changes;
cx.notify();
@@ -298,6 +283,7 @@ impl ActionLog {
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.edited_since_project_diagnostics_check = true;
self.tracked_buffers.remove(&buffer);
self.track_buffer_internal(buffer.clone(), true, cx);
}
@@ -360,11 +346,11 @@ impl ActionLog {
true
} else {
let old_range = tracked_buffer
.diff_base
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.diff_base.point_to_offset(cmp::min(
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.diff_base.max_point(),
tracked_buffer.base_text.max_point(),
));
let new_range = tracked_buffer
.snapshot
@@ -373,7 +359,7 @@ impl ActionLog {
Point::new(edit.new.end, 0),
tracked_buffer.snapshot.max_point(),
));
tracked_buffer.diff_base.replace(
tracked_buffer.base_text.replace(
old_range,
&tracked_buffer
.snapshot
@@ -431,7 +417,7 @@ impl ActionLog {
}
TrackedBufferStatus::Deleted => {
buffer.update(cx, |buffer, cx| {
buffer.set_text(tracked_buffer.diff_base.to_string(), cx)
buffer.set_text(tracked_buffer.base_text.to_string(), cx)
});
let save = self
.project
@@ -478,14 +464,14 @@ impl ActionLog {
if revert {
let old_range = tracked_buffer
.diff_base
.base_text
.point_to_offset(Point::new(edit.old.start, 0))
..tracked_buffer.diff_base.point_to_offset(cmp::min(
..tracked_buffer.base_text.point_to_offset(cmp::min(
Point::new(edit.old.end, 0),
tracked_buffer.diff_base.max_point(),
tracked_buffer.base_text.max_point(),
));
let old_text = tracked_buffer
.diff_base
.base_text
.chunks_in_range(old_range)
.collect::<String>();
edits_to_revert.push((new_range, old_text));
@@ -506,7 +492,7 @@ impl ActionLog {
TrackedBufferStatus::Deleted => false,
_ => {
tracked_buffer.unreviewed_changes.clear();
tracked_buffer.diff_base = tracked_buffer.snapshot.as_rope().clone();
tracked_buffer.base_text = tracked_buffer.snapshot.as_rope().clone();
tracked_buffer.schedule_diff_update(ChangeAuthor::User, cx);
true
}
@@ -669,7 +655,7 @@ enum TrackedBufferStatus {
struct TrackedBuffer {
buffer: Entity<Buffer>,
diff_base: Rope,
base_text: Rope,
unreviewed_changes: Patch<u32>,
status: TrackedBufferStatus,
version: clock::Global,
@@ -1108,86 +1094,6 @@ mod tests {
);
}
#[gpui::test(iterations = 10)]
async fn test_overwriting_previously_edited_files(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/dir"),
json!({
"file1": "Lorem ipsum dolor"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.append(" sit amet consecteur", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![HunkStatus {
range: Point::new(0, 0)..Point::new(0, 37),
diff_status: DiffHunkStatusKind::Modified,
old_text: "Lorem ipsum dolor".into(),
}],
)]
);
cx.update(|cx| {
action_log.update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("rewritten", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.await
.unwrap();
cx.run_until_parked();
assert_eq!(
unreviewed_hunks(&action_log, cx),
vec![(
buffer.clone(),
vec![HunkStatus {
range: Point::new(0, 0)..Point::new(0, 9),
diff_status: DiffHunkStatusKind::Added,
old_text: "".into(),
}],
)]
);
action_log
.update(cx, |log, cx| {
log.reject_edits_in_ranges(buffer.clone(), vec![2..5], cx)
})
.await
.unwrap();
cx.run_until_parked();
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
assert_eq!(
buffer.read_with(cx, |buffer, _cx| buffer.text()),
"Lorem ipsum dolor"
);
}
#[gpui::test(iterations = 10)]
async fn test_deleting_files(cx: &mut TestAppContext) {
init_test(cx);
@@ -1695,7 +1601,7 @@ mod tests {
cx.run_until_parked();
action_log.update(cx, |log, cx| {
let tracked_buffer = log.tracked_buffers.get(&buffer).unwrap();
let mut old_text = tracked_buffer.diff_base.clone();
let mut old_text = tracked_buffer.base_text.clone();
let new_text = buffer.read(cx).as_rope();
for edit in tracked_buffer.unreviewed_changes.edits() {
let old_start = old_text.point_to_offset(Point::new(edit.new.start, 0));

View File

@@ -19,7 +19,6 @@ use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use icons::IconName;
use language_model::LanguageModel;
use language_model::LanguageModelImage;
use language_model::LanguageModelRequest;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
@@ -66,50 +65,21 @@ impl ToolUseStatus {
#[derive(Debug)]
pub struct ToolResultOutput {
pub content: ToolResultContent,
pub content: String,
pub output: Option<serde_json::Value>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum ToolResultContent {
Text(String),
Image(LanguageModelImage),
}
impl ToolResultContent {
pub fn len(&self) -> usize {
match self {
ToolResultContent::Text(str) => str.len(),
ToolResultContent::Image(image) => image.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
ToolResultContent::Text(str) => str.is_empty(),
ToolResultContent::Image(image) => image.is_empty(),
}
}
pub fn as_str(&self) -> Option<&str> {
match self {
ToolResultContent::Text(str) => Some(str),
ToolResultContent::Image(_) => None,
}
}
}
impl From<String> for ToolResultOutput {
fn from(value: String) -> Self {
ToolResultOutput {
content: ToolResultContent::Text(value),
content: value,
output: None,
}
}
}
impl Deref for ToolResultOutput {
type Target = ToolResultContent;
type Target = String;
fn deref(&self) -> &Self::Target {
&self.content

View File

@@ -17,14 +17,14 @@ eval = []
[dependencies]
aho-corasick.workspace = true
anyhow.workspace = true
assistant_settings.workspace = true
assistant_tool.workspace = true
assistant_settings.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
derive_more.workspace = true
editor.workspace = true
derive_more.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -36,7 +36,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
markdown.workspace = true
linkme.workspace = true
open.workspace = true
paths.workspace = true
portable-pty.workspace = true

View File

@@ -40,12 +40,13 @@ use crate::find_path_tool::FindPathTool;
use crate::grep_tool::GrepTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::read_file_tool::ReadFileTool;
use crate::thinking_tool::ThinkingTool;
pub use edit_file_tool::{EditFileMode, EditFileToolInput};
pub use edit_file_tool::EditFileToolInput;
pub use find_path_tool::FindPathToolInput;
pub use open_tool::OpenTool;
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
pub use read_file_tool::ReadFileToolInput;
pub use terminal_tool::TerminalTool;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {

View File

@@ -20,8 +20,7 @@ use language_model::{
LanguageModelToolChoice, MessageContent, Role,
};
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde::Serialize;
use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
use streaming_diff::{CharOperation, StreamingDiff};
@@ -51,10 +50,10 @@ pub enum EditAgentOutputEvent {
OldTextNotFound(SharedString),
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Clone, Debug)]
pub struct EditAgentOutput {
pub raw_edits: String,
pub parser_metrics: EditParserMetrics,
pub _raw_edits: String,
pub _parser_metrics: EditParserMetrics,
}
#[derive(Clone)]
@@ -187,8 +186,8 @@ impl EditAgent {
}
Ok(EditAgentOutput {
raw_edits,
parser_metrics: EditParserMetrics::default(),
_raw_edits: raw_edits,
_parser_metrics: EditParserMetrics::default(),
})
}
@@ -427,8 +426,8 @@ impl EditAgent {
}
}
Ok(EditAgentOutput {
raw_edits,
parser_metrics: parser.finish(),
_raw_edits: raw_edits,
_parser_metrics: parser.finish(),
})
});
(output, rx)

View File

@@ -1,6 +1,4 @@
use derive_more::{Add, AddAssign};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::{cmp, mem, ops::Range};
@@ -15,9 +13,7 @@ pub enum EditParserEvent {
NewTextChunk { chunk: String, done: bool },
}
#[derive(
Clone, Debug, Default, PartialEq, Eq, Add, AddAssign, Serialize, Deserialize, JsonSchema,
)]
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
pub struct EditParserMetrics {
pub tags: usize,
pub mismatched_tags: usize,

View File

@@ -1,9 +1,5 @@
use super::*;
use crate::{
ReadFileToolInput,
edit_file_tool::{EditFileMode, EditFileToolInput},
grep_tool::GrepToolInput,
};
use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
use Role::*;
use anyhow::anyhow;
use assistant_tool::ToolRegistry;
@@ -14,8 +10,8 @@ use futures::{FutureExt, future::LocalBoxFuture};
use gpui::{AppContext, TestAppContext};
use indoc::{formatdoc, indoc};
use language_model::{
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel,
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId,
};
use project::Project;
use rand::prelude::*;
@@ -25,7 +21,6 @@ use std::{
cmp::Reverse,
fmt::{self, Display},
io::Write as _,
str::FromStr,
sync::mpsc,
};
use util::path;
@@ -76,7 +71,7 @@ fn eval_extract_handle_command_output() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
)],
),
@@ -132,7 +127,7 @@ fn eval_delete_run_git_blame() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
)],
),
@@ -187,7 +182,7 @@ fn eval_translate_doc_comments() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
)],
),
@@ -302,7 +297,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
)],
),
@@ -377,7 +372,7 @@ fn eval_disable_cursor_blinking() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
)],
),
@@ -571,7 +566,7 @@ fn eval_from_pixels_constructor() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
)],
),
@@ -648,7 +643,7 @@ fn eval_zode() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Create,
create_or_overwrite: true,
},
),
],
@@ -893,7 +888,7 @@ fn eval_add_overwrite_test() {
EditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
},
),
],
@@ -956,7 +951,7 @@ fn tool_result(
tool_use_id: LanguageModelToolUseId::from(id.into()),
tool_name: name.into(),
is_error: false,
content: LanguageModelToolResultContent::Text(result.into()),
content: result.into(),
output: None,
})
}
@@ -1121,7 +1116,7 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.sample.edit_output.parser_metrics.clone();
cumulative_parser_metrics += output.sample.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
@@ -1202,9 +1197,9 @@ impl Display for EvalOutput {
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.sample.edit_output.parser_metrics
self.sample.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output.raw_edits)?;
writeln!(f, "Raw Edits:\n{}", self.sample.edit_output._raw_edits)?;
Ok(())
}
}
@@ -1217,7 +1212,7 @@ fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usiz
passed_count as f64 / evaluated_count as f64
};
print!(
"\r\x1b[KEvaluated {}/{} ({:.2}% passed)",
"\r\x1b[KEvaluated {}/{} ({:.2}%)",
evaluated_count,
iterations,
passed_ratio * 100.0
@@ -1256,21 +1251,13 @@ impl EditAgentTest {
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let agent_model = SelectedModel::from_str(
&std::env::var("ZED_AGENT_MODEL")
.unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
)
.unwrap();
let judge_model = SelectedModel::from_str(
&std::env::var("ZED_JUDGE_MODEL")
.unwrap_or("anthropic/claude-3-7-sonnet-latest".into()),
)
.unwrap();
let (agent_model, judge_model) = cx
.update(|cx| {
cx.spawn(async move |cx| {
let agent_model = Self::load_model(&agent_model, cx).await;
let judge_model = Self::load_model(&judge_model, cx).await;
let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
let judge_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
(agent_model.unwrap(), judge_model.unwrap())
})
})
@@ -1285,17 +1272,15 @@ impl EditAgentTest {
}
async fn load_model(
selected_model: &SelectedModel,
provider: &str,
id: &str,
cx: &mut AsyncApp,
) -> Result<Arc<dyn LanguageModel>> {
let (provider, model) = cx.update(|cx| {
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| {
model.provider_id() == selected_model.provider
&& model.id() == selected_model.model
})
.find(|model| model.provider_id().0 == provider && model.id().0 == id)
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
(provider, model)

View File

@@ -1249,7 +1249,7 @@ pub struct ActiveDiagnosticGroup {
}
#[derive(Debug, PartialEq, Eq)]
#[allow(clippy::large_enum_variant)]
pub(crate) enum ActiveDiagnostic {
None,
All,

View File

@@ -1,12 +1,11 @@
use crate::{
Templates,
edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent},
edit_agent::{EditAgent, EditAgentOutputEvent},
schema::json_schema_for,
};
use anyhow::{Result, anyhow};
use assistant_tool::{
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput,
ToolUseStatus,
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus,
};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
@@ -21,7 +20,6 @@ use language::{
language_settings::SoftWrap,
};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -76,22 +74,12 @@ pub struct EditFileToolInput {
/// </example>
pub path: PathBuf,
/// The mode of operation on the file. Possible values:
/// - 'edit': Make granular edits to an existing file.
/// - 'create': Create a new file if it doesn't exist.
/// - 'overwrite': Replace the entire contents of an existing file.
/// If true, this tool will recreate the file from scratch.
/// If false, this tool will produce granular edits to an existing file.
///
/// When a file already exists or you just created it, prefer editing
/// When a file already exists or you just created it, always prefer editing
/// it as opposed to recreating it from scratch.
pub mode: EditFileMode,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum EditFileMode {
Edit,
Create,
Overwrite,
pub create_or_overwrite: bool,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -99,7 +87,6 @@ pub struct EditFileToolOutput {
pub original_path: PathBuf,
pub new_text: String,
pub old_text: String,
pub raw_output: Option<EditAgentOutput>,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -205,11 +192,7 @@ impl Tool for EditFileTool {
.as_ref()
.map_or(false, |file| file.disk_state().exists())
})?;
let create_or_overwrite = match input.mode {
EditFileMode::Create | EditFileMode::Overwrite => true,
_ => false,
};
if !create_or_overwrite && !exists {
if !input.create_or_overwrite && !exists {
return Err(anyhow!("{} not found", input.path.display()));
}
@@ -221,7 +204,7 @@ impl Tool for EditFileTool {
})
.await;
let (output, mut events) = if create_or_overwrite {
let (output, mut events) = if input.create_or_overwrite {
edit_agent.overwrite(
buffer.clone(),
input.display_description.clone(),
@@ -264,7 +247,7 @@ impl Tool for EditFileTool {
EditAgentOutputEvent::OldTextNotFound(_) => hallucinated_old_text = true,
}
}
let agent_output = output.await?;
output.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
@@ -284,7 +267,6 @@ impl Tool for EditFileTool {
original_path: project_path.path.to_path_buf(),
new_text: new_text.clone(),
old_text: old_text.clone(),
raw_output: Some(agent_output),
};
if let Some(card) = card_clone {
@@ -307,10 +289,7 @@ impl Tool for EditFileTool {
}
} else {
Ok(ToolResultOutput {
content: ToolResultContent::Text(format!(
"Edited {}:\n\n```diff\n{}\n```",
input_path, diff
)),
content: format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff),
output: serde_json::to_value(output).ok(),
})
}
@@ -356,7 +335,7 @@ pub struct EditFileToolCard {
project: Entity<Project>,
diff_task: Option<Task<Result<()>>>,
preview_expanded: bool,
error_expanded: Option<Entity<Markdown>>,
error_expanded: bool,
full_height_expanded: bool,
total_lines: Option<u32>,
editor_unique_id: EntityId,
@@ -380,7 +359,7 @@ impl EditFileToolCard {
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.disable_scrollbars_and_minimap(window, cx);
editor.disable_scrollbars_and_minimap(cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
@@ -399,7 +378,7 @@ impl EditFileToolCard {
multibuffer,
diff_task: None,
preview_expanded: true,
error_expanded: None,
error_expanded: false,
full_height_expanded: false,
total_lines: None,
}
@@ -456,9 +435,9 @@ impl ToolCard for EditFileToolCard {
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let error_message = match status {
ToolUseStatus::Error(err) => Some(err),
_ => None,
let (failed, error_message) = match status {
ToolUseStatus::Error(err) => (true, Some(err.to_string())),
_ => (false, None),
};
let path_label_button = h_flex()
@@ -546,11 +525,9 @@ impl ToolCard for EditFileToolCard {
.gap_1()
.justify_between()
.rounded_t_md()
.when(error_message.is_none(), |header| {
header.bg(codeblock_header_bg)
})
.when(!failed, |header| header.bg(codeblock_header_bg))
.child(path_label_button)
.when_some(error_message, |header, error_message| {
.when(failed, |header| {
header.child(
h_flex()
.gap_1()
@@ -562,28 +539,19 @@ impl ToolCard for EditFileToolCard {
.child(
Disclosure::new(
("edit-file-error-disclosure", self.editor_unique_id),
self.error_expanded.is_some(),
self.error_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener({
let error_message = error_message.clone();
move |this, _event, _window, cx| {
if this.error_expanded.is_some() {
this.error_expanded.take();
} else {
this.error_expanded = Some(cx.new(|cx| {
Markdown::new(error_message.clone(), None, None, cx)
}))
}
cx.notify();
}
})),
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.error_expanded = !this.error_expanded;
},
)),
),
)
})
.when(error_message.is_none() && self.has_diff(), |header| {
.when(!failed && self.has_diff(), |header| {
header.child(
Disclosure::new(
("edit-file-disclosure", self.editor_unique_id),
@@ -655,7 +623,7 @@ impl ToolCard for EditFileToolCard {
.p_3()
.gap_1()
.border_t_1()
.rounded_b_md()
.rounded_md()
.border_color(border_color)
.bg(cx.theme().colors().editor_background);
@@ -690,12 +658,12 @@ impl ToolCard for EditFileToolCard {
v_flex()
.mb_2()
.border_1()
.when(error_message.is_some(), |card| card.border_dashed())
.when(failed, |card| card.border_dashed())
.border_color(border_color)
.rounded_md()
.overflow_hidden()
.child(codeblock_header)
.when_some(self.error_expanded.as_ref(), |card, error_markdown| {
.when(failed && self.error_expanded, |card| {
card.child(
v_flex()
.p_2()
@@ -715,81 +683,65 @@ impl ToolCard for EditFileToolCard {
.rounded_md()
.text_ui_sm(cx)
.bg(cx.theme().colors().editor_background)
.child(MarkdownElement::new(
error_markdown.clone(),
markdown_style(window, cx),
)),
.children(
error_message
.map(|error| div().child(error).into_any_element()),
),
),
)
})
.when(!self.has_diff() && error_message.is_none(), |card| {
.when(!self.has_diff() && !failed, |card| {
card.child(waiting_for_diff)
})
.when(self.preview_expanded && self.has_diff(), |card| {
card.child(
v_flex()
.relative()
.h_full()
.when(!self.full_height_expanded, |editor_container| {
editor_container
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
})
.overflow_hidden()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.child(editor)
.when(
!self.full_height_expanded && is_collapsible,
|editor_container| editor_container.child(gradient_overlay),
),
)
.when(is_collapsible, |card| {
.when(
!failed && self.preview_expanded && self.has_diff(),
|card| {
card.child(
h_flex()
.id(("expand-button", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
v_flex()
.relative()
.h_full()
.when(!self.full_height_expanded, |editor_container| {
editor_container
.max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
})
.overflow_hidden()
.border_t_1()
.rounded_b_md()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
.child(editor)
.when(
!self.full_height_expanded && is_collapsible,
|editor_container| editor_container.child(gradient_overlay),
),
)
})
})
}
}
fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let ui_font_size = TextSize::Default.rems(cx);
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_fallbacks: theme_settings.ui_font.fallbacks.clone(),
font_features: Some(theme_settings.ui_font.features.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: cx.theme().players().local().selection,
..Default::default()
.when(is_collapsible, |card| {
card.child(
h_flex()
.id(("expand-button", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.border_t_1()
.rounded_b_md()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| {
style.bg(cx.theme().colors().element_hover.opacity(0.1))
})
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
)
})
},
)
}
}
@@ -890,7 +842,7 @@ mod tests {
let input = serde_json::to_value(EditFileToolInput {
display_description: "Some edit".into(),
path: "root/nonexistent_file.txt".into(),
mode: EditFileMode::Edit,
create_or_overwrite: false,
})
.unwrap();
Arc::new(EditFileTool)

View File

@@ -1,8 +1,6 @@
use crate::{schema::json_schema_for, ui::ToolCallCardHeader};
use anyhow::{Result, anyhow};
use assistant_tool::{
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use editor::Editor;
use futures::channel::oneshot::{self, Receiver};
use gpui::{
@@ -40,12 +38,6 @@ pub struct FindPathToolInput {
pub offset: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct FindPathToolOutput {
glob: String,
paths: Vec<PathBuf>,
}
const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool;
@@ -119,18 +111,10 @@ impl Tool for FindPathTool {
)
.unwrap();
}
let output = FindPathToolOutput {
glob,
paths: matches.clone(),
};
for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) {
write!(&mut message, "\n{}", mat.display()).unwrap();
}
Ok(ToolResultOutput {
content: ToolResultContent::Text(message),
output: Some(serde_json::to_value(output)?),
})
Ok(message.into())
}
});
@@ -139,18 +123,6 @@ impl Tool for FindPathTool {
card: Some(card.into()),
}
}
fn deserialize_card(
self: Arc<Self>,
output: serde_json::Value,
_project: Entity<Project>,
_window: &mut Window,
cx: &mut App,
) -> Option<assistant_tool::AnyToolCard> {
let output = serde_json::from_value::<FindPathToolOutput>(output).ok()?;
let card = cx.new(|_| FindPathToolCard::from_output(output));
Some(card.into())
}
}
fn search_paths(glob: &str, project: Entity<Project>, cx: &mut App) -> Task<Result<Vec<PathBuf>>> {
@@ -208,15 +180,6 @@ impl FindPathToolCard {
_receiver_task: Some(_receiver_task),
}
}
fn from_output(output: FindPathToolOutput) -> Self {
Self {
glob: output.glob,
paths: output.paths,
expanded: false,
_receiver_task: None,
}
}
}
impl ToolCard for FindPathToolCard {

View File

@@ -752,9 +752,9 @@ mod tests {
match task.output.await {
Ok(result) => {
if cfg!(windows) {
result.content.as_str().unwrap().replace("root\\", "root/")
result.content.replace("root\\", "root/")
} else {
result.content.as_str().unwrap().to_string()
result.content
}
}
Err(e) => panic!("Failed to run grep tool: {}", e),

View File

@@ -1,17 +1,13 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult};
use assistant_tool::{ToolResultContent, outline};
use gpui::{AnyWindowHandle, App, Entity, Task};
use project::{ImageItem, image_store};
use assistant_tool::ToolResultOutput;
use indoc::formatdoc;
use itertools::Itertools;
use language::{Anchor, Point};
use language_model::{
LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat,
};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::{AgentLocation, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -90,7 +86,7 @@ impl Tool for ReadFileTool {
_request: Arc<LanguageModelRequest>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
model: Arc<dyn LanguageModel>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
@@ -104,42 +100,6 @@ impl Tool for ReadFileTool {
};
let file_path = input.path.clone();
if image_store::is_image_file(&project, &project_path, cx) {
if !model.supports_images() {
return Task::ready(Err(anyhow!(
"Attempted to read an image, but Zed doesn't currently sending images to {}.",
model.name().0
)))
.into();
}
let task = cx.spawn(async move |cx| -> Result<ToolResultOutput> {
let image_entity: Entity<ImageItem> = cx
.update(|cx| {
project.update(cx, |project, cx| {
project.open_image(project_path.clone(), cx)
})
})?
.await?;
let image =
image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?;
let language_model_image = cx
.update(|cx| LanguageModelImage::from_image(image, cx))?
.await
.ok_or_else(|| anyhow!("Failed to process image"))?;
Ok(ToolResultOutput {
content: ToolResultContent::Image(language_model_image),
output: None,
})
});
return task.into();
}
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
@@ -322,10 +282,7 @@ mod test {
.output
})
.await;
assert_eq!(
result.unwrap().content.as_str(),
Some("This is a small file content")
);
assert_eq!(result.unwrap().content, "This is a small file content");
}
#[gpui::test]
@@ -365,7 +322,6 @@ mod test {
})
.await;
let content = result.unwrap();
let content = content.as_str().unwrap();
assert_eq!(
content.lines().skip(4).take(6).collect::<Vec<_>>(),
vec![
@@ -409,8 +365,6 @@ mod test {
.collect::<Vec<_>>();
pretty_assertions::assert_eq!(
content
.as_str()
.unwrap()
.lines()
.skip(4)
.take(expected_content.len())
@@ -454,10 +408,7 @@ mod test {
.output
})
.await;
assert_eq!(
result.unwrap().content.as_str(),
Some("Line 2\nLine 3\nLine 4")
);
assert_eq!(result.unwrap().content, "Line 2\nLine 3\nLine 4");
}
#[gpui::test]
@@ -497,7 +448,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2"));
assert_eq!(result.unwrap().content, "Line 1\nLine 2");
// end_line of 0 should result in at least 1 line
let result = cx
@@ -520,7 +471,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap().content.as_str(), Some("Line 1"));
assert_eq!(result.unwrap().content, "Line 1");
// when start_line > end_line, should still return at least 1 line
let result = cx
@@ -543,7 +494,7 @@ mod test {
.output
})
.await;
assert_eq!(result.unwrap().content.as_str(), Some("Line 3"));
assert_eq!(result.unwrap().content, "Line 3");
}
fn init_test(cx: &mut TestAppContext) {

View File

@@ -1,19 +1,14 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{FutureExt as _, future::Shared};
use gpui::{
AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement,
WeakEntity, Window,
};
use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window};
use language::LineEnding;
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use project::{Project, terminals::TerminalKind};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
env,
path::{Path, PathBuf},
@@ -22,7 +17,6 @@ use std::{
time::{Duration, Instant},
};
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::{Disclosure, Tooltip, prelude::*};
use util::{
get_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
@@ -125,24 +119,14 @@ impl Tool for TerminalTool {
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let working_dir = match working_dir(&input, &project, cx) {
let input_path = Path::new(&input.cd);
let working_dir = match working_dir(&input, &project, input_path, cx) {
Ok(dir) => dir,
Err(err) => return Task::ready(Err(err)).into(),
};
let program = self.determine_shell.clone();
let command = if cfg!(windows) {
format!("$null | & {{{}}}", input.command.replace("\"", "'"))
} else if let Some(cwd) = working_dir
.as_ref()
.and_then(|cwd| cwd.as_os_str().to_str())
{
// Make sure once we're *inside* the shell, we cd into `cwd`
format!("(cd {cwd}; {}) </dev/null", input.command)
} else {
format!("({}) </dev/null", input.command)
};
let args = vec!["-c".into(), command];
let command = format!("({}) </dev/null", input.command);
let args = vec!["-c".into(), command.clone()];
let cwd = working_dir.clone();
let env = match &working_dir {
Some(dir) => project.update(cx, |project, cx| {
@@ -227,21 +211,8 @@ impl Tool for TerminalTool {
}
});
let command_markdown = cx.new(|cx| {
Markdown::new(
format!("```bash\n{}\n```", input.command).into(),
None,
None,
cx,
)
});
let card = cx.new(|cx| {
TerminalToolCard::new(
command_markdown.clone(),
working_dir.clone(),
cx.entity_id(),
)
TerminalToolCard::new(input.command.clone(), working_dir.clone(), cx.entity_id())
});
let output = cx.spawn({
@@ -325,13 +296,19 @@ fn process_content(
} else {
content
};
let content = content.trim();
let is_empty = content.is_empty();
let content = format!("```\n{content}\n```");
let is_empty = content.trim().is_empty();
let content = format!(
"```\n{}{}```",
content,
if content.ends_with('\n') { "" } else { "\n" }
);
let content = if should_truncate {
format!(
"Command output too long. The first {} bytes:\n\n{content}",
"Command output too long. The first {} bytes:\n\n{}",
content.len(),
content,
)
} else {
content
@@ -371,52 +348,47 @@ fn process_content(
fn working_dir(
input: &TerminalToolInput,
project: &Entity<Project>,
input_path: &Path,
cx: &mut App,
) -> Result<Option<PathBuf>> {
let project = project.read(cx);
let cd = &input.cd;
if cd == "." || cd == "" {
// Accept "." or "" as meaning "the one worktree" if we only have one worktree.
if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
match worktrees.next() {
Some(worktree) => {
if worktrees.next().is_none() {
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
} else {
Err(anyhow!(
if worktrees.next().is_some() {
bail!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.",
))
);
}
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
}
None => Ok(None),
}
} else {
let input_path = Path::new(cd);
if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Ok(Some(input_path.into()));
}
} else {
if let Some(worktree) = project.worktree_for_root_name(cd, cx) {
return Ok(Some(worktree.read(cx).abs_path().to_path_buf()));
}
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
bail!("The absolute path must be within one of the project's worktrees");
}
Err(anyhow!(
"`cd` directory {cd:?} was not in any of the project's worktrees."
))
Ok(Some(input_path.into()))
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
bail!("`cd` directory {:?} not found in the project", input.cd);
};
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
}
}
struct TerminalToolCard {
input_command: Entity<Markdown>,
input_command: String,
working_dir: Option<PathBuf>,
entity_id: EntityId,
exit_status: Option<ExitStatus>,
@@ -432,11 +404,7 @@ struct TerminalToolCard {
}
impl TerminalToolCard {
pub fn new(
input_command: Entity<Markdown>,
working_dir: Option<PathBuf>,
entity_id: EntityId,
) -> Self {
pub fn new(input_command: String, working_dir: Option<PathBuf>, entity_id: EntityId) -> Self {
Self {
input_command,
working_dir,
@@ -459,7 +427,7 @@ impl ToolCard for TerminalToolCard {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
_window: &mut Window,
_workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
@@ -603,25 +571,11 @@ impl ToolCard for TerminalToolCard {
.rounded_lg()
.overflow_hidden()
.child(
v_flex()
.p_2()
.gap_0p5()
.bg(header_bg)
.text_xs()
.child(header)
.child(
MarkdownElement::new(
self.input_command.clone(),
markdown_style(window, cx),
)
.code_block_renderer(
markdown::CodeBlockRenderer::Default {
copy_button: false,
copy_button_on_hover: true,
border: false,
},
),
),
v_flex().p_2().gap_0p5().bg(header_bg).child(header).child(
Label::new(self.input_command.clone())
.buffer_font(cx)
.size(LabelSize::Small),
),
)
.when(self.preview_expanded && !should_hide_terminal, |this| {
this.child(
@@ -640,27 +594,6 @@ impl ToolCard for TerminalToolCard {
}
}
fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let buffer_font_size = TextSize::Default.rems(cx);
let mut text_style = window.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
font_features: Some(theme_settings.buffer_font.features.clone()),
font_size: Some(buffer_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});
MarkdownStyle {
base_text_style: text_style.clone(),
selection_background_color: cx.theme().players().local().selection,
..Default::default()
}
}
#[cfg(test)]
mod tests {
use editor::EditorSettings;
@@ -732,8 +665,8 @@ mod tests {
)
});
let output = result.output.await.log_err().unwrap().content;
assert_eq!(output.as_str().unwrap(), "Command executed successfully.");
let output = result.output.await.log_err().map(|output| output.content);
assert_eq!(output, Some("Command executed successfully.".into()));
}
#[gpui::test]
@@ -766,13 +699,12 @@ mod tests {
cx,
);
cx.spawn(async move |_| {
let output = headless_result.output.await.map(|output| output.content);
assert_eq!(
output
.ok()
.and_then(|content| content.as_str().map(ToString::to_string)),
expected
);
let output = headless_result
.output
.await
.log_err()
.map(|output| output.content);
assert_eq!(output, expected);
})
};
@@ -780,7 +712,7 @@ mod tests {
check(
TerminalToolInput {
command: "pwd".into(),
cd: ".".into(),
cd: "project".into(),
},
Some(format!(
"```\n{}\n```",
@@ -795,9 +727,12 @@ mod tests {
check(
TerminalToolInput {
command: "pwd".into(),
cd: "other-project".into(),
cd: ".".into(),
},
None, // other-project is a dir, but *not* a worktree (yet)
Some(format!(
"```\n{}\n```",
tree.path().join("project").display()
)),
cx,
)
})

View File

@@ -3,9 +3,7 @@ use std::{sync::Arc, time::Duration};
use crate::schema::json_schema_for;
use crate::ui::ToolCallCardHeader;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{Future, FutureExt, TryFutureExt};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
@@ -75,13 +73,9 @@ impl Tool for WebSearchTool {
let search_task = search_task.clone();
async move {
let response = search_task.await.map_err(|err| anyhow!(err))?;
Ok(ToolResultOutput {
content: ToolResultContent::Text(
serde_json::to_string(&response)
.context("Failed to serialize search results")?,
),
output: Some(serde_json::to_value(response)?),
})
serde_json::to_string(&response)
.context("Failed to serialize search results")
.map(Into::into)
}
});
@@ -90,18 +84,6 @@ impl Tool for WebSearchTool {
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
}
}
fn deserialize_card(
self: Arc<Self>,
output: serde_json::Value,
_project: Entity<Project>,
_window: &mut Window,
cx: &mut App,
) -> Option<assistant_tool::AnyToolCard> {
let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
Some(card.into())
}
}
#[derive(RegisterComponent)]

View File

@@ -155,7 +155,6 @@ pub fn notify_if_app_was_updated(cx: &mut App) {
}
cx.emit(DismissEvent);
})
.show_suppress_button(false)
})
},
);

View File

@@ -38,7 +38,6 @@ pub enum Model {
AmazonNovaLite,
AmazonNovaMicro,
AmazonNovaPro,
AmazonNovaPremier,
// AI21 models
AI21J2GrandeInstruct,
AI21J2JumboInstruct,
@@ -73,10 +72,6 @@ pub enum Model {
MistralMixtral8x7BInstructV0,
MistralMistralLarge2402V1,
MistralMistralSmall2402V1,
MistralPixtralLarge2502V1,
// Writer models
PalmyraWriterX5,
PalmyraWriterX4,
#[serde(rename = "custom")]
Custom {
name: String,
@@ -125,7 +120,6 @@ impl Model {
Model::AmazonNovaLite => "amazon.nova-lite-v1:0",
Model::AmazonNovaMicro => "amazon.nova-micro-v1:0",
Model::AmazonNovaPro => "amazon.nova-pro-v1:0",
Model::AmazonNovaPremier => "amazon.nova-premier-v1:0",
Model::DeepSeekR1 => "us.deepseek.r1-v1:0",
Model::AI21J2GrandeInstruct => "ai21.j2-grande-instruct",
Model::AI21J2JumboInstruct => "ai21.j2-jumbo-instruct",
@@ -155,9 +149,6 @@ impl Model {
Model::MistralMixtral8x7BInstructV0 => "mistral.mixtral-8x7b-instruct-v0:1",
Model::MistralMistralLarge2402V1 => "mistral.mistral-large-2402-v1:0",
Model::MistralMistralSmall2402V1 => "mistral.mistral-small-2402-v1:0",
Model::MistralPixtralLarge2502V1 => "mistral.pixtral-large-2502-v1:0",
Model::PalmyraWriterX4 => "writer.palmyra-x4-v1:0",
Model::PalmyraWriterX5 => "writer.palmyra-x5-v1:0",
Self::Custom { name, .. } => name,
}
}
@@ -175,7 +166,6 @@ impl Model {
Self::AmazonNovaLite => "Amazon Nova Lite",
Self::AmazonNovaMicro => "Amazon Nova Micro",
Self::AmazonNovaPro => "Amazon Nova Pro",
Self::AmazonNovaPremier => "Amazon Nova Premier",
Self::DeepSeekR1 => "DeepSeek R1",
Self::AI21J2GrandeInstruct => "AI21 Jurassic2 Grande Instruct",
Self::AI21J2JumboInstruct => "AI21 Jurassic2 Jumbo Instruct",
@@ -205,9 +195,6 @@ impl Model {
Self::MistralMixtral8x7BInstructV0 => "Mistral Mixtral 8x7B Instruct V0",
Self::MistralMistralLarge2402V1 => "Mistral Large 2402 V1",
Self::MistralMistralSmall2402V1 => "Mistral Small 2402 V1",
Self::MistralPixtralLarge2502V1 => "Pixtral Large 25.02 V1",
Self::PalmyraWriterX5 => "Writer Palmyra X5",
Self::PalmyraWriterX4 => "Writer Palmyra X4",
Self::Custom {
display_name, name, ..
} => display_name.as_deref().unwrap_or(name),
@@ -221,11 +208,8 @@ impl Model {
| Self::Claude3Sonnet
| Self::Claude3_5Haiku
| Self::Claude3_7Sonnet => 200_000,
Self::AmazonNovaPremier => 1_000_000,
Self::PalmyraWriterX5 => 1_000_000,
Self::PalmyraWriterX4 => 128_000,
Self::Custom { max_tokens, .. } => *max_tokens,
_ => 128_000,
_ => 200_000,
}
}
@@ -233,7 +217,7 @@ impl Model {
match self {
Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3_5Haiku => 4_096,
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => 128_000,
Self::Claude3_5SonnetV2 | Self::PalmyraWriterX4 | Self::PalmyraWriterX5 => 8_192,
Self::Claude3_5SonnetV2 => 8_192,
Self::Custom {
max_output_tokens, ..
} => max_output_tokens.unwrap_or(4_096),
@@ -268,10 +252,7 @@ impl Model {
| Self::Claude3_5Haiku => true,
// Amazon Nova models (all support tool use)
Self::AmazonNovaPremier
| Self::AmazonNovaPro
| Self::AmazonNovaLite
| Self::AmazonNovaMicro => true,
Self::AmazonNovaPro | Self::AmazonNovaLite | Self::AmazonNovaMicro => true,
// AI21 Jamba 1.5 models support tool use
Self::AI21Jamba15LargeV1 | Self::AI21Jamba15MiniV1 => true,
@@ -324,11 +305,8 @@ impl Model {
// Models available only in US
(Model::Claude3Opus, "us")
| (Model::Claude3_5Haiku, "us")
| (Model::Claude3_7Sonnet, "us")
| (Model::Claude3_7SonnetThinking, "us")
| (Model::AmazonNovaPremier, "us")
| (Model::MistralPixtralLarge2502V1, "us") => {
| (Model::Claude3_7SonnetThinking, "us") => {
Ok(format!("{}.{}", region_group, model_id))
}
@@ -362,12 +340,6 @@ impl Model {
Ok(format!("{}.{}", region_group, model_id))
}
// Writer models only available in the US
(Model::PalmyraWriterX4, "us") | (Model::PalmyraWriterX5, "us") => {
// They have some goofiness
Ok(format!("{}.{}", region_group, model_id))
}
// Any other combination is not supported
_ => Ok(self.id().into()),
}

View File

@@ -12,7 +12,6 @@ pub struct CallSettings {
/// Configuration of voice calls in Zed.
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct CallSettingsContent {
/// Whether the microphone should be muted when joining a channel or a call.
///

View File

@@ -19,7 +19,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
anyhow.workspace = true
async-recursion = "0.3"
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
collections.workspace = true
@@ -30,7 +29,6 @@ gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
log.workspace = true
paths.workspace = true
parking_lot.workspace = true
@@ -49,7 +47,6 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http = "0.8"
tokio-native-tls = "0.3"
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
url.workspace = true
util.workspace = true

View File

@@ -1,7 +1,7 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
mod proxy;
mod socks;
pub mod telemetry;
pub mod user;
pub mod zed_urls;
@@ -24,13 +24,13 @@ use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use parking_lot::RwLock;
use postage::watch;
use proxy::connect_proxy_stream;
use rand::prelude::*;
use release_channel::{AppVersion, ReleaseChannel};
use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use socks::connect_socks_proxy_stream;
use std::pin::Pin;
use std::{
any::TypeId,
@@ -49,7 +49,7 @@ use telemetry::Telemetry;
use thiserror::Error;
use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
use util::{ResultExt, TryFutureExt};
pub use rpc::*;
pub use telemetry_events::Event;
@@ -151,19 +151,9 @@ pub fn init(client: &Arc<Client>, cx: &mut App) {
let client = client.clone();
move |_: &SignIn, cx| {
if let Some(client) = client.upgrade() {
cx.spawn(
async move |cx| match client.authenticate_and_connect(true, &cx).await {
ConnectionResult::Timeout => {
log::error!("Initial authentication timed out");
}
ConnectionResult::ConnectionReset => {
log::error!("Initial authentication connection reset");
}
ConnectionResult::Result(r) => {
r.log_err();
}
},
)
cx.spawn(async move |cx| {
client.authenticate_and_connect(true, &cx).log_err().await
})
.detach();
}
}
@@ -668,7 +658,7 @@ impl Client {
state._reconnect_task = None;
}
Status::ConnectionLost => {
let client = self.clone();
let this = self.clone();
state._reconnect_task = Some(cx.spawn(async move |cx| {
#[cfg(any(test, feature = "test-support"))]
let mut rng = StdRng::seed_from_u64(0);
@@ -676,25 +666,10 @@ impl Client {
let mut rng = StdRng::from_entropy();
let mut delay = INITIAL_RECONNECTION_DELAY;
loop {
match client.authenticate_and_connect(true, &cx).await {
ConnectionResult::Timeout => {
log::error!("client connect attempt timed out")
}
ConnectionResult::ConnectionReset => {
log::error!("client connect attempt reset")
}
ConnectionResult::Result(r) => {
if let Err(error) = r {
log::error!("failed to connect: {error}");
} else {
break;
}
}
}
if matches!(*client.status().borrow(), Status::ConnectionError) {
client.set_status(
while let Err(error) = this.authenticate_and_connect(true, &cx).await {
log::error!("failed to connect {}", error);
if matches!(*this.status().borrow(), Status::ConnectionError) {
this.set_status(
Status::ReconnectionError {
next_reconnection: Instant::now() + delay,
},
@@ -852,7 +827,7 @@ impl Client {
self: &Arc<Self>,
try_provider: bool,
cx: &AsyncApp,
) -> ConnectionResult<()> {
) -> anyhow::Result<()> {
let was_disconnected = match *self.status().borrow() {
Status::SignedOut => true,
Status::ConnectionError
@@ -861,14 +836,9 @@ impl Client {
| Status::Reauthenticating { .. }
| Status::ReconnectionError { .. } => false,
Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => {
return ConnectionResult::Result(Ok(()));
}
Status::UpgradeRequired => {
return ConnectionResult::Result(
Err(EstablishConnectionError::UpgradeRequired)
.context("client auth and connect"),
);
return Ok(());
}
Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
};
if was_disconnected {
self.set_status(Status::Authenticating, cx);
@@ -892,12 +862,12 @@ impl Client {
Ok(creds) => credentials = Some(creds),
Err(err) => {
self.set_status(Status::ConnectionError, cx);
return ConnectionResult::Result(Err(err));
return Err(err);
}
}
}
_ = status_rx.next().fuse() => {
return ConnectionResult::Result(Err(anyhow!("authentication canceled")));
return Err(anyhow!("authentication canceled"));
}
}
}
@@ -922,10 +892,10 @@ impl Client {
}
futures::select_biased! {
result = self.set_connection(conn, cx).fuse() => ConnectionResult::Result(result.context("client auth and connect")),
result = self.set_connection(conn, cx).fuse() => result,
_ = timeout => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Timeout
Err(anyhow!("timed out waiting on hello message from server"))
}
}
}
@@ -937,22 +907,22 @@ impl Client {
self.authenticate_and_connect(false, cx).await
} else {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(EstablishConnectionError::Unauthorized).context("client auth and connect"))
Err(EstablishConnectionError::Unauthorized)?
}
}
Err(EstablishConnectionError::UpgradeRequired) => {
self.set_status(Status::UpgradeRequired, cx);
ConnectionResult::Result(Err(EstablishConnectionError::UpgradeRequired).context("client auth and connect"))
Err(EstablishConnectionError::UpgradeRequired)?
}
Err(error) => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Result(Err(error).context("client auth and connect"))
Err(error)?
}
}
}
_ = &mut timeout => {
self.set_status(Status::ConnectionError, cx);
ConnectionResult::Timeout
Err(anyhow!("timed out trying to establish connection"))
}
}
}
@@ -968,7 +938,10 @@ impl Client {
let peer_id = async {
log::debug!("waiting for server hello");
let message = incoming.next().await.context("no hello message received")?;
let message = incoming
.next()
.await
.ok_or_else(|| anyhow!("no hello message received"))?;
log::debug!("got server hello");
let hello_message_type_name = message.payload_type_name().to_string();
let hello = message
@@ -1156,7 +1129,7 @@ impl Client {
let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
let _guard = handle.enter();
match proxy {
Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
Some(proxy) => connect_socks_proxy_stream(&proxy, rpc_host).await?,
None => Box::new(TcpStream::connect(rpc_host).await?),
}
};
@@ -1770,7 +1743,7 @@ mod tests {
status.next().await,
Some(Status::ConnectionError { .. })
));
auth_and_connect.await.into_response().unwrap_err();
auth_and_connect.await.unwrap_err();
// Allow the connection to be established.
let server = FakeServer::for_client(user_id, &client, cx).await;

View File

@@ -1,66 +0,0 @@
//! client proxy
mod http_proxy;
mod socks_proxy;
use anyhow::{Context, Result, anyhow};
use http_client::Url;
use http_proxy::{HttpProxyType, connect_http_proxy_stream, parse_http_proxy};
use socks_proxy::{SocksVersion, connect_socks_proxy_stream, parse_socks_proxy};
pub(crate) async fn connect_proxy_stream(
proxy: &Url,
rpc_host: (&str, u16),
) -> Result<Box<dyn AsyncReadWrite>> {
let Some(((proxy_domain, proxy_port), proxy_type)) = parse_proxy_type(proxy) else {
// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
// SOCKS proxies are often used in contexts where security and privacy are critical,
// so any fallback could expose users to significant risks.
return Err(anyhow!("Parsing proxy url failed"));
};
// Connect to proxy and wrap protocol later
let stream = tokio::net::TcpStream::connect((proxy_domain.as_str(), proxy_port))
.await
.context("Failed to connect to proxy")?;
let proxy_stream = match proxy_type {
ProxyType::SocksProxy(proxy) => connect_socks_proxy_stream(stream, proxy, rpc_host).await?,
ProxyType::HttpProxy(proxy) => {
connect_http_proxy_stream(stream, proxy, rpc_host, &proxy_domain).await?
}
};
Ok(proxy_stream)
}
enum ProxyType<'t> {
SocksProxy(SocksVersion<'t>),
HttpProxy(HttpProxyType<'t>),
}
fn parse_proxy_type<'t>(proxy: &'t Url) -> Option<((String, u16), ProxyType<'t>)> {
let scheme = proxy.scheme();
let host = proxy.host()?.to_string();
let port = proxy.port_or_known_default()?;
let proxy_type = match scheme {
scheme if scheme.starts_with("socks") => {
Some(ProxyType::SocksProxy(parse_socks_proxy(scheme, proxy)))
}
scheme if scheme.starts_with("http") => {
Some(ProxyType::HttpProxy(parse_http_proxy(scheme, proxy)))
}
_ => None,
}?;
Some(((host, port), proxy_type))
}
pub(crate) trait AsyncReadWrite:
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
{
}
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static> AsyncReadWrite
for T
{
}

View File

@@ -1,171 +0,0 @@
use anyhow::{Context, Result};
use base64::Engine;
use httparse::{EMPTY_HEADER, Response};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
net::TcpStream,
};
use tokio_native_tls::{TlsConnector, native_tls};
use url::Url;
use super::AsyncReadWrite;
pub(super) enum HttpProxyType<'t> {
HTTP(Option<HttpProxyAuthorization<'t>>),
HTTPS(Option<HttpProxyAuthorization<'t>>),
}
pub(super) struct HttpProxyAuthorization<'t> {
username: &'t str,
password: &'t str,
}
pub(super) fn parse_http_proxy<'t>(scheme: &str, proxy: &'t Url) -> HttpProxyType<'t> {
let auth = proxy.password().map(|password| HttpProxyAuthorization {
username: proxy.username(),
password,
});
if scheme.starts_with("https") {
HttpProxyType::HTTPS(auth)
} else {
HttpProxyType::HTTP(auth)
}
}
pub(crate) async fn connect_http_proxy_stream(
stream: TcpStream,
http_proxy: HttpProxyType<'_>,
rpc_host: (&str, u16),
proxy_domain: &str,
) -> Result<Box<dyn AsyncReadWrite>> {
match http_proxy {
HttpProxyType::HTTP(auth) => http_connect(stream, rpc_host, auth).await,
HttpProxyType::HTTPS(auth) => https_connect(stream, rpc_host, auth, proxy_domain).await,
}
.context("error connecting to http/https proxy")
}
async fn http_connect<T>(
stream: T,
target: (&str, u16),
auth: Option<HttpProxyAuthorization<'_>>,
) -> Result<Box<dyn AsyncReadWrite>>
where
T: AsyncReadWrite,
{
let mut stream = BufStream::new(stream);
let request = make_request(target, auth);
stream.write_all(request.as_bytes()).await?;
stream.flush().await?;
check_response(&mut stream).await?;
Ok(Box::new(stream))
}
async fn https_connect<T>(
stream: T,
target: (&str, u16),
auth: Option<HttpProxyAuthorization<'_>>,
proxy_domain: &str,
) -> Result<Box<dyn AsyncReadWrite>>
where
T: AsyncReadWrite,
{
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
let stream = tls_connector.connect(proxy_domain, stream).await?;
http_connect(stream, target, auth).await
}
fn make_request(target: (&str, u16), auth: Option<HttpProxyAuthorization<'_>>) -> String {
let (host, port) = target;
let mut request = format!(
"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\nProxy-Connection: Keep-Alive\r\n"
);
if let Some(HttpProxyAuthorization { username, password }) = auth {
let auth =
base64::prelude::BASE64_STANDARD.encode(format!("{username}:{password}").as_bytes());
let auth = format!("Proxy-Authorization: Basic {auth}\r\n");
request.push_str(&auth);
}
request.push_str("\r\n");
request
}
async fn check_response<T>(stream: &mut BufStream<T>) -> Result<()>
where
T: AsyncReadWrite,
{
let response = recv_response(stream).await?;
let mut dummy_headers = [EMPTY_HEADER; MAX_RESPONSE_HEADERS];
let mut parser = Response::new(&mut dummy_headers);
parser.parse(response.as_bytes())?;
match parser.code {
Some(code) => {
if code == 200 {
Ok(())
} else {
Err(anyhow::anyhow!(
"Proxy connection failed with HTTP code: {code}"
))
}
}
None => Err(anyhow::anyhow!(
"Proxy connection failed with no HTTP code: {}",
parser.reason.unwrap_or("Unknown reason")
)),
}
}
const MAX_RESPONSE_HEADER_LENGTH: usize = 4096;
const MAX_RESPONSE_HEADERS: usize = 16;
async fn recv_response<T>(stream: &mut BufStream<T>) -> Result<String>
where
T: AsyncReadWrite,
{
let mut response = String::new();
loop {
if stream.read_line(&mut response).await? == 0 {
return Err(anyhow::anyhow!("End of stream"));
}
if MAX_RESPONSE_HEADER_LENGTH < response.len() {
return Err(anyhow::anyhow!("Maximum response header length exceeded"));
}
if response.ends_with("\r\n\r\n") {
return Ok(response);
}
}
}
#[cfg(test)]
mod tests {
use url::Url;
use super::{HttpProxyAuthorization, HttpProxyType, parse_http_proxy};
#[test]
fn test_parse_http_proxy() {
let proxy = Url::parse("http://proxy.example.com:1080").unwrap();
let scheme = proxy.scheme();
let version = parse_http_proxy(scheme, &proxy);
assert!(matches!(version, HttpProxyType::HTTP(None)))
}
#[test]
fn test_parse_http_proxy_with_auth() {
let proxy = Url::parse("http://username:password@proxy.example.com:1080").unwrap();
let scheme = proxy.scheme();
let version = parse_http_proxy(scheme, &proxy);
assert!(matches!(
version,
HttpProxyType::HTTP(Some(HttpProxyAuthorization {
username: "username",
password: "password"
}))
))
}
}

View File

@@ -1,19 +1,15 @@
//! socks proxy
use anyhow::{Context, Result};
use tokio::net::TcpStream;
use anyhow::{Context, Result, anyhow};
use http_client::Url;
use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
use url::Url;
use super::AsyncReadWrite;
/// Identification to a Socks V4 Proxy
pub(super) struct Socks4Identification<'a> {
struct Socks4Identification<'a> {
user_id: &'a str,
}
/// Authorization to a Socks V5 Proxy
pub(super) struct Socks5Authorization<'a> {
struct Socks5Authorization<'a> {
username: &'a str,
password: &'a str,
}
@@ -22,50 +18,45 @@ pub(super) struct Socks5Authorization<'a> {
///
/// V4 allows idenfication using a user_id
/// V5 allows authorization using a username and password
pub(super) enum SocksVersion<'a> {
enum SocksVersion<'a> {
V4(Option<Socks4Identification<'a>>),
V5(Option<Socks5Authorization<'a>>),
}
pub(super) fn parse_socks_proxy<'t>(scheme: &str, proxy: &'t Url) -> SocksVersion<'t> {
if scheme.starts_with("socks4") {
let identification = match proxy.username() {
"" => None,
username => Some(Socks4Identification { user_id: username }),
};
SocksVersion::V4(identification)
} else {
let authorization = proxy.password().map(|password| Socks5Authorization {
username: proxy.username(),
password,
});
SocksVersion::V5(authorization)
}
}
pub(super) async fn connect_socks_proxy_stream(
stream: TcpStream,
socks_version: SocksVersion<'_>,
pub(crate) async fn connect_socks_proxy_stream(
proxy: &Url,
rpc_host: (&str, u16),
) -> Result<Box<dyn AsyncReadWrite>> {
match socks_version {
let Some((socks_proxy, version)) = parse_socks_proxy(proxy) else {
// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
// SOCKS proxies are often used in contexts where security and privacy are critical,
// so any fallback could expose users to significant risks.
return Err(anyhow!("Parsing proxy url failed"));
};
// Connect to proxy and wrap protocol later
let stream = tokio::net::TcpStream::connect(socks_proxy)
.await
.context("Failed to connect to socks proxy")?;
let socks: Box<dyn AsyncReadWrite> = match version {
SocksVersion::V4(None) => {
let socks = Socks4Stream::connect_with_socket(stream, rpc_host)
.await
.context("error connecting to socks")?;
Ok(Box::new(socks))
Box::new(socks)
}
SocksVersion::V4(Some(Socks4Identification { user_id })) => {
let socks = Socks4Stream::connect_with_userid_and_socket(stream, rpc_host, user_id)
.await
.context("error connecting to socks")?;
Ok(Box::new(socks))
Box::new(socks)
}
SocksVersion::V5(None) => {
let socks = Socks5Stream::connect_with_socket(stream, rpc_host)
.await
.context("error connecting to socks")?;
Ok(Box::new(socks))
Box::new(socks)
}
SocksVersion::V5(Some(Socks5Authorization { username, password })) => {
let socks = Socks5Stream::connect_with_password_and_socket(
@@ -73,9 +64,44 @@ pub(super) async fn connect_socks_proxy_stream(
)
.await
.context("error connecting to socks")?;
Ok(Box::new(socks))
Box::new(socks)
}
}
};
Ok(socks)
}
fn parse_socks_proxy(proxy: &Url) -> Option<((String, u16), SocksVersion<'_>)> {
let scheme = proxy.scheme();
let socks_version = if scheme.starts_with("socks4") {
let identification = match proxy.username() {
"" => None,
username => Some(Socks4Identification { user_id: username }),
};
SocksVersion::V4(identification)
} else if scheme.starts_with("socks") {
let authorization = proxy.password().map(|password| Socks5Authorization {
username: proxy.username(),
password,
});
SocksVersion::V5(authorization)
} else {
return None;
};
let host = proxy.host()?.to_string();
let port = proxy.port_or_known_default()?;
Some(((host, port), socks_version))
}
pub(crate) trait AsyncReadWrite:
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
{
}
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static> AsyncReadWrite
for T
{
}
#[cfg(test)]
@@ -87,18 +113,20 @@ mod tests {
#[test]
fn parse_socks4() {
let proxy = Url::parse("socks4://proxy.example.com:1080").unwrap();
let scheme = proxy.scheme();
let version = parse_socks_proxy(scheme, &proxy);
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
assert_eq!(host, "proxy.example.com");
assert_eq!(port, 1080);
assert!(matches!(version, SocksVersion::V4(None)))
}
#[test]
fn parse_socks4_with_identification() {
let proxy = Url::parse("socks4://userid@proxy.example.com:1080").unwrap();
let scheme = proxy.scheme();
let version = parse_socks_proxy(scheme, &proxy);
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
assert_eq!(host, "proxy.example.com");
assert_eq!(port, 1080);
assert!(matches!(
version,
SocksVersion::V4(Some(Socks4Identification { user_id: "userid" }))
@@ -108,18 +136,20 @@ mod tests {
#[test]
fn parse_socks5() {
let proxy = Url::parse("socks5://proxy.example.com:1080").unwrap();
let scheme = proxy.scheme();
let version = parse_socks_proxy(scheme, &proxy);
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
assert_eq!(host, "proxy.example.com");
assert_eq!(port, 1080);
assert!(matches!(version, SocksVersion::V5(None)))
}
#[test]
fn parse_socks5_with_authorization() {
let proxy = Url::parse("socks5://username:password@proxy.example.com:1080").unwrap();
let scheme = proxy.scheme();
let version = parse_socks_proxy(scheme, &proxy);
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
assert_eq!(host, "proxy.example.com");
assert_eq!(port, 1080);
assert!(matches!(
version,
SocksVersion::V5(Some(Socks5Authorization {
@@ -128,4 +158,19 @@ mod tests {
}))
))
}
/// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
/// SOCKS proxies are often used in contexts where security and privacy are critical,
/// so any fallback could expose users to significant risks.
#[tokio::test]
async fn fails_on_bad_proxy() {
// Should fail connecting because http is not a valid Socks proxy scheme
let proxy = Url::parse("http://localhost:2313").unwrap();
let result = connect_socks_proxy_stream(&proxy, ("test", 1080)).await;
match result {
Err(e) => assert_eq!(e.to_string(), "Parsing proxy url failed"),
Ok(_) => panic!("Connecting on bad proxy should fail"),
};
}
}

View File

@@ -137,14 +137,18 @@ pub fn os_version() -> String {
log::error!("Failed to load /etc/os-release, /usr/lib/os-release");
"".to_string()
};
let mut name = "unknown";
let mut version = "unknown";
let mut name = "unknown".to_string();
let mut version = "unknown".to_string();
for line in content.lines() {
match line.split_once('=') {
Some(("ID", val)) => name = val.trim_matches('"'),
Some(("VERSION_ID", val)) => version = val.trim_matches('"'),
_ => {}
if line.starts_with("ID=") {
name = line.trim_start_matches("ID=").trim_matches('"').to_string();
}
if line.starts_with("VERSION_ID=") {
version = line
.trim_start_matches("VERSION_ID=")
.trim_matches('"')
.to_string();
}
}
@@ -218,7 +222,7 @@ impl Telemetry {
cx.background_spawn({
let state = state.clone();
let os_version = os_version();
state.lock().os_version = Some(os_version);
state.lock().os_version = Some(os_version.clone());
async move {
if let Some(tempfile) = File::create(Self::log_file_path()).log_err() {
state.lock().log_file = Some(tempfile);
@@ -365,7 +369,7 @@ impl Telemetry {
telemetry::event!(
"Editor Edited",
duration = duration,
environment = environment,
environment = environment.to_string(),
is_via_ssh = is_via_ssh
);
}
@@ -427,8 +431,9 @@ impl Telemetry {
if state.flush_events_task.is_none() {
let this = self.clone();
let executor = self.executor.clone();
state.flush_events_task = Some(self.executor.spawn(async move {
this.executor.timer(FLUSH_INTERVAL).await;
executor.timer(FLUSH_INTERVAL).await;
this.flush_events().detach();
}));
}
@@ -479,12 +484,12 @@ impl Telemetry {
self: &Arc<Self>,
// We take in the JSON bytes buffer so we can reuse the existing allocation.
mut json_bytes: Vec<u8>,
event_request: &EventRequestBody,
event_request: EventRequestBody,
) -> Result<Request<AsyncBody>> {
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, event_request)?;
serde_json::to_writer(&mut json_bytes, &event_request)?;
let checksum = calculate_json_checksum(&json_bytes).unwrap_or_default();
let checksum = calculate_json_checksum(&json_bytes).unwrap_or("".to_string());
Ok(Request::builder()
.method(Method::POST)
@@ -501,7 +506,7 @@ impl Telemetry {
pub fn flush_events(self: &Arc<Self>) -> Task<()> {
let mut state = self.state.lock();
state.first_event_date_time = None;
let events = mem::take(&mut state.events_queue);
let mut events = mem::take(&mut state.events_queue);
state.flush_events_task.take();
drop(state);
if events.is_empty() {
@@ -514,7 +519,7 @@ impl Telemetry {
let mut json_bytes = Vec::new();
if let Some(file) = &mut this.state.lock().log_file {
for event in &events {
for event in &mut events {
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, event)?;
file.write_all(&json_bytes)?;
@@ -541,7 +546,7 @@ impl Telemetry {
}
};
let request = this.build_request(json_bytes, &request_body)?;
let request = this.build_request(json_bytes, request_body)?;
let response = this.http_client.send(request).await?;
if response.status() != 200 {
log::error!("Failed to send events: HTTP {:?}", response.status());

View File

@@ -107,7 +107,6 @@ impl FakeServer {
client
.authenticate_and_connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
server

View File

@@ -18,8 +18,7 @@ use stripe::{
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
SubscriptionStatus,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::{ResultExt, maybe};
@@ -430,8 +429,6 @@ enum ManageSubscriptionIntent {
///
/// This will open the Stripe billing portal without putting the user in a specific flow.
ManageSubscription,
/// The user intends to update their payment method.
UpdatePaymentMethod,
/// The user intends to upgrade to Zed Pro.
UpgradeToPro,
/// The user intends to cancel their subscription.
@@ -446,7 +443,6 @@ struct ManageBillingSubscriptionBody {
intent: ManageSubscriptionIntent,
/// The ID of the subscription to manage.
subscription_id: BillingSubscriptionId,
redirect_to: Option<String>,
}
#[derive(Debug, Serialize)]
@@ -544,23 +540,6 @@ async fn manage_billing_subscription(
.map_or(false, |price| price.id == zed_pro_price_id)
});
if is_on_zed_pro_trial {
let payment_methods = PaymentMethod::list(
&stripe_client,
&stripe::ListPaymentMethods {
customer: Some(stripe_subscription.customer.id()),
..Default::default()
},
)
.await?;
let has_payment_method = !payment_methods.data.is_empty();
if !has_payment_method {
return Err(Error::http(
StatusCode::BAD_REQUEST,
"missing payment method".into(),
));
}
// If the user is already on a Zed Pro trial and wants to upgrade to Pro, we just need to end their trial early.
Subscription::update(
&stripe_client,
@@ -610,21 +589,6 @@ async fn manage_billing_subscription(
..Default::default()
})
}
ManageSubscriptionIntent::UpdatePaymentMethod => Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::PaymentMethodUpdate,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
type_: stripe::CreateBillingPortalSessionFlowDataAfterCompletionType::Redirect,
redirect: Some(CreateBillingPortalSessionFlowDataAfterCompletionRedirect {
return_url: format!(
"{}{path}",
app.config.zed_dot_dev_url(),
path = body.redirect_to.unwrap_or_else(|| "/account".to_string())
),
}),
..Default::default()
}),
..Default::default()
}),
ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
@@ -1137,12 +1101,6 @@ async fn handle_customer_subscription_event(
.await?;
}
// When the user's subscription changes, push down any changes to their plan.
rpc_server
.update_plan_for_user(billing_customer.user_id)
.await
.trace_err();
// When the user's subscription changes, we want to refresh their LLM tokens
// to either grant/revoke access.
rpc_server
@@ -1280,7 +1238,7 @@ async fn get_current_usage(
subscription
.kind
.map(Into::into)
.unwrap_or(zed_llm_client::Plan::ZedFree)
.unwrap_or(zed_llm_client::Plan::Free)
});
let model_requests_limit = match plan.model_requests_limit() {

View File

@@ -543,7 +543,7 @@ pub struct MembershipUpdated {
/// The result of setting a member's role.
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum SetMemberRoleResult {
InviteUpdated(Channel),
MembershipUpdated(MembershipUpdated),

View File

@@ -99,7 +99,7 @@ impl From<SubscriptionKind> for zed_llm_client::Plan {
match value {
SubscriptionKind::ZedPro => Self::ZedPro,
SubscriptionKind::ZedProTrial => Self::ZedProTrial,
SubscriptionKind::ZedFree => Self::ZedFree,
SubscriptionKind::ZedFree => Self::Free,
}
}
}

View File

@@ -1,6 +1,9 @@
use crate::Cents;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{billing_subscription, user};
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::llm::{
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
};
use crate::{Config, db::billing_preference};
use anyhow::{Result, anyhow};
use chrono::{NaiveDateTime, Utc};
@@ -25,12 +28,23 @@ pub struct LlmTokenClaims {
pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool,
pub bypass_account_age_check: bool,
pub has_llm_subscription: bool,
#[serde(default)]
pub use_llm_request_queue: bool,
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
#[serde(default)]
pub use_new_billing: bool,
pub plan: Plan,
#[serde(default)]
pub has_extended_trial: bool,
pub subscription_period: (NaiveDateTime, NaiveDateTime),
#[serde(default)]
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
#[serde(default)]
pub enable_model_request_overages: bool,
#[serde(default)]
pub model_request_overages_spend_limit_in_cents: u32,
#[serde(default)]
pub can_use_web_search_tool: bool,
}
@@ -42,6 +56,7 @@ impl LlmTokenClaims {
is_staff: bool,
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
has_legacy_llm_subscription: bool,
subscription: Option<billing_subscription::Model>,
system_id: Option<String>,
config: &Config,
@@ -51,23 +66,6 @@ impl LlmTokenClaims {
.as_ref()
.ok_or_else(|| anyhow!("no LLM API secret"))?;
let plan = if is_staff {
Plan::ZedPro
} else {
subscription
.as_ref()
.and_then(|subscription| subscription.kind)
.map_or(Plan::ZedFree, |kind| match kind {
SubscriptionKind::ZedFree => Plan::ZedFree,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
})
};
let subscription_period =
billing_subscription::Model::current_period(subscription, is_staff)
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
let now = Utc::now();
let claims = Self {
iat: now.timestamp() as u64,
@@ -85,13 +83,38 @@ impl LlmTokenClaims {
bypass_account_age_check: feature_flags
.iter()
.any(|flag| flag == "bypass-account-age-check"),
can_use_web_search_tool: true,
can_use_web_search_tool: feature_flags.iter().any(|flag| flag == "assistant2"),
has_llm_subscription: has_legacy_llm_subscription,
max_monthly_spend_in_cents: billing_preferences
.as_ref()
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
preferences.max_monthly_llm_usage_spending_in_cents as u32
}),
custom_llm_monthly_allowance_in_cents: user
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
use_new_billing: feature_flags.iter().any(|flag| flag == "new-billing"),
use_llm_request_queue: feature_flags.iter().any(|flag| flag == "llm-request-queue"),
plan,
plan: if is_staff {
Plan::ZedPro
} else {
subscription
.as_ref()
.and_then(|subscription| subscription.kind)
.map_or(Plan::Free, |kind| match kind {
SubscriptionKind::ZedFree => Plan::Free,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
})
},
has_extended_trial: feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
subscription_period,
subscription_period: billing_subscription::Model::current_period(
subscription,
is_staff,
)
.map(|(start, end)| (start.naive_utc(), end.naive_utc())),
enable_model_request_overages: billing_preferences
.as_ref()
.map_or(false, |preferences| {
@@ -132,6 +155,12 @@ impl LlmTokenClaims {
}
}
}
pub fn free_tier_monthly_spending_limit(&self) -> Cents {
self.custom_llm_monthly_allowance_in_cents
.map(Cents)
.unwrap_or(FREE_TIER_MONTHLY_SPENDING_LIMIT)
}
}
#[derive(Error, Debug)]

View File

@@ -36,7 +36,6 @@ use util::{ResultExt as _, maybe};
const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
#[expect(clippy::result_large_err)]
#[tokio::main]
async fn main() -> Result<()> {
if let Err(error) = env::load_dotenv() {

View File

@@ -2,7 +2,6 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::db::LlmDatabase;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
use crate::{
AppState, Error, Result, auth,
@@ -68,7 +67,7 @@ use std::{
time::{Duration, Instant},
};
use time::OffsetDateTime;
use tokio::sync::{Semaphore, watch};
use tokio::sync::{MutexGuard, Semaphore, watch};
use tower::ServiceBuilder;
use tracing::{
Instrument,
@@ -167,6 +166,42 @@ impl Session {
}
}
pub async fn has_llm_subscription(
&self,
db: &MutexGuard<'_, DbHandle>,
) -> anyhow::Result<bool> {
if self.is_staff() {
return Ok(true);
}
let user_id = self.user_id();
Ok(db.has_active_billing_subscription(user_id).await?)
}
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
if self.is_staff() {
return Ok(proto::Plan::ZedPro);
}
let user_id = self.user_id();
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
proto::Plan::Free
};
Ok(plan)
}
fn user_id(&self) -> UserId {
match &self.principal {
Principal::User(user) => user.id,
@@ -931,32 +966,6 @@ impl Server {
Ok(())
}
pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
let user = self
.app_state
.db
.get_user_by_id(user_id)
.await?
.ok_or_else(|| anyhow!("user not found"))?;
let update_user_plan = make_update_user_plan_message(
&self.app_state.db,
self.app_state.llm_db.clone(),
user_id,
user.admin,
)
.await?;
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
self.peer
.send(connection_id, update_user_plan.clone())
.trace_err();
}
Ok(())
}
pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
let pool = self.connection_pool.lock();
for connection_id in pool.user_connection_ids(user_id) {
@@ -2692,43 +2701,21 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
if is_staff {
return Ok(proto::Plan::ZedPro);
}
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
let db = session.db().await;
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
let plan = if let Some(subscription_kind) = subscription_kind {
match subscription_kind {
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
SubscriptionKind::ZedFree => proto::Plan::Free,
}
} else {
proto::Plan::Free
};
Ok(plan)
}
async fn make_update_user_plan_message(
db: &Arc<Database>,
llm_db: Option<Arc<LlmDatabase>>,
user_id: UserId,
is_staff: bool,
) -> Result<proto::UpdateUserPlan> {
let feature_flags = db.get_user_flags(user_id).await?;
let plan = current_plan(db, user_id, is_staff).await?;
let plan = session.current_plan(&db).await?;
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
let billing_preferences = db.get_billing_preferences(user_id).await?;
let (subscription_period, usage) = if let Some(llm_db) = llm_db {
let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
let subscription = db.get_active_billing_subscription(user_id).await?;
let subscription_period =
crate::db::billing_subscription::Model::current_period(subscription, is_staff);
let subscription_period = crate::db::billing_subscription::Model::current_period(
subscription,
session.is_staff(),
);
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
llm_db
@@ -2743,92 +2730,92 @@ async fn make_update_user_plan_message(
(None, None)
};
Ok(proto::UpdateUserPlan {
plan: plan.into(),
trial_started_at: billing_customer
.and_then(|billing_customer| billing_customer.trial_started_at)
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
is_usage_based_billing_enabled: if is_staff {
Some(true)
} else {
billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
},
subscription_period: subscription_period.map(|(started_at, ended_at)| {
proto::SubscriptionPeriod {
started_at: started_at.timestamp() as u64,
ended_at: ended_at.timestamp() as u64,
}
}),
usage: usage.map(|usage| {
let plan = match plan {
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
};
let model_requests_limit = match plan.model_requests_limit() {
zed_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == zed_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
{
1_000
} else {
limit
};
zed_llm_client::UsageLimit::Limited(limit)
}
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
};
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
}),
}
}),
})
}
async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
let db = session.db().await;
let update_user_plan = make_update_user_plan_message(
&db.0,
session.app_state.llm_db.clone(),
user_id,
session.is_staff(),
)
.await?;
session
.peer
.send(session.connection_id, update_user_plan)
.send(
session.connection_id,
proto::UpdateUserPlan {
plan: plan.into(),
trial_started_at: billing_customer
.and_then(|billing_customer| billing_customer.trial_started_at)
.map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
is_usage_based_billing_enabled: if session.is_staff() {
Some(true)
} else {
billing_preferences
.map(|preferences| preferences.model_request_overages_enabled)
},
subscription_period: subscription_period.map(|(started_at, ended_at)| {
proto::SubscriptionPeriod {
started_at: started_at.timestamp() as u64,
ended_at: ended_at.timestamp() as u64,
}
}),
usage: usage.map(|usage| {
let plan = match plan {
proto::Plan::Free => zed_llm_client::Plan::Free,
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
};
let model_requests_limit = match plan.model_requests_limit() {
zed_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == zed_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
{
1_000
} else {
limit
};
zed_llm_client::UsageLimit::Limited(limit)
}
zed_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited
}
};
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(
proto::usage_limit::Limited {
limit: limit as u32,
},
)
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(
proto::usage_limit::Unlimited {},
)
}
}),
}),
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(
proto::usage_limit::Limited {
limit: limit as u32,
},
)
}
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(
proto::usage_limit::Unlimited {},
)
}
}),
}),
}
}),
},
)
.trace_err();
Ok(())
@@ -4013,6 +4000,11 @@ async fn get_llm_api_token(
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
if !session.is_staff() && !has_language_models_feature_flag {
Err(anyhow!("permission denied"))?
}
let user_id = session.user_id();
let user = db
@@ -4024,6 +4016,7 @@ async fn get_llm_api_token(
Err(anyhow!("terms of service not accepted"))?
}
let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
let billing_preferences = db.get_billing_preferences(user.id).await?;
@@ -4032,6 +4025,7 @@ async fn get_llm_api_token(
session.is_staff(),
billing_preferences,
&flags,
has_legacy_llm_subscription,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,

View File

@@ -248,8 +248,6 @@ impl StripeBilling {
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {

View File

@@ -36,8 +36,8 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
room.read_with(cx, |room, _| {
let mut remote = room
.remote_participants()
.values()
.map(|participant| participant.user.github_login.clone())
.iter()
.map(|(_, participant)| participant.user.github_login.clone())
.collect::<Vec<_>>();
let mut pending = room
.pending_participants()

View File

@@ -1740,7 +1740,6 @@ async fn test_mutual_editor_inlay_hint_cache_update(
fake_language_server
.request::<lsp::request::InlayHintRefreshRequest>(())
.await
.into_response()
.expect("inlay refresh request failed");
executor.run_until_parked();
@@ -1931,7 +1930,6 @@ async fn test_inlay_hint_refresh_is_forwarded(
fake_language_server
.request::<lsp::request::InlayHintRefreshRequest>(())
.await
.into_response()
.expect("inlay refresh request failed");
executor.run_until_parked();
editor_a.update(cx_a, |editor, _| {

View File

@@ -1253,7 +1253,6 @@ async fn test_calls_on_multiple_connections(
client_b1
.authenticate_and_connect(false, &cx_b1.to_async())
.await
.into_response()
.unwrap();
// User B hangs up, and user A calls them again.
@@ -1634,7 +1633,6 @@ async fn test_project_reconnect(
client_a
.authenticate_and_connect(false, &cx_a.to_async())
.await
.into_response()
.unwrap();
executor.run_until_parked();
@@ -1763,7 +1761,6 @@ async fn test_project_reconnect(
client_b
.authenticate_and_connect(false, &cx_b.to_async())
.await
.into_response()
.unwrap();
executor.run_until_parked();
@@ -4320,7 +4317,6 @@ async fn test_collaborating_with_lsp_progress_updates_and_diagnostics_ordering(
token: lsp::NumberOrString::String("the-disk-based-token".to_string()),
})
.await
.into_response()
.unwrap();
fake_language_server.notify::<lsp::notification::Progress>(&lsp::ProgressParams {
token: lsp::NumberOrString::String("the-disk-based-token".to_string()),
@@ -5703,7 +5699,6 @@ async fn test_contacts(
client_c
.authenticate_and_connect(false, &cx_c.to_async())
.await
.into_response()
.unwrap();
executor.run_until_parked();
@@ -6234,7 +6229,6 @@ async fn test_contact_requests(
client
.authenticate_and_connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
}
}

View File

@@ -581,11 +581,7 @@ async fn test_ssh_collaboration_formatting_with_prettier(
}
#[gpui::test]
async fn test_remote_server_debugger(
cx_a: &mut TestAppContext,
server_cx: &mut TestAppContext,
executor: BackgroundExecutor,
) {
async fn test_remote_server_debugger(cx_a: &mut TestAppContext, server_cx: &mut TestAppContext) {
cx_a.update(|cx| {
release_channel::init(SemanticVersion::default(), cx);
command_palette_hooks::init(cx);
@@ -683,7 +679,7 @@ async fn test_remote_server_debugger(
});
client_ssh.update(cx_a, |a, _| {
a.shutdown_processes(Some(proto::ShutdownRemoteServer {}), executor)
a.shutdown_processes(Some(proto::ShutdownRemoteServer {}))
});
shutdown_session.await.unwrap();

View File

@@ -313,7 +313,6 @@ impl TestServer {
client
.authenticate_and_connect(false, &cx.to_async())
.await
.into_response()
.unwrap();
let client = TestClient {

View File

@@ -42,7 +42,6 @@ futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
language.workspace = true
log.workspace = true
menu.workspace = true
notifications.workspace = true
picker.workspace = true

View File

@@ -1059,7 +1059,7 @@ impl Render for ChatPanel {
.child(
Label::new(format!(
"@{}",
user_being_replied_to.github_login
user_being_replied_to.github_login.clone()
))
.size(LabelSize::Small)
.weight(FontWeight::BOLD),

View File

@@ -378,27 +378,16 @@ impl CollabPanel {
workspace: WeakEntity<Workspace>,
mut cx: AsyncWindowContext,
) -> anyhow::Result<Entity<Self>> {
let serialized_panel = match workspace
.read_with(&cx, |workspace, _| {
CollabPanel::serialization_key(workspace)
})
.ok()
let serialized_panel = cx
.background_spawn(async move { KEY_VALUE_STORE.read_kvp(COLLABORATION_PANEL_KEY) })
.await
.map_err(|_| anyhow::anyhow!("Failed to read collaboration panel from key value store"))
.log_err()
.flatten()
{
Some(serialization_key) => cx
.background_spawn(async move { KEY_VALUE_STORE.read_kvp(&serialization_key) })
.await
.map_err(|_| {
anyhow::anyhow!("Failed to read collaboration panel from key value store")
})
.log_err()
.flatten()
.map(|panel| serde_json::from_str::<SerializedCollabPanel>(&panel))
.transpose()
.log_err()
.flatten(),
None => None,
};
.map(|panel| serde_json::from_str::<SerializedCollabPanel>(&panel))
.transpose()
.log_err()
.flatten();
workspace.update_in(&mut cx, |workspace, window, cx| {
let panel = CollabPanel::new(workspace, window, cx);
@@ -418,30 +407,14 @@ impl CollabPanel {
})
}
fn serialization_key(workspace: &Workspace) -> Option<String> {
workspace
.database_id()
.map(|id| i64::from(id).to_string())
.or(workspace.session_id())
.map(|id| format!("{}-{:?}", COLLABORATION_PANEL_KEY, id))
}
fn serialize(&mut self, cx: &mut Context<Self>) {
let Some(serialization_key) = self
.workspace
.update(cx, |workspace, _| CollabPanel::serialization_key(workspace))
.ok()
.flatten()
else {
return;
};
let width = self.width;
let collapsed_channels = self.collapsed_channels.clone();
self.pending_serialization = cx.background_spawn(
async move {
KEY_VALUE_STORE
.write_kvp(
serialization_key,
COLLABORATION_PANEL_KEY.into(),
serde_json::to_string(&SerializedCollabPanel {
width,
collapsed_channels: Some(
@@ -1490,9 +1463,7 @@ impl CollabPanel {
}
fn cancel(&mut self, _: &Cancel, window: &mut Window, cx: &mut Context<Self>) {
if cx.stop_active_drag(window) {
return;
} else if self.take_editing_state(window, cx) {
if self.take_editing_state(window, cx) {
window.focus(&self.filter_editor.focus_handle(cx));
} else if !self.reset_filter_editor_text(window, cx) {
self.focus_handle.focus(window);
@@ -2254,7 +2225,6 @@ impl CollabPanel {
client
.authenticate_and_connect(true, &cx)
.await
.into_response()
.notify_async_err(cx);
})
.detach()
@@ -3026,12 +2996,10 @@ impl Panel for CollabPanel {
.unwrap_or_else(|| CollaborationPanelSettings::get_global(cx).default_width)
}
fn set_size(&mut self, size: Option<Pixels>, window: &mut Window, cx: &mut Context<Self>) {
fn set_size(&mut self, size: Option<Pixels>, _: &mut Window, cx: &mut Context<Self>) {
self.width = size;
self.serialize(cx);
cx.notify();
cx.defer_in(window, |this, _, cx| {
this.serialize(cx);
});
}
fn icon(&self, _window: &Window, cx: &App) -> Option<ui::IconName> {

View File

@@ -6,8 +6,8 @@ use collections::HashMap;
use db::kvp::KEY_VALUE_STORE;
use futures::StreamExt;
use gpui::{
AnyElement, App, AsyncWindowContext, ClickEvent, Context, CursorStyle, DismissEvent, Element,
Entity, EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment,
AnyElement, App, AsyncWindowContext, Context, CursorStyle, DismissEvent, Element, Entity,
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, ListAlignment,
ListScrollEvent, ListState, ParentElement, Render, StatefulInteractiveElement, Styled, Task,
WeakEntity, Window, actions, div, img, list, px,
};
@@ -22,6 +22,7 @@ use ui::{
Avatar, Button, Icon, IconButton, IconName, Label, Tab, Tooltip, h_flex, prelude::*, v_flex,
};
use util::{ResultExt, TryFutureExt};
use workspace::SuppressNotification;
use workspace::notifications::{
Notification as WorkspaceNotification, NotificationId, SuppressEvent,
};
@@ -646,20 +647,10 @@ impl Render for NotificationPanel {
let client = client.clone();
window
.spawn(cx, async move |cx| {
match client
client
.authenticate_and_connect(true, &cx)
.await
{
util::ConnectionResult::Timeout => {
log::error!("Connection timeout");
}
util::ConnectionResult::ConnectionReset => {
log::error!("Connection reset");
}
util::ConnectionResult::Result(r) => {
r.log_err();
}
}
.log_err()
.await;
})
.detach()
}
@@ -821,50 +812,32 @@ impl NotificationToast {
}
impl Render for NotificationToast {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let user = self.actor.clone();
let suppress = window.modifiers().shift;
let (close_id, close_icon) = if suppress {
("suppress", IconName::Minimize)
} else {
("close", IconName::Close)
};
h_flex()
.id("notification_panel_toast")
.elevation_3(cx)
.p_2()
.justify_between()
.gap_2()
.children(user.map(|user| Avatar::new(user.avatar_uri.clone())))
.child(Label::new(self.text.clone()))
.on_modifiers_changed(cx.listener(|_, _, _, cx| cx.notify()))
.child(
IconButton::new(close_id, close_icon)
.tooltip(move |window, cx| {
if suppress {
Tooltip::for_action(
"Suppress.\nClose with click.",
&workspace::SuppressNotification,
window,
cx,
)
} else {
Tooltip::for_action(
"Close.\nSuppress with shift-click",
&menu::Cancel,
window,
cx,
)
}
IconButton::new("close", IconName::Close)
.tooltip(|window, cx| Tooltip::for_action("Close", &menu::Cancel, window, cx))
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
)
.child(
IconButton::new("suppress", IconName::SquareMinus)
.tooltip(|window, cx| {
Tooltip::for_action(
"Do not show until restart",
&SuppressNotification,
window,
cx,
)
})
.on_click(cx.listener(move |_, _: &ClickEvent, _, cx| {
if suppress {
cx.emit(SuppressEvent);
} else {
cx.emit(DismissEvent);
}
})),
.on_click(cx.listener(|_, _, _, cx| cx.emit(SuppressEvent))),
)
.on_click(cx.listener(|this, _, window, cx| {
this.focus_notification_panel(window, cx);

View File

@@ -7,11 +7,11 @@ use crate::notifications::collab_notification::CollabNotification;
pub struct CollabNotificationStory;
impl Render for CollabNotificationStory {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
let window_container = |width, height| div().w(px(width)).h(px(height));
Story::container(cx)
.child(Story::title_for::<CollabNotification>(cx))
Story::container()
.child(Story::title_for::<CollabNotification>())
.child(
StorySection::new().child(StoryItem::new(
"Incoming Call Notification",

View File

@@ -28,7 +28,6 @@ pub struct ChatPanelSettings {
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct ChatPanelSettingsContent {
/// When to show the panel button in the status bar.
///
@@ -52,7 +51,6 @@ pub struct NotificationPanelSettings {
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct PanelSettingsContent {
/// Whether to show the panel button in the status bar.
///
@@ -69,7 +67,6 @@ pub struct PanelSettingsContent {
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct MessageEditorSettings {
/// Whether to automatically replace emoji shortcodes with emoji characters.
/// For example: typing `:wave:` gets replaced with `👋`.

View File

@@ -14,7 +14,7 @@ path = "src/component.rs"
[dependencies]
collections.workspace = true
gpui.workspace = true
inventory.workspace = true
linkme.workspace = true
parking_lot.workspace = true
strum.workspace = true
theme.workspace = true

View File

@@ -9,12 +9,13 @@
mod component_layout;
use std::sync::LazyLock;
pub use component_layout::*;
use std::sync::LazyLock;
use collections::HashMap;
use gpui::{AnyElement, App, SharedString, Window};
use linkme::distributed_slice;
use parking_lot::RwLock;
use strum::{Display, EnumString};
@@ -23,27 +24,12 @@ pub fn components() -> ComponentRegistry {
}
pub fn init() {
for f in inventory::iter::<ComponentFn>() {
(f.0)();
let component_fns: Vec<_> = __ALL_COMPONENTS.iter().cloned().collect();
for f in component_fns {
f();
}
}
pub struct ComponentFn(fn());
impl ComponentFn {
pub const fn new(f: fn()) -> Self {
Self(f)
}
}
inventory::collect!(ComponentFn);
/// Private internals for macros.
#[doc(hidden)]
pub mod __private {
pub use inventory;
}
pub fn register_component<T: Component>() {
let id = T::id();
let metadata = ComponentMetadata {
@@ -60,6 +46,9 @@ pub fn register_component<T: Component>() {
data.components.insert(id, metadata);
}
#[distributed_slice]
pub static __ALL_COMPONENTS: [fn()] = [..];
pub static COMPONENT_DATA: LazyLock<RwLock<ComponentRegistry>> =
LazyLock::new(|| RwLock::new(ComponentRegistry::default()));

View File

@@ -0,0 +1,37 @@
[package]
name = "component_preview"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/component_preview.rs"
[features]
default = []
[dependencies]
agent.workspace = true
anyhow.workspace = true
client.workspace = true
collections.workspace = true
component.workspace = true
db.workspace = true
futures.workspace = true
gpui.workspace = true
languages.workspace = true
log.workspace = true
notifications.workspace = true
project.workspace = true
prompt_store.workspace = true
serde.workspace = true
ui.workspace = true
ui_input.workspace = true
util.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
assistant_tool.workspace = true

Some files were not shown because too many files have changed in this diff Show More