Compare commits

..

42 Commits

Author SHA1 Message Date
Antonio Scandurra
587ed1e314 WIP 2025-07-01 12:52:08 +02:00
Conrad Irwin
1cf7a0f97b Rename WIPity WIP
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-06-30 17:48:49 -06:00
Conrad Irwin
f9b43cbd1f Re-merge ZedAgent and Thread
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-06-30 17:43:53 -06:00
Conrad Irwin
dab7ca4a84 WIP
Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-06-30 17:01:05 -06:00
Agus Zubiaga
e061fbefae Replace edit_message + delete_messages with new truncate method 2025-06-30 15:46:59 -03:00
Agus Zubiaga
64d19c44e4 Remove Message is_hidden 2025-06-30 15:05:44 -03:00
Agus Zubiaga
e51a0852e1 Replace insert_invisible_continue_message with send_continue_message 2025-06-30 14:39:42 -03:00
Agus Zubiaga
2ea1488aca Replace insert_user_message fully 2025-06-30 13:22:47 -03:00
Agus Zubiaga
c76361d213 Test new retry 2025-06-30 13:16:39 -03:00
Agus Zubiaga
9d7c94a16e Rename send_to_model2 to send_message and fix is_generating 2025-06-30 11:37:48 -03:00
Agus Zubiaga
2af70370e9 Trigger summary generation from send_to_model2 2025-06-30 11:14:51 -03:00
Agus Zubiaga
7725b95571 Replace more insert_user_message usages 2025-06-30 10:58:08 -03:00
Ben Brandt
be3a295ae4 Refactor tool use deserialization
Extract the tool use deserialization logic from `ZedAgent::new` into a
new `DeserializedToolUse` helper struct, so we don't have to clone
messages
2025-06-30 12:09:12 +02:00
Ben Brandt
269f73ab7c Report tool output in finished status 2025-06-30 11:53:24 +02:00
Ben Brandt
90899465a2 Cleanup from merge and clippy warnings 2025-06-30 11:31:56 +02:00
Ben Brandt
34a2d23134 Merge branch 'main' into split-agent-from-thread 2025-06-30 11:11:29 +02:00
Agus Zubiaga
3e2bcb05fb Start using send_to_model2 in message editor
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 19:35:53 -03:00
Agus Zubiaga
f32af6ab52 Checkpoint: Rendering tool uses
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 19:09:34 -03:00
Agus Zubiaga
eef7c07061 Remove MessageSegment::RedactedThinking
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 18:37:38 -03:00
Agus Zubiaga
b1a7812232 BASE_RETRY_DELAY_SECS -> BASE_RETRY_DELAY
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 18:18:21 -03:00
Agus Zubiaga
2f8fa209bc Test send_to_model2
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-27 18:07:23 -03:00
Max Brunsfeld
5e0f3e0ead Start writing assistant messages + tool calls to thread in ZedAgent
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 13:00:19 -07:00
Agus Zubiaga
8776548b02 Use build_request in tests 2025-06-27 12:46:15 -03:00
Agus Zubiaga
82b243e4ea Add user messages to agent request 2025-06-27 12:42:35 -03:00
Agus Zubiaga
b2434e7fef Checkpoint: Handle all retryable errors 2025-06-27 12:05:24 -03:00
Antonio Scandurra
6036c09c1a Checkpoint
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-27 16:00:08 +02:00
Antonio Scandurra
865970d42b WIP
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-27 15:11:50 +02:00
Antonio Scandurra
b9c4f2c7a8 WIP
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 14:26:15 +02:00
Antonio Scandurra
e458ba2293 WIP
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 13:12:45 +02:00
Antonio Scandurra
04c842a7c2 WIP: actually run tools
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 13:02:40 +02:00
Antonio Scandurra
7a055b4865 WIP: start reworking tool use
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
2025-06-27 12:34:27 +02:00
Antonio Scandurra
9eff1c32af Merge remote-tracking branch 'origin/main' into split-agent-from-thread 2025-06-27 10:40:24 +02:00
Ben Brandt
88b1345595 variable cleanup 2025-06-27 10:09:31 +02:00
Max Brunsfeld
a02a0b9c0a Remove some methods that delegate from ZedAgent to Thread
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-06-26 17:56:13 -07:00
Max Brunsfeld
f35fbbb78f Move ActionLog from ZedAgent to Thread
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-06-26 17:37:22 -07:00
Max Brunsfeld
bdeaddc59d Move checkpoints from agent to thread 2025-06-26 16:39:43 -07:00
Conrad Irwin
d5aa609bee Split thread/agent in ActiveThread
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 16:53:57 -06:00
Conrad Irwin
1f0512cd2f Move summary() into ThreadData. Split thread/agent in tests
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 16:32:51 -06:00
Conrad Irwin
438acc98d6 Move messages -> ThreadData
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 16:09:43 -06:00
Conrad Irwin
5cc016291d Factor id -> ThreadData
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 15:34:35 -06:00
Conrad Irwin
61ab3bcd8e Rename Thread -> Agent
Co-authored-by: Agus Zubiaga <agus@zed.dev>
2025-06-26 15:31:53 -06:00
Max Brunsfeld
03478d5715 Inline ToolUseState
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-06-26 12:30:52 -07:00
424 changed files with 12908 additions and 29079 deletions

View File

@@ -30,7 +30,6 @@ jobs:
run_tests: ${{ steps.filter.outputs.run_tests }}
run_license: ${{ steps.filter.outputs.run_license }}
run_docs: ${{ steps.filter.outputs.run_docs }}
run_nix: ${{ steps.filter.outputs.run_nix }}
runs-on:
- ubuntu-latest
steps:
@@ -65,17 +64,11 @@ jobs:
else
echo "run_docs=false" >> $GITHUB_OUTPUT
fi
if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep -P '^(Cargo.lock|script/.*licenses)') ]]; then
if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep '^Cargo.lock') ]]; then
echo "run_license=true" >> $GITHUB_OUTPUT
else
echo "run_license=false" >> $GITHUB_OUTPUT
fi
NIX_REGEX='^(nix/|flake\.|Cargo\.|rust-toolchain.toml|\.cargo/config.toml)'
if [[ $(git diff --name-only $COMPARE_REV ${{ github.sha }} | grep "$NIX_REGEX") ]]; then
echo "run_nix=true" >> $GITHUB_OUTPUT
else
echo "run_nix=false" >> $GITHUB_OUTPUT
fi
migration_checks:
name: Check Postgres and Protobuf migrations, mergability
@@ -753,10 +746,7 @@ jobs:
nix-build:
name: Build with Nix
uses: ./.github/workflows/nix.yml
needs: [job_spec]
if: github.repository_owner == 'zed-industries' &&
(contains(github.event.pull_request.labels.*.name, 'run-nix') ||
needs.job_spec.outputs.run_nix == 'true')
if: github.repository_owner == 'zed-industries' && contains(github.event.pull_request.labels.*.name, 'run-nix')
secrets: inherit
with:
flake-output: debug

61
Cargo.lock generated
View File

@@ -1911,6 +1911,7 @@ dependencies = [
"serde_json",
"strum 0.27.1",
"thiserror 2.0.12",
"tokio",
"workspace-hack",
]
@@ -2076,7 +2077,7 @@ dependencies = [
[[package]]
name = "blade-graphics"
version = "0.6.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"ash",
"ash-window",
@@ -2109,7 +2110,7 @@ dependencies = [
[[package]]
name = "blade-macros"
version = "0.3.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"proc-macro2",
"quote",
@@ -2119,7 +2120,7 @@ dependencies = [
[[package]]
name = "blade-util"
version = "0.2.0"
source = "git+https://github.com/kvark/blade?rev=416375211bb0b5826b3584dccdb6a43369e499ad#416375211bb0b5826b3584dccdb6a43369e499ad"
source = "git+https://github.com/kvark/blade?rev=e0ec4e720957edd51b945b64dd85605ea54bcfe5#e0ec4e720957edd51b945b64dd85605ea54bcfe5"
dependencies = [
"blade-graphics",
"bytemuck",
@@ -4112,6 +4113,7 @@ dependencies = [
"log",
"node_runtime",
"parking_lot",
"paths",
"proto",
"schemars",
"serde",
@@ -4131,7 +4133,7 @@ dependencies = [
[[package]]
name = "dap-types"
version = "0.0.1"
source = "git+https://github.com/zed-industries/dap-types?rev=7f39295b441614ca9dbf44293e53c32f666897f9#7f39295b441614ca9dbf44293e53c32f666897f9"
source = "git+https://github.com/zed-industries/dap-types?rev=b40956a7f4d1939da67429d941389ee306a3a308#b40956a7f4d1939da67429d941389ee306a3a308"
dependencies = [
"schemars",
"serde",
@@ -4146,22 +4148,16 @@ dependencies = [
"async-trait",
"collections",
"dap",
"dotenvy",
"fs",
"futures 0.3.31",
"gpui",
"json_dotpath",
"language",
"log",
"node_runtime",
"paths",
"reqwest_client",
"serde",
"serde_json",
"settings",
"shlex",
"task",
"tempfile",
"util",
"workspace-hack",
]
@@ -4680,6 +4676,12 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "dotenv"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "dotenvy"
version = "0.15.7"
@@ -4812,7 +4814,6 @@ dependencies = [
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"release_channel",
"rpc",
"schemars",
@@ -4833,7 +4834,6 @@ dependencies = [
"tree-sitter-python",
"tree-sitter-rust",
"tree-sitter-typescript",
"tree-sitter-yaml",
"ui",
"unicode-script",
"unicode-segmentation",
@@ -5114,7 +5114,7 @@ dependencies = [
"collections",
"debug_adapter_extension",
"dirs 4.0.0",
"dotenvy",
"dotenv",
"env_logger 0.11.8",
"extension",
"fs",
@@ -8847,7 +8847,6 @@ dependencies = [
"http_client",
"imara-diff",
"indoc",
"inventory",
"itertools 0.14.0",
"log",
"lsp",
@@ -8946,10 +8945,8 @@ dependencies = [
"aws-credential-types",
"aws_http_client",
"bedrock",
"chrono",
"client",
"collections",
"component",
"copilot",
"credentials_provider",
"deepseek",
@@ -9026,6 +9023,7 @@ dependencies = [
"itertools 0.14.0",
"language",
"lsp",
"picker",
"project",
"release_channel",
"serde_json",
@@ -12261,7 +12259,6 @@ dependencies = [
"language",
"log",
"lsp",
"markdown",
"node_runtime",
"parking_lot",
"pathdiff",
@@ -14056,13 +14053,12 @@ dependencies = [
[[package]]
name = "schemars"
version = "1.0.1"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe8c9d1c68d67dd9f97ecbc6f932b60eb289c5dbddd8aa1405484a8fd2fcd984"
checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615"
dependencies = [
"dyn-clone",
"indexmap",
"ref-cast",
"schemars_derive",
"serde",
"serde_json",
@@ -14070,9 +14066,9 @@ dependencies = [
[[package]]
name = "schemars_derive"
version = "1.0.1"
version = "0.8.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ca9fcb757952f8e8629b9ab066fc62da523c46c2b247b1708a3be06dd82530b"
checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d"
dependencies = [
"proc-macro2",
"quote",
@@ -14571,29 +14567,16 @@ dependencies = [
name = "settings_ui"
version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"command_palette",
"command_palette_hooks",
"component",
"db",
"editor",
"feature_flags",
"fs",
"fuzzy",
"gpui",
"language",
"log",
"menu",
"paths",
"project",
"schemars",
"search",
"serde",
"settings",
"theme",
"tree-sitter-json",
"tree-sitter-rust",
"ui",
"util",
"workspace",
@@ -16027,7 +16010,6 @@ dependencies = [
"futures 0.3.31",
"gpui",
"indexmap",
"inventory",
"log",
"palette",
"parking_lot",
@@ -17350,7 +17332,6 @@ dependencies = [
"rand 0.8.5",
"regex",
"rust-embed",
"schemars",
"serde",
"serde_json",
"serde_json_lenient",
@@ -19948,7 +19929,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.195.0"
version = "0.194.0"
dependencies = [
"activity_indicator",
"agent",
@@ -20146,9 +20127,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
version = "0.8.5"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
dependencies = [
"anyhow",
"serde",

View File

@@ -425,9 +425,9 @@ aws-smithy-runtime-api = { version = "1.7.4", features = ["http-1x", "client"] }
aws-smithy-types = { version = "1.3.0", features = ["http-body-1-x"] }
base64 = "0.22"
bitflags = "2.6.0"
blade-graphics = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-util = { git = "https://github.com/kvark/blade", rev = "416375211bb0b5826b3584dccdb6a43369e499ad" }
blade-graphics = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blade-macros = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blade-util = { git = "https://github.com/kvark/blade", rev = "e0ec4e720957edd51b945b64dd85605ea54bcfe5" }
blake3 = "1.5.3"
bytes = "1.0"
cargo_metadata = "0.19"
@@ -444,12 +444,12 @@ core-video = { version = "0.4.3", features = ["metal"] }
cpal = "0.16"
criterion = { version = "0.5", features = ["html_reports"] }
ctor = "0.4.0"
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "7f39295b441614ca9dbf44293e53c32f666897f9" }
dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "b40956a7f4d1939da67429d941389ee306a3a308" }
dashmap = "6.0"
derive_more = "0.99.17"
dirs = "4.0"
documented = "0.9.1"
dotenvy = "0.15.0"
dotenv = "0.15.0"
ec4rs = "1.1"
emojis = "0.6.1"
env_logger = "0.11"
@@ -480,7 +480,7 @@ json_dotpath = "1.1"
jsonschema = "0.30.0"
jsonwebtoken = "9.3"
jupyter-protocol = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
jupyter-websocket-client = { git = "https://github.com/ConradIrwin/runtimed" ,rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
libc = "0.2"
libsqlite3-sys = { version = "0.30.1", features = ["bundled"] }
linkify = "0.10.0"
@@ -491,7 +491,7 @@ metal = "0.29"
moka = { version = "0.12.10", features = ["sync"] }
naga = { version = "25.0", features = ["wgsl-in"] }
nanoid = "0.4"
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
nbformat = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734" }
nix = "0.29"
num-format = "0.4.4"
objc = "0.2"
@@ -531,7 +531,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
"stream",
] }
rsa = "0.9.6"
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
"async-dispatcher-runtime",
] }
rust-embed = { version = "8.4", features = ["include-exclude"] }
@@ -540,7 +540,7 @@ rustc-hash = "2.1.0"
rustls = { version = "0.23.26" }
rustls-platform-verifier = "0.5.0"
scap = { git = "https://github.com/zed-industries/scap", rev = "08f0a01417505cc0990b9931a37e5120db92e0d0", default-features = false }
schemars = { version = "1.0", features = ["indexmap2"] }
schemars = { version = "0.8", features = ["impl_json_schema", "indexmap2"] }
semver = "1.0"
serde = { version = "1.0", features = ["derive", "rc"] }
serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
@@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
zed_llm_client = "= 0.8.5"
zed_llm_client = "0.8.4"
zstd = "0.11"
[workspace.dependencies.async-stripe]

View File

@@ -34,7 +34,7 @@
"ctrl-q": "zed::Quit",
"f4": "debugger::Start",
"shift-f5": "debugger::Stop",
"ctrl-shift-f5": "debugger::RerunSession",
"ctrl-shift-f5": "debugger::Restart",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"ctrl-f11": "debugger::StepInto",
@@ -557,13 +557,6 @@
"ctrl-b": "workspace::ToggleLeftDock",
"ctrl-j": "workspace::ToggleBottomDock",
"ctrl-alt-y": "workspace::CloseAllDocks",
"ctrl-alt-0": "workspace::ResetActiveDockSize",
// For 0px parameter, uses UI font size value.
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
"shift-find": "pane::DeploySearch",
"ctrl-shift-f": "pane::DeploySearch",
"ctrl-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
@@ -605,9 +598,7 @@
// "foo-bar": ["task::Spawn", { "task_name": "MyTask", "reveal_target": "dock" }]
// or by tag:
// "foo-bar": ["task::Spawn", { "task_tag": "MyTag" }],
"f5": "debugger::Rerun",
"ctrl-f4": "workspace::CloseActiveDock",
"ctrl-w": "workspace::CloseActiveDock"
"f5": "debugger::RerunLastSession"
}
},
{
@@ -710,13 +701,6 @@
"pagedown": "editor::ContextMenuLast"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"up": "editor::SignatureHelpPrevious",
"down": "editor::SignatureHelpNext"
}
},
// Custom bindings
{
"bindings": {
@@ -1083,19 +1067,5 @@
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
}
},
{
"context": "KeymapEditor",
"use_key_equivalents": true,
"bindings": {
"ctrl-f": "search::FocusSearch"
}
}
]

View File

@@ -5,10 +5,10 @@
"bindings": {
"f4": "debugger::Start",
"shift-f5": "debugger::Stop",
"shift-cmd-f5": "debugger::RerunSession",
"shift-cmd-f5": "debugger::Restart",
"f6": "debugger::Pause",
"f7": "debugger::StepOver",
"ctrl-f11": "debugger::StepInto",
"f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"home": "menu::SelectFirst",
"shift-pageup": "menu::SelectFirst",
@@ -624,13 +624,6 @@
"cmd-r": "workspace::ToggleRightDock",
"cmd-j": "workspace::ToggleBottomDock",
"alt-cmd-y": "workspace::CloseAllDocks",
// For 0px parameter, uses UI font size value.
"ctrl-alt-0": "workspace::ResetActiveDockSize",
"ctrl-alt--": ["workspace::DecreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-=": ["workspace::IncreaseActiveDockSize", { "px": 0 }],
"ctrl-alt-)": "workspace::ResetOpenDocksSize",
"ctrl-alt-_": ["workspace::DecreaseOpenDocksSize", { "px": 0 }],
"ctrl-alt-+": ["workspace::IncreaseOpenDocksSize", { "px": 0 }],
"cmd-shift-f": "pane::DeploySearch",
"cmd-shift-h": ["pane::DeploySearch", { "replace_enabled": true }],
"cmd-shift-t": "pane::ReopenClosedItem",
@@ -659,8 +652,7 @@
"cmd-k shift-up": "workspace::SwapPaneUp",
"cmd-k shift-down": "workspace::SwapPaneDown",
"cmd-shift-x": "zed::Extensions",
"f5": "debugger::Rerun",
"cmd-w": "workspace::CloseActiveDock"
"f5": "debugger::RerunLastSession"
}
},
{
@@ -774,13 +766,6 @@
"pagedown": "editor::ContextMenuLast"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"up": "editor::SignatureHelpPrevious",
"down": "editor::SignatureHelpNext"
}
},
// Custom bindings
{
"use_key_equivalents": true,
@@ -1182,19 +1167,5 @@
"ctrl-tab": "pane::ActivateNextItem",
"ctrl-shift-tab": "pane::ActivatePreviousItem"
}
},
{
"context": "MarkdownPreview",
"bindings": {
"pageup": "markdown::MovePageUp",
"pagedown": "markdown::MovePageDown"
}
},
{
"context": "KeymapEditor",
"use_key_equivalents": true,
"bindings": {
"cmd-f": "search::FocusSearch"
}
}
]

View File

@@ -98,13 +98,6 @@
"ctrl-n": "editor::ContextMenuNext"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "Workspace",
"bindings": {

View File

@@ -98,13 +98,6 @@
"ctrl-n": "editor::ContextMenuNext"
}
},
{
"context": "Editor && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "Workspace",
"bindings": {

View File

@@ -210,8 +210,7 @@
"ctrl-w space": "editor::OpenExcerptsSplit",
"ctrl-w g space": "editor::OpenExcerptsSplit",
"ctrl-6": "pane::AlternateFile",
"ctrl-^": "pane::AlternateFile",
".": "vim::Repeat"
"ctrl-^": "pane::AlternateFile"
}
},
{
@@ -220,6 +219,7 @@
"ctrl-[": "editor::Cancel",
"escape": "editor::Cancel",
":": "command_palette::Toggle",
".": "vim::Repeat",
"c": "vim::PushChange",
"shift-c": "vim::ChangeToEndOfLine",
"d": "vim::PushDelete",
@@ -327,7 +327,6 @@
"g shift-r": ["vim::Paste", { "preserve_clipboard": true }],
"g c": "vim::ToggleComments",
"g q": "vim::Rewrap",
"g w": "vim::Rewrap",
"g ?": "vim::ConvertToRot13",
// "g ?": "vim::ConvertToRot47",
"\"": "vim::PushRegister",
@@ -478,13 +477,6 @@
"ctrl-n": "editor::ShowWordCompletions"
}
},
{
"context": "vim_mode == insert && showing_signature_help && !showing_completions",
"bindings": {
"ctrl-p": "editor::SignatureHelpPrevious",
"ctrl-n": "editor::SignatureHelpNext"
}
},
{
"context": "vim_mode == replace",
"bindings": {
@@ -857,25 +849,6 @@
"shift-u": "git::UnstageAll"
}
},
{
"context": "Editor && mode == auto_height && VimControl",
"bindings": {
// TODO: Implement search
"/": null,
"?": null,
"#": null,
"*": null,
"n": null,
"shift-n": null
}
},
{
"context": "GitCommit > Editor && VimControl && vim_mode == normal",
"bindings": {
"ctrl-c": "menu::Cancel",
"escape": "menu::Cancel"
}
},
{
"context": "Editor && edit_prediction",
"bindings": {
@@ -887,7 +860,14 @@
{
"context": "MessageEditor > Editor && VimControl",
"bindings": {
"enter": "agent::Chat"
"enter": "agent::Chat",
// TODO: Implement search
"/": null,
"?": null,
"#": null,
"*": null,
"n": null,
"shift-n": null
}
},
{

View File

@@ -617,8 +617,6 @@
// 3. Mark files with errors and warnings:
// "all"
"show_diagnostics": "all",
// Whether to stick parent directories at top of the project panel.
"sticky_scroll": true,
// Settings related to indent guides in the project panel.
"indent_guides": {
// When to show indent guides in the project panel.
@@ -748,6 +746,8 @@
"default_width": 380
},
"agent": {
// Version of this setting.
"version": "2",
// Whether the agent is enabled.
"enabled": true,
/// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'.
@@ -810,7 +810,6 @@
"edit_file": true,
"fetch": true,
"list_directory": true,
"project_notifications": true,
"move_path": true,
"now": true,
"find_path": true,
@@ -830,7 +829,6 @@
"diagnostics": true,
"fetch": true,
"list_directory": true,
"project_notifications": true,
"now": true,
"find_path": true,
"read_file": true,
@@ -1294,8 +1292,6 @@
// Whether or not selecting text in the terminal will automatically
// copy to the system clipboard.
"copy_on_select": false,
// Whether to keep the text selection after copying it to the clipboard
"keep_selection_on_copy": false,
// Whether to show the terminal button in the status bar
"button": true,
// Any key-value pairs added to this list will be added to the terminal's
@@ -1660,6 +1656,7 @@
// Different settings for specific language models.
"language_models": {
"anthropic": {
"version": "1",
"api_url": "https://api.anthropic.com"
},
"google": {
@@ -1669,6 +1666,7 @@
"api_url": "http://localhost:11434"
},
"openai": {
"version": "1",
"api_url": "https://api.openai.com/v1"
},
"open_router": {
@@ -1786,8 +1784,7 @@
// `socks5h`. `http` will be used when no scheme is specified.
//
// By default no proxy will be used, or Zed will try get proxy settings from
// environment variables. If certain hosts should not be proxied,
// set the `no_proxy` environment variable and provide a comma-separated list.
// environment variables.
//
// Examples:
// - "proxy": "socks5h://localhost:10808"

View File

@@ -31,13 +31,7 @@ use workspace::{StatusItemView, Workspace, item::ItemHandle};
const GIT_OPERATION_DELAY: Duration = Duration::from_millis(0);
actions!(
activity_indicator,
[
/// Displays error messages from language servers in the status bar.
ShowErrorMessage
]
);
actions!(activity_indicator, [ShowErrorMessage]);
pub enum Event {
ShowStatus {

View File

@@ -68,6 +68,7 @@ zstd.workspace = true
[dev-dependencies]
assistant_tools.workspace = true
assistant_tool = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
language = { workspace = true, "features" = ["test-support"] }

View File

@@ -5,13 +5,12 @@ pub mod context_store;
pub mod history_store;
pub mod thread;
pub mod thread_store;
pub mod tool_use;
pub use context::{AgentContext, ContextId, ContextLoadResult};
pub use context_store::ContextStore;
pub use thread::{
LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError,
ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio,
LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, ThreadError,
ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgentThread,
};
pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore};

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings};
use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName};
use assistant_tool::{Tool, ToolSource, ToolWorkingSet};
use collections::IndexMap;
use convert_case::{Case, Casing};
use fs::Fs;
@@ -72,7 +72,7 @@ impl AgentProfile {
&self.id
}
pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else {
return Vec::new();
};
@@ -81,7 +81,7 @@ impl AgentProfile {
.read(cx)
.tools(cx)
.into_iter()
.filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name()))
.filter(|tool| Self::is_enabled(settings, tool.source(), tool.name()))
.collect()
}
@@ -96,11 +96,16 @@ impl AgentProfile {
fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool {
match source {
ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false),
ToolSource::ContextServer { id } => settings
.context_servers
.get(id.as_ref())
.and_then(|preset| preset.tools.get(name.as_str()).copied())
.unwrap_or(settings.enable_all_context_servers),
ToolSource::ContextServer { id } => {
if settings.enable_all_context_servers {
return true;
}
let Some(preset) = settings.context_servers.get(id.as_ref()) else {
return false;
};
*preset.tools.get(name.as_str()).unwrap_or(&false)
}
}
}
}
@@ -137,7 +142,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -174,7 +179,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -207,7 +212,7 @@ mod tests {
let mut enabled_tools = cx
.read(|cx| profile.enabled_tools(cx))
.into_iter()
.map(|(_, tool)| tool.name())
.map(|tool| tool.name())
.collect::<Vec<_>>();
enabled_tools.sort();
@@ -267,10 +272,10 @@ mod tests {
}
fn default_tool_set(cx: &mut TestAppContext) -> Entity<ToolWorkingSet> {
cx.new(|cx| {
cx.new(|_| {
let mut tool_set = ToolWorkingSet::default();
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx);
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx);
tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")));
tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")));
tool_set
})
}

View File

@@ -1,4 +1,4 @@
use crate::thread::Thread;
use crate::thread::ZedAgentThread;
use assistant_context::AssistantContext;
use assistant_tool::outline;
use collections::HashSet;
@@ -560,7 +560,7 @@ impl Display for FetchedUrlContext {
#[derive(Debug, Clone)]
pub struct ThreadContextHandle {
pub thread: Entity<Thread>,
pub agent: Entity<ZedAgentThread>,
pub context_id: ContextId,
}
@@ -573,23 +573,23 @@ pub struct ThreadContext {
impl ThreadContextHandle {
pub fn eq_for_key(&self, other: &Self) -> bool {
self.thread == other.thread
self.agent == other.agent
}
pub fn hash_for_key<H: Hasher>(&self, state: &mut H) {
self.thread.hash(state)
self.agent.hash(state)
}
pub fn title(&self, cx: &App) -> SharedString {
self.thread.read(cx).summary().or_default()
self.agent.read(cx).summary().or_default()
}
fn load(self, cx: &App) -> Task<Option<(AgentContext, Vec<Entity<Buffer>>)>> {
cx.spawn(async move |cx| {
let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?;
let text = ZedAgentThread::wait_for_detailed_summary_or_text(&self.agent, cx).await?;
let title = self
.thread
.read_with(cx, |thread, _cx| thread.summary().or_default())
.agent
.read_with(cx, |thread, _| thread.summary().or_default())
.ok()?;
let context = AgentContext::Thread(ThreadContext {
title,

View File

@@ -4,7 +4,7 @@ use crate::{
FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
},
thread::{MessageId, Thread, ThreadId},
thread::{MessageId, ThreadId, ZedAgentThread},
thread_store::ThreadStore,
};
use anyhow::{Context as _, Result, anyhow};
@@ -66,8 +66,9 @@ impl ContextStore {
pub fn new_context_for_thread(
&self,
thread: &Thread,
thread: &ZedAgentThread,
exclude_messages_from_id: Option<MessageId>,
_cx: &App,
) -> Vec<AgentContextHandle> {
let existing_context = thread
.messages()
@@ -206,12 +207,15 @@ impl ContextStore {
pub fn add_thread(
&mut self,
thread: Entity<Thread>,
thread: Entity<ZedAgentThread>,
remove_if_exists: bool,
cx: &mut Context<Self>,
) -> Option<AgentContextHandle> {
let context_id = self.next_context_id.post_inc();
let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
let context = AgentContextHandle::Thread(ThreadContextHandle {
agent: thread,
context_id,
});
if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
if remove_if_exists {
@@ -387,7 +391,10 @@ impl ContextStore {
if let Some(thread) = thread.upgrade() {
let context_id = self.next_context_id.post_inc();
self.insert_context(
AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
AgentContextHandle::Thread(ThreadContextHandle {
agent: thread,
context_id,
}),
cx,
);
}
@@ -411,11 +418,11 @@ impl ContextStore {
match &context {
AgentContextHandle::Thread(thread_context) => {
if let Some(thread_store) = self.thread_store.clone() {
thread_context.thread.update(cx, |thread, cx| {
thread_context.agent.update(cx, |thread, cx| {
thread.start_generating_detailed_summary_if_needed(thread_store, cx);
});
self.context_thread_ids
.insert(thread_context.thread.read(cx).id().clone());
.insert(thread_context.agent.read(cx).id().clone());
} else {
return false;
}
@@ -441,7 +448,7 @@ impl ContextStore {
match context {
AgentContextHandle::Thread(thread_context) => {
self.context_thread_ids
.remove(thread_context.thread.read(cx).id());
.remove(thread_context.agent.read(cx).id());
}
AgentContextHandle::TextThread(text_thread_context) => {
if let Some(path) = text_thread_context.context.read(cx).path() {
@@ -570,7 +577,7 @@ pub enum SuggestedContext {
},
Thread {
name: SharedString,
thread: WeakEntity<Thread>,
thread: WeakEntity<ZedAgentThread>,
},
TextThread {
name: SharedString,

View File

@@ -1,3 +0,0 @@
[The following is an auto-generated notification; do not reply]
These files have changed since the last read:

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,12 @@
use crate::{
context_server_tool::ContextServerTool,
thread::{
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, ThreadId, ZedAgentThread,
},
};
use agent_settings::{AgentProfileId, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{Tool, ToolId, ToolWorkingSet};
use assistant_tool::{ToolId, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::ContextServerId;
@@ -400,9 +400,9 @@ impl ThreadStore {
self.threads.iter()
}
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<ZedAgentThread> {
cx.new(|cx| {
Thread::new(
ZedAgentThread::new(
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
@@ -416,9 +416,9 @@ impl ThreadStore {
&mut self,
serialized: SerializedThread,
cx: &mut Context<Self>,
) -> Entity<Thread> {
) -> Entity<ZedAgentThread> {
cx.new(|cx| {
Thread::deserialize(
ZedAgentThread::deserialize(
ThreadId::new(),
serialized,
self.project.clone(),
@@ -436,7 +436,7 @@ impl ThreadStore {
id: &ThreadId,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<Entity<Thread>>> {
) -> Task<Result<Entity<ZedAgentThread>>> {
let id = id.clone();
let database_future = ThreadsDatabase::global_future(cx);
let this = cx.weak_entity();
@@ -449,7 +449,7 @@ impl ThreadStore {
let thread = this.update_in(cx, |this, window, cx| {
cx.new(|cx| {
Thread::deserialize(
ZedAgentThread::deserialize(
id.clone(),
thread,
this.project.clone(),
@@ -466,9 +466,14 @@ impl ThreadStore {
})
}
pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
let (metadata, serialized_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
pub fn save_thread(
&self,
thread: &Entity<ZedAgentThread>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let (metadata, serialized_thread) = thread.update(cx, |thread, cx| {
(thread.id().clone(), thread.serialize(cx))
});
let database_future = ThreadsDatabase::global_future(cx);
cx.spawn(async move |this, cx| {
@@ -537,8 +542,8 @@ impl ThreadStore {
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.update(cx, |tool_working_set, cx| {
tool_working_set.remove(&tool_ids, cx);
tool_working_set.update(cx, |tool_working_set, _| {
tool_working_set.remove(&tool_ids);
});
}
}
@@ -569,17 +574,19 @@ impl ThreadStore {
.log_err()
{
let tool_ids = tool_working_set
.update(cx, |tool_working_set, cx| {
tool_working_set.extend(
response.tools.into_iter().map(|tool| {
Arc::new(ContextServerTool::new(
.update(cx, |tool_working_set, _| {
response
.tools
.into_iter()
.map(|tool| {
log::info!("registering context server tool: {:?}", tool.name);
tool_working_set.insert(Arc::new(ContextServerTool::new(
context_server_store.clone(),
server.id(),
tool,
)) as Arc<dyn Tool>
}),
cx,
)
)))
})
.collect::<Vec<_>>()
})
.log_err();
@@ -698,7 +705,7 @@ impl SerializedThreadV0_1_0 {
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedMessage {
pub id: MessageId,
pub role: Role,
@@ -712,11 +719,9 @@ pub struct SerializedMessage {
pub context: String,
#[serde(default)]
pub creases: Vec<SerializedCrease>,
#[serde(default)]
pub is_hidden: bool,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "type")]
pub enum SerializedMessageSegment {
#[serde(rename = "text")]
@@ -734,14 +739,14 @@ pub enum SerializedMessageSegment {
},
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub input: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedToolResult {
pub tool_use_id: LanguageModelToolUseId,
pub is_error: bool,
@@ -799,12 +804,11 @@ impl LegacySerializedMessage {
tool_results: self.tool_results,
context: String::new(),
creases: Vec::new(),
is_hidden: false,
}
}
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SerializedCrease {
pub start: usize,
pub end: usize,
@@ -1103,7 +1107,6 @@ mod tests {
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false
}],
version: SerializedThread::VERSION.to_string(),
initial_project_snapshot: None,
@@ -1136,7 +1139,6 @@ mod tests {
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
SerializedMessage {
id: MessageId(2),
@@ -1152,7 +1154,6 @@ mod tests {
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
SerializedMessage {
id: MessageId(1),
@@ -1169,7 +1170,6 @@ mod tests {
}],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
],
version: SerializedThreadV0_1_0::VERSION.to_string(),
@@ -1201,7 +1201,6 @@ mod tests {
tool_results: vec![],
context: "".to_string(),
creases: vec![],
is_hidden: false
},
SerializedMessage {
id: MessageId(2),
@@ -1222,7 +1221,6 @@ mod tests {
}],
context: "".to_string(),
creases: vec![],
is_hidden: false,
},
],
version: SerializedThread::VERSION.to_string(),

View File

@@ -1,567 +0,0 @@
use crate::{
thread::{MessageId, PromptId, ThreadId},
thread_store::SerializedMessage,
};
use anyhow::Result;
use assistant_tool::{
AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet,
};
use collections::HashMap;
use futures::{FutureExt as _, future::Shared};
use gpui::{App, Entity, SharedString, Task, Window};
use icons::IconName;
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, Role,
};
use project::Project;
use std::sync::Arc;
use util::truncate_lines_to_byte_limit;
#[derive(Debug)]
pub struct ToolUse {
pub id: LanguageModelToolUseId,
pub name: SharedString,
pub ui_text: SharedString,
pub status: ToolUseStatus,
pub input: serde_json::Value,
pub icon: icons::IconName,
pub needs_confirmation: bool,
}
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
}
impl ToolUseState {
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
Self {
tools,
tool_uses_by_assistant_message: HashMap::default(),
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
tool_use_metadata_by_id: HashMap::default(),
}
}
/// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s.
///
/// Accepts a function to filter the tools that should be used to populate the state.
///
/// If `window` is `None` (e.g., when in headless mode or when running evals),
/// tool cards won't be deserialized
pub fn from_serialized_messages(
tools: Entity<ToolWorkingSet>,
messages: &[SerializedMessage],
project: Entity<Project>,
window: Option<&mut Window>, // None in headless mode
cx: &mut App,
) -> Self {
let mut this = Self::new(tools);
let mut tool_names_by_id = HashMap::default();
let mut window = window;
for message in messages {
match message.role {
Role::Assistant => {
if !message.tool_uses.is_empty() {
let tool_uses = message
.tool_uses
.iter()
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
raw_input: tool_use.input.to_string(),
input: tool_use.input.clone(),
is_input_complete: true,
})
.collect::<Vec<_>>();
tool_names_by_id.extend(
tool_uses
.iter()
.map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
);
this.tool_uses_by_assistant_message
.insert(message.id, tool_uses);
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
log::warn!("no tool name found for tool use: {tool_use_id:?}");
continue;
};
this.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
output: tool_result.output.clone(),
},
);
if let Some(window) = &mut window {
if let Some(tool) = this.tools.read(cx).tool(tool_use, cx) {
if let Some(output) = tool_result.output.clone() {
if let Some(card) = tool.deserialize_card(
output,
project.clone(),
window,
cx,
) {
this.tool_result_cards.insert(tool_use_id, card);
}
}
}
}
}
}
}
Role::System | Role::User => {}
}
}
this
}
pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
let mut cancelled_tool_uses = Vec::new();
self.pending_tool_uses_by_id
.retain(|tool_use_id, tool_use| {
if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) {
return true;
}
let content = "Tool canceled by user".into();
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name: tool_use.name.clone(),
content,
output: None,
is_error: true,
},
);
cancelled_tool_uses.push(tool_use.clone());
false
});
cancelled_tool_uses
}
pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
self.pending_tool_uses_by_id.values().collect()
}
pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
return Vec::new();
};
let mut tool_uses = Vec::new();
for tool_use in tool_uses_for_message.iter() {
let tool_result = self.tool_results.get(&tool_use.id);
let status = (|| {
if let Some(tool_result) = tool_result {
let content = tool_result
.content
.to_str()
.map(|str| str.to_owned().into())
.unwrap_or_default();
return if tool_result.is_error {
ToolUseStatus::Error(content)
} else {
ToolUseStatus::Finished(content)
};
}
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
match pending_tool_use.status {
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
PendingToolUseStatus::NeedsConfirmation { .. } => {
ToolUseStatus::NeedsConfirmation
}
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
PendingToolUseStatus::InputStillStreaming => {
ToolUseStatus::InputStillStreaming
}
}
} else {
ToolUseStatus::Pending
}
})();
let (icon, needs_confirmation) =
if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
(tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
} else {
(IconName::Cog, false)
};
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
),
input: tool_use.input.clone(),
status,
icon,
needs_confirmation,
})
}
tool_uses
}
pub fn tool_ui_label(
&self,
tool_name: &str,
input: &serde_json::Value,
is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
if is_input_complete {
tool.ui_text(input).into()
} else {
tool.still_streaming_ui_text(input).into()
}
} else {
format!("Unknown tool {tool_name:?}").into()
}
}
pub fn tool_results_for_message(
&self,
assistant_message_id: MessageId,
) -> Vec<&LanguageModelToolResult> {
let Some(tool_uses) = self
.tool_uses_by_assistant_message
.get(&assistant_message_id)
else {
return Vec::new();
};
tool_uses
.iter()
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
.collect()
}
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.map_or(false, |results| !results.is_empty())
}
pub fn tool_result(
&self,
tool_use_id: &LanguageModelToolUseId,
) -> Option<&LanguageModelToolResult> {
self.tool_results.get(tool_use_id)
}
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
self.tool_result_cards.get(tool_use_id)
}
pub fn insert_tool_result_card(
&mut self,
tool_use_id: LanguageModelToolUseId,
card: AnyToolCard,
) {
self.tool_result_cards.insert(tool_use_id, card);
}
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
) -> Arc<str> {
let tool_uses = self
.tool_uses_by_assistant_message
.entry(assistant_message_id)
.or_default();
let mut existing_tool_use_found = false;
for existing_tool_use in tool_uses.iter_mut() {
if existing_tool_use.id == tool_use.id {
*existing_tool_use = tool_use.clone();
existing_tool_use_found = true;
}
}
if !existing_tool_use_found {
tool_uses.push(tool_use.clone());
}
let status = if tool_use.is_input_complete {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
};
let ui_text: Arc<str> = self
.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
)
.into();
let may_perform_edits = self
.tools
.read(cx)
.tool(&tool_use.name, cx)
.is_some_and(|tool| tool.may_perform_edits());
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
PendingToolUse {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: ui_text.clone(),
input: tool_use.input,
may_perform_edits,
status,
},
);
ui_text
}
pub fn run_pending_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: SharedString,
task: Task<()>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.ui_text = ui_text.into();
tool_use.status = PendingToolUseStatus::Running {
_task: task.shared(),
};
}
}
pub fn confirm_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
ui_text: impl Into<Arc<str>>,
input: serde_json::Value,
request: Arc<LanguageModelRequest>,
tool: Arc<dyn Tool>,
) {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
let ui_text = ui_text.into();
tool_use.ui_text = ui_text.clone();
let confirmation = Confirmation {
tool_use_id,
input,
request,
tool,
ui_text,
};
tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
}
}
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<ToolResultOutput>,
configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
telemetry::event!(
"Agent Tool Finished",
model = metadata
.as_ref()
.map(|metadata| metadata.model.telemetry_id()),
model_provider = metadata
.as_ref()
.map(|metadata| metadata.model.provider_id().to_string()),
thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
tool_name,
success = output.is_ok()
);
match output {
Ok(output) => {
let tool_result = output.content;
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id);
// Protect from overly large output
let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() as usize * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);
let content = match tool_result {
ToolResultContent::Text(text) => {
let text = if text.len() < tool_output_limit {
text
} else {
let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit);
format!(
"Tool result too long. The first {} bytes:\n\n{}",
truncated.len(),
truncated
)
};
LanguageModelToolResultContent::Text(text.into())
}
ToolResultContent::Image(language_model_image) => {
if language_model_image.estimate_tokens() < tool_output_limit {
LanguageModelToolResultContent::Image(language_model_image)
} else {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: "Tool responded with an image that would exceeded the remaining tokens".into(),
is_error: true,
output: None,
},
);
return old_use;
}
}
};
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content,
is_error: false,
output: output.output,
},
);
old_use
}
Err(err) => {
self.tool_results.insert(
tool_use_id.clone(),
LanguageModelToolResult {
tool_use_id: tool_use_id.clone(),
tool_name,
content: LanguageModelToolResultContent::Text(err.to_string().into()),
is_error: true,
output: None,
},
);
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
}
self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
}
}
}
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
self.tool_uses_by_assistant_message
.contains_key(&assistant_message_id)
}
pub fn tool_results(
&self,
assistant_message_id: MessageId,
) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
self.tool_uses_by_assistant_message
.get(&assistant_message_id)
.into_iter()
.flatten()
.map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
}
}
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: LanguageModelToolUseId,
/// The ID of the Assistant message in which the tool use was requested.
#[allow(unused)]
pub assistant_message_id: MessageId,
pub name: Arc<str>,
pub ui_text: Arc<str>,
pub input: serde_json::Value,
pub status: PendingToolUseStatus,
pub may_perform_edits: bool,
}
#[derive(Debug, Clone)]
pub struct Confirmation {
pub tool_use_id: LanguageModelToolUseId,
pub input: serde_json::Value,
pub ui_text: Arc<str>,
pub request: Arc<LanguageModelRequest>,
pub tool: Arc<dyn Tool>,
}
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] Arc<str>),
}
impl PendingToolUseStatus {
pub fn is_idle(&self) -> bool {
matches!(self, PendingToolUseStatus::Idle)
}
pub fn is_error(&self) -> bool {
matches!(self, PendingToolUseStatus::Error(_))
}
pub fn needs_confirmation(&self) -> bool {
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
}
}
#[derive(Clone)]
pub struct ToolUseMetadata {
pub model: Arc<dyn LanguageModel>,
pub thread_id: ThreadId,
pub prompt_id: PromptId,
}

View File

@@ -6,10 +6,9 @@ use anyhow::{Result, bail};
use collections::IndexMap;
use gpui::{App, Pixels, SharedString};
use language_model::LanguageModel;
use schemars::{JsonSchema, json_schema};
use schemars::{JsonSchema, schema::Schema};
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use std::borrow::Cow;
pub use crate::agent_profile::*;
@@ -50,7 +49,7 @@ pub struct AgentSettings {
pub dock: AgentDockPosition,
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: Option<LanguageModelSelection>,
pub default_model: LanguageModelSelection,
pub inline_assistant_model: Option<LanguageModelSelection>,
pub commit_message_model: Option<LanguageModelSelection>,
pub thread_summary_model: Option<LanguageModelSelection>,
@@ -212,6 +211,7 @@ impl AgentSettingsContent {
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
#[schemars(deny_unknown_fields)]
pub struct AgentSettingsContent {
/// Whether the Agent is enabled.
///
@@ -321,27 +321,29 @@ pub struct LanguageModelSelection {
pub struct LanguageModelProviderSetting(pub String);
impl JsonSchema for LanguageModelProviderSetting {
fn schema_name() -> Cow<'static, str> {
fn schema_name() -> String {
"LanguageModelProviderSetting".into()
}
fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema {
json_schema!({
"enum": [
"anthropic",
"amazon-bedrock",
"google",
"lmstudio",
"ollama",
"openai",
"zed.dev",
"copilot_chat",
"deepseek",
"openrouter",
"mistral",
"vercel"
]
})
fn json_schema(_: &mut schemars::r#gen::SchemaGenerator) -> Schema {
schemars::schema::SchemaObject {
enum_values: Some(vec![
"anthropic".into(),
"amazon-bedrock".into(),
"google".into(),
"lmstudio".into(),
"ollama".into(),
"openai".into(),
"zed.dev".into(),
"copilot_chat".into(),
"deepseek".into(),
"openrouter".into(),
"mistral".into(),
"vercel".into(),
]),
..Default::default()
}
.into()
}
}
@@ -357,6 +359,15 @@ impl From<&str> for LanguageModelProviderSetting {
}
}
impl Default for LanguageModelSelection {
fn default() -> Self {
Self {
provider: LanguageModelProviderSetting("openai".to_string()),
model: "gpt-4".to_string(),
}
}
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentProfileContent {
pub name: Arc<str>,
@@ -400,10 +411,7 @@ impl Settings for AgentSettings {
&mut settings.default_height,
value.default_height.map(Into::into),
);
settings.default_model = value
.default_model
.clone()
.or(settings.default_model.take());
merge(&mut settings.default_model, value.default_model.clone());
settings.inline_assistant_model = value
.inline_assistant_model
.clone()

View File

@@ -96,6 +96,7 @@ zed_llm_client.workspace = true
[dev-dependencies]
assistant_tools.workspace = true
assistant_tool = { workspace = true, "features" = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }

View File

@@ -1,18 +1,21 @@
use crate::context_picker::{ContextPicker, MentionLink};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::{extract_message_creases, insert_message_creases};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
};
use crate::{AgentPanel, ModelUsageContext};
use agent::thread::{ToolUseSegment, UserMessageParams};
use agent::{
ContextStore, LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, TextThreadStore,
Thread, ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadSummary,
ThreadError, ThreadEvent, ThreadFeedback, ThreadStore, ThreadSummary, ZedAgentThread,
context::{self, AgentContextHandle, RULES_ICON},
thread::{PendingToolUseStatus, ToolUse},
thread_store::RulesLoadingError,
tool_use::{PendingToolUseStatus, ToolUse},
};
use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
use anyhow::Context as _;
use assistant_tool::ToolUseStatus;
use assistant_tool::{AnyToolCard, ToolUseStatus, ToolWorkingSet};
use audio::{Audio, Sound};
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
@@ -28,13 +31,14 @@ use gpui::{
};
use language::{Buffer, Language, LanguageRegistry};
use language_model::{
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason,
LanguageModelRequestMessage, LanguageModelToolResultContent, LanguageModelToolUseId,
MessageContent, Role, StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{
HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, PathWithRange,
};
use project::{ProjectEntryId, ProjectItem as _};
use project::{Project, ProjectEntryId, ProjectItem as _};
use rope::Point;
use settings::{Settings as _, SettingsStore, update_settings_file};
use std::ffi::OsStr;
@@ -45,26 +49,26 @@ use std::time::Duration;
use text::ToPoint;
use theme::ThemeSettings;
use ui::{
Banner, Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize,
Tooltip, prelude::*,
Disclosure, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState, TextSize, Tooltip,
prelude::*,
};
use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
use util::{ResultExt as _, debug_panic};
use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
use zed_llm_client::CompletionIntent;
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;
const RESPONSE_PADDING_X: Pixels = px(19.);
pub struct ActiveThread {
context_store: Entity<ContextStore>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
// thread: Entity<Thread>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
save_thread_task: Option<Task<()>>,
messages: Vec<MessageId>,
list_state: ListState,
@@ -91,7 +95,7 @@ struct RenderedMessage {
segments: Vec<RenderedMessageSegment>,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
struct RenderedToolUse {
label: Entity<Markdown>,
input: Entity<Markdown>,
@@ -161,17 +165,103 @@ impl RenderedMessage {
cx,
)))
}
MessageSegment::RedactedThinking(_) => {}
MessageSegment::ToolUse { .. } => {
todo!()
}
};
}
fn update_tool_call(
&mut self,
segment_index: usize,
segment: &ToolUseSegment,
_tools: &Entity<ToolWorkingSet>,
cx: &mut App,
) {
if let Some(card) = segment.card.clone() {
if self.segments.len() < segment_index {
self.segments.push(RenderedMessageSegment::ToolUseCard(
segment.status.clone(),
card,
))
}
return;
}
if self.segments.len() <= segment_index {
self.segments
.push(RenderedMessageSegment::ToolUseMarkdown(RenderedToolUse {
label: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
input: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
output: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
}))
}
dbg!(&self.segments);
let RenderedMessageSegment::ToolUseMarkdown(rendered) = &self.segments[segment_index]
else {
panic!()
};
// todo!()
// let ui_label = if let Some(tool) = tools.read(cx).tool(segment.name, cx) {
// if segment.is_input_complete {
// tool.ui_text(segment.input).into()
// } else {
// tool.still_streaming_ui_text(segment.input).into()
// }
// } else {
// format!("Unknown tool {:?}", segment.name).into()
// };
rendered.label.update(cx, |this, cx| {
this.replace(segment.name.clone(), cx);
});
rendered.input.update(cx, |this, cx| {
this.replace(
MarkdownCodeBlock {
tag: "json",
text: &serde_json::to_string_pretty(&segment.input).unwrap_or_default(),
}
.to_string(),
cx,
);
});
rendered.output.update(cx, |_this, _cx| {
match &segment.output {
Some(Ok(LanguageModelToolResultContent::Text(_text))) => {
// todo!
}
Some(Ok(LanguageModelToolResultContent::Image(_image))) => {
// todo!
}
Some(Err(_error)) => {
// todo!
}
None => {
// todo!
}
}
});
}
}
#[derive(Debug)]
enum RenderedMessageSegment {
Thinking {
content: Entity<Markdown>,
scroll_handle: ScrollHandle,
},
Text(Entity<Markdown>),
ToolUseCard(ToolUseStatus, AnyToolCard),
ToolUseMarkdown(RenderedToolUse),
}
fn parse_markdown(
@@ -764,7 +854,7 @@ struct EditingMessageState {
impl ActiveThread {
pub fn new(
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
context_store: Entity<ContextStore>,
@@ -774,8 +864,8 @@ impl ActiveThread {
cx: &mut Context<Self>,
) -> Self {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe_in(&thread, window, Self::handle_thread_event),
cx.observe(&agent, |_, _, cx| cx.notify()),
cx.subscribe_in(&agent, window, Self::handle_thread_event),
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
cx.observe_global::<SettingsStore>(|_, cx| cx.notify()),
];
@@ -787,12 +877,14 @@ impl ActiveThread {
.unwrap()
}
});
let project = agent.read(cx).project().clone();
let mut this = Self {
language_registry,
thread_store,
text_thread_store,
context_store,
thread: thread.clone(),
agent: agent.clone(),
project,
workspace,
save_thread_task: None,
messages: Vec::new(),
@@ -815,7 +907,8 @@ impl ActiveThread {
_load_edited_message_context_task: None,
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
// todo! hold on to thread entity and get messages directly
for message in agent.read(cx).messages().cloned().collect::<Vec<_>>() {
let rendered_message = RenderedMessage::from_segments(
&message.segments,
this.language_registry.clone(),
@@ -823,7 +916,7 @@ impl ActiveThread {
);
this.push_rendered_message(message.id, rendered_message);
for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
for tool_use in agent.read(cx).tool_uses_for_message(message.id, cx) {
this.render_tool_use_markdown(
tool_use.id.clone(),
tool_use.ui_text.clone(),
@@ -837,8 +930,8 @@ impl ActiveThread {
this
}
pub fn thread(&self) -> &Entity<Thread> {
&self.thread
pub fn agent(&self) -> &Entity<ZedAgentThread> {
&self.agent
}
pub fn is_empty(&self) -> bool {
@@ -846,17 +939,17 @@ impl ActiveThread {
}
pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary {
self.thread.read(cx).summary()
self.agent.read(cx).summary()
}
pub fn regenerate_summary(&self, cx: &mut App) {
self.thread.update(cx, |thread, cx| thread.summarize(cx))
self.agent.update(cx, |agent, cx| agent.summarize(cx))
}
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
self.last_error.take();
self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
self.agent.update(cx, |agent, cx| {
agent.cancel_last_completion(Some(window.window_handle()), cx)
})
}
@@ -946,7 +1039,7 @@ impl ActiveThread {
fn handle_thread_event(
&mut self,
_thread: &Entity<Thread>,
_agent: &Entity<ZedAgentThread>,
event: &ThreadEvent,
window: &mut Window,
cx: &mut Context<Self>,
@@ -964,10 +1057,8 @@ impl ActiveThread {
cx.notify();
}
ThreadEvent::CompletionCanceled => {
self.thread.update(cx, |thread, cx| {
thread.project().update(cx, |project, cx| {
project.set_agent_location(None, cx);
})
self.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
self.workspace
.update(cx, |workspace, cx| {
@@ -985,7 +1076,7 @@ impl ActiveThread {
}
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
let used_tools = self.thread.read(cx).used_tools_since_last_user_message();
let used_tools = self.agent.read(cx).used_tools_since_last_user_message(cx);
self.play_notification_sound(window, cx);
self.show_notification(
if used_tools {
@@ -1023,10 +1114,28 @@ impl ActiveThread {
rendered_message.append_thinking(text, cx);
}
}
ThreadEvent::StreamedToolUse2 {
message_id,
segment_index,
} => {
if let Some(rendered_message) = self.rendered_messages_by_id.get_mut(&message_id) {
self.agent.update(cx, |agent, cx| {
if let Some(message) = agent.message(*message_id) {
let MessageSegment::ToolUse(tool_use) =
&message.segments[*segment_index]
else {
debug_panic!("segment index mismatch");
return;
};
let tools = self.agent.read(cx).tools().clone();
rendered_message.update_tool_call(*segment_index, tool_use, &tools, cx);
}
})
}
}
ThreadEvent::MessageAdded(message_id) => {
self.clear_last_error();
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
if let Some(rendered_message) = self.agent.update(cx, |agent, cx| {
agent.message(*message_id).map(|message| {
RenderedMessage::from_segments(
&message.segments,
self.language_registry.clone(),
@@ -1041,10 +1150,9 @@ impl ActiveThread {
cx.notify();
}
ThreadEvent::MessageEdited(message_id) => {
self.clear_last_error();
if let Some(index) = self.messages.iter().position(|id| id == message_id) {
if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
thread.message(*message_id).map(|message| {
if let Some(rendered_message) = self.agent.update(cx, |agent, cx| {
agent.message(*message_id).map(|message| {
let mut rendered_message = RenderedMessage {
language_registry: self.language_registry.clone(),
segments: Vec::with_capacity(message.segments.len()),
@@ -1101,7 +1209,7 @@ impl ActiveThread {
tool_use.id.clone(),
tool_use.ui_text.clone(),
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
self.thread
self.agent
.read(cx)
.output_for_tool(&tool_use.id)
.map(|output| output.clone().into())
@@ -1121,7 +1229,7 @@ impl ActiveThread {
tool_use_id.clone(),
ui_text,
invalid_input_json,
self.thread
self.agent
.read(cx)
.output_for_tool(tool_use_id)
.map(|output| output.clone().into())
@@ -1137,7 +1245,7 @@ impl ActiveThread {
tool_use_id.clone(),
ui_text,
"",
self.thread
self.agent
.read(cx)
.output_for_tool(tool_use_id)
.map(|output| output.clone().into())
@@ -1186,7 +1294,7 @@ impl ActiveThread {
return;
}
let title = self.thread.read(cx).summary().unwrap_or("Agent Panel");
let title = self.agent.read(cx).summary().unwrap_or("Agent Panel");
match AgentSettings::get_global(cx).notify_when_agent_waiting {
NotifyWhenAgentWaiting::PrimaryScreen => {
@@ -1297,12 +1405,12 @@ impl ActiveThread {
///
/// Only one task to save the thread will be in flight at a time.
fn save_thread(&mut self, cx: &mut Context<Self>) {
let thread = self.thread.clone();
let agent = self.agent.clone();
self.save_thread_task = Some(cx.spawn(async move |this, cx| {
let task = this
.update(cx, |this, cx| {
this.thread_store
.update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx))
.update(cx, |thread_store, cx| thread_store.save_thread(&agent, cx))
})
.ok();
@@ -1352,7 +1460,7 @@ impl ActiveThread {
Some(self.text_thread_store.downgrade()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
ModelUsageContext::Thread(self.thread.clone()),
ModelUsageContext::Thread(self.agent.clone()),
window,
cx,
)
@@ -1404,13 +1512,13 @@ impl ActiveThread {
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take();
let Some(configured_model) = self.thread.read(cx).configured_model() else {
let Some(configured_model) = self.agent.read(cx).configured_model() else {
state.last_estimated_token_count.take();
return;
};
let editor = state.editor.clone();
let thread = self.thread.clone();
let agent = self.agent.clone();
let message_id = *message_id;
state._update_token_count_task = Some(cx.spawn(async move |this, cx| {
@@ -1422,7 +1530,7 @@ impl ActiveThread {
let token_count = if let Some(task) = cx
.update(|cx| {
let Some(message) = thread.read(cx).message(message_id) else {
let Some(message) = agent.read(cx).message(message_id) else {
log::error!("Message that was being edited no longer exists");
return None;
};
@@ -1554,8 +1662,8 @@ impl ActiveThread {
};
let Some(model) = self
.thread
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
.agent
.update(cx, |agent, cx| agent.get_or_init_configured_model(cx))
else {
return;
};
@@ -1569,12 +1677,13 @@ impl ActiveThread {
let creases = state.editor.update(cx, extract_message_creases);
let new_context = self
.context_store
.read(cx)
.new_context_for_thread(self.thread.read(cx), Some(message_id));
let new_context = self.context_store.read(cx).new_context_for_thread(
self.agent.read(cx),
Some(message_id),
cx,
);
let project = self.thread.read(cx).project().clone();
let project = self.project.clone();
let prompt_store = self.thread_store.read(cx).prompt_store().clone();
let git_store = project.read(cx).git_store().clone();
@@ -1587,32 +1696,24 @@ impl ActiveThread {
futures::future::join(load_context_task, checkpoint).await;
let _ = this
.update_in(cx, |this, window, cx| {
this.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
vec![MessageSegment::Text(edited_text)],
creases,
Some(context.loaded_context),
checkpoint.ok(),
cx,
);
for message_id in this.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
});
this.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.cancel_last_completion(Some(window.window_handle()), cx);
thread.send_to_model(
this.agent.update(cx, |agent, cx| {
agent.truncate(message_id, cx);
agent.send_message(
UserMessageParams {
text: edited_text,
creases,
checkpoint: checkpoint.ok(),
context,
},
model.model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
});
// todo! do we need this?
this._load_edited_message_context_task = None;
cx.notify();
})
.log_err();
@@ -1627,14 +1728,6 @@ impl ActiveThread {
}
}
fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
self.messages
.iter()
.position(|id| *id == message_id)
.map(|index| &self.messages[index + 1..])
.unwrap_or(&[])
}
fn handle_cancel_click(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
self.cancel_editing_message(&menu::Cancel, window, cx);
}
@@ -1655,7 +1748,7 @@ impl ActiveThread {
window: &mut Window,
cx: &mut Context<Self>,
) {
let report = self.thread.update(cx, |thread, cx| {
let report = self.agent.update(cx, |thread, cx| {
thread.report_message_feedback(message_id, feedback, cx)
});
@@ -1714,17 +1807,17 @@ impl ActiveThread {
return;
};
let report_task = self.thread.update(cx, |thread, cx| {
let report_task = self.agent.update(cx, |thread, cx| {
thread.report_message_feedback(message_id, ThreadFeedback::Negative, cx)
});
let comments = editor.read(cx).text(cx);
if !comments.is_empty() {
let thread_id = self.thread.read(cx).id().clone();
let thread_id = self.agent.read(cx).id().clone();
let comments_value = String::from(comments.as_str());
let message_content = self
.thread
.agent
.read(cx)
.message(message_id)
.map(|msg| msg.to_string())
@@ -1800,45 +1893,42 @@ impl ActiveThread {
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let message_id = self.messages[ix];
let workspace = self.workspace.clone();
let thread = self.thread.read(cx);
let agent = self.agent.read(cx);
let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1;
let Some(message) = thread.message(message_id) else {
let Some(message) = agent.message(message_id) else {
return Empty.into_any();
};
let is_generating = thread.is_generating();
let is_generating_stale = thread.is_generation_stale().unwrap_or(false);
let is_generating = agent.is_generating();
let is_generating_stale = agent.is_generation_stale().unwrap_or(false);
let loading_dots = (is_generating && is_last_message).then(|| {
h_flex()
.h_8()
.my_3()
.mx_5()
.when(is_generating_stale || message.is_hidden, |this| {
this.child(LoadingLabel::new("").size(LabelSize::Small))
.when(is_generating_stale, |this| {
this.child(AnimatedLabel::new("").size(LabelSize::Small))
})
});
if message.is_hidden {
return div().children(loading_dots).into_any();
}
let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
return Empty.into_any();
};
// Get all the data we need from thread before we start using it in closures
let checkpoint = thread.checkpoint_for_message(message_id);
let configured_model = thread.configured_model().map(|m| m.model);
let added_context = thread
let checkpoint = agent.checkpoint_for_message(message_id);
let configured_model = agent.configured_model().map(|m| m.model);
let added_context = agent
.context_for_message(message_id)
.map(|context| AddedContext::new_attached(context, configured_model.as_ref(), cx))
.collect::<Vec<_>>();
let tool_uses = thread.tool_uses_for_message(message_id, cx);
// let tool_uses = message.segments
let tool_uses = agent.tool_uses_for_message(message_id, cx);
let has_tool_uses = !tool_uses.is_empty();
let editing_message_state = self
@@ -1857,11 +1947,11 @@ impl ActiveThread {
.icon_color(Color::Ignored)
.tooltip(Tooltip::text("Open Thread as Markdown"))
.on_click({
let thread = self.thread.clone();
let agent = self.agent.clone();
let workspace = self.workspace.clone();
move |_, window, cx| {
if let Some(workspace) = workspace.upgrade() {
open_active_thread_as_markdown(thread.clone(), workspace, window, cx)
open_active_thread_as_markdown(agent.clone(), workspace, window, cx)
.detach_and_log_err(cx);
}
}
@@ -1875,7 +1965,10 @@ impl ActiveThread {
this.scroll_to_top(cx);
}));
let show_feedback = thread.is_turn_end(ix);
// For all items that should be aligned with the LLM's response.
const RESPONSE_PADDING_X: Pixels = px(19.);
let show_feedback = self.agent.read(cx).is_turn_end(ix);
let feedback_container = h_flex()
.group("feedback_container")
.mt_1()
@@ -1887,7 +1980,7 @@ impl ActiveThread {
.gap_1p5()
.flex_wrap()
.justify_end();
let feedback_items = match self.thread.read(cx).message_feedback(message_id) {
let feedback_items = match self.agent.read(cx).message_feedback(message_id) {
Some(feedback) => feedback_container
.child(
div().visible_on_hover("feedback_container").child(
@@ -1993,6 +2086,9 @@ impl ActiveThread {
};
let message_is_empty = message.should_display_content();
let message_is_ui_only = message.ui_only;
let message_creases = message.creases.clone();
let role = message.role;
let has_content = !message_is_empty || !added_context.is_empty();
let message_content = has_content.then(|| {
@@ -2035,10 +2131,10 @@ impl ActiveThread {
}
});
let styled_message = if message.ui_only {
let styled_message = if message_is_ui_only {
self.render_ui_notification(message_content, ix, cx)
} else {
match message.role {
match role {
Role::User => {
let colors = cx.theme().colors();
v_flex()
@@ -2143,10 +2239,9 @@ impl ActiveThread {
}),
)
.on_click(cx.listener({
let message_creases = message.creases.clone();
move |this, _, window, cx| {
if let Some(message_text) =
this.thread.read(cx).message(message_id).and_then(|message| {
this.agent.read(cx).message(message_id).and_then(|message| {
message.segments.first().and_then(|segment| {
match segment {
MessageSegment::Text(message_text) => {
@@ -2217,7 +2312,7 @@ impl ActiveThread {
let mut is_pending = false;
let mut error = None;
if let Some(last_restore_checkpoint) =
self.thread.read(cx).last_restore_checkpoint()
self.agent.read(cx).last_restore_checkpoint()
{
if last_restore_checkpoint.message_id() == message_id {
match last_restore_checkpoint {
@@ -2246,7 +2341,7 @@ impl ActiveThread {
.label_size(LabelSize::XSmall)
.disabled(is_pending)
.on_click(cx.listener(move |this, _, _window, cx| {
this.thread.update(cx, |thread, cx| {
this.agent.update(cx, |thread, cx| {
thread
.restore_checkpoint(checkpoint.clone(), cx)
.detach_and_log_err(cx);
@@ -2380,11 +2475,11 @@ impl ActiveThread {
rendered_message: &RenderedMessage,
has_tool_uses: bool,
workspace: WeakEntity<Workspace>,
window: &Window,
cx: &Context<Self>,
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let is_last_message = self.messages.last() == Some(&message_id);
let is_generating = self.thread.read(cx).is_generating();
let is_generating = self.agent.read(cx).is_generating();
let pending_thinking_segment_index = if is_generating && is_last_message && !has_tool_uses {
rendered_message
.segments
@@ -2398,7 +2493,7 @@ impl ActiveThread {
};
let message_role = self
.thread
.agent
.read(cx)
.message(message_id)
.map(|m| m.role)
@@ -2513,6 +2608,23 @@ impl ActiveThread {
}))
.into_any_element()
}
RenderedMessageSegment::ToolUseCard(status, card) => {
card.render(status, window, workspace.clone(), cx)
}
RenderedMessageSegment::ToolUseMarkdown(rendered) => v_flex()
.child(MarkdownElement::new(
rendered.label.clone(),
default_markdown_style(window, cx),
))
.child(MarkdownElement::new(
rendered.input.clone(),
default_markdown_style(window, cx),
))
.child(MarkdownElement::new(
rendered.output.clone(),
default_markdown_style(window, cx),
))
.into_any(), // todo!()
},
),
)
@@ -2535,18 +2647,34 @@ impl ActiveThread {
ix: usize,
cx: &mut Context<Self>,
) -> Stateful<Div> {
let message = div()
.flex_1()
.min_w_0()
.text_size(TextSize::XSmall.rems(cx))
.text_color(cx.theme().colors().text_muted)
.children(message_content);
div()
.id(("message-container", ix))
.py_1()
.px_2p5()
.child(Banner::new().severity(ui::Severity::Warning).child(message))
let colors = cx.theme().colors();
div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.w_full()
.bg(colors.editor_background)
.rounded_sm()
.child(
h_flex()
.w_full()
.p_2()
.gap_2()
.child(
div().flex_none().child(
Icon::new(IconName::Warning)
.size(IconSize::Small)
.color(Color::Warning),
),
)
.child(
v_flex()
.flex_1()
.min_w_0()
.text_size(TextSize::Small.rems(cx))
.text_color(cx.theme().colors().text_muted)
.children(message_content),
),
),
)
}
fn render_message_thinking_segment(
@@ -2584,7 +2712,7 @@ impl ActiveThread {
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(LoadingLabel::new("Thinking").size(LabelSize::Small)),
.child(AnimatedLabel::new("Thinking").size(LabelSize::Small)),
)
.child(
h_flex()
@@ -2766,7 +2894,7 @@ impl ActiveThread {
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
if let Some(card) = self.agent.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, workspace, cx);
}
@@ -3153,7 +3281,7 @@ impl ActiveThread {
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(
LoadingLabel::new("Waiting for Confirmation").size(LabelSize::Small)
AnimatedLabel::new("Waiting for Confirmation").size(LabelSize::Small)
)
.child(
h_flex()
@@ -3247,7 +3375,7 @@ impl ActiveThread {
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let project_context = self.thread.read(cx).project_context();
let project_context = self.agent.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
return div().into_any();
@@ -3371,12 +3499,12 @@ impl ActiveThread {
cx: &mut Context<Self>,
) {
if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
.thread
.agent
.read(cx)
.pending_tool(&tool_use_id)
.map(|tool_use| tool_use.status.clone())
{
self.thread.update(cx, |thread, cx| {
self.agent.update(cx, |thread, cx| {
if let Some(configured) = thread.get_or_init_configured_model(cx) {
thread.run_tool(
c.tool_use_id.clone(),
@@ -3402,13 +3530,13 @@ impl ActiveThread {
cx: &mut Context<Self>,
) {
let window_handle = window.window_handle();
self.thread.update(cx, |thread, cx| {
self.agent.update(cx, |thread, cx| {
thread.deny_tool_use(tool_use_id, tool_name, Some(window_handle), cx);
});
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let project_context = self.thread.read(cx).project_context();
let project_context = self.agent.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
return;
@@ -3570,7 +3698,7 @@ impl Render for ActiveThread {
}
pub(crate) fn open_active_thread_as_markdown(
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
workspace: Entity<Workspace>,
window: &mut Window,
cx: &mut App,
@@ -3585,7 +3713,7 @@ pub(crate) fn open_active_thread_as_markdown(
let markdown_language = markdown_language_task.await?;
workspace.update_in(cx, |workspace, window, cx| {
let thread = thread.read(cx);
let thread = agent.read(cx);
let markdown = thread.to_markdown(cx)?;
let thread_summary = thread.summary().or_default().to_string();
@@ -3674,7 +3802,7 @@ pub(crate) fn open_context(
AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
panel.update(cx, |panel, cx| {
panel.open_thread(thread_context.thread.clone(), window, cx);
panel.open_thread(thread_context.agent.clone(), window, cx);
});
}
}),
@@ -3761,7 +3889,9 @@ fn open_editor_at_position(
#[cfg(test)]
mod tests {
use super::*;
use agent::{MessageSegment, context::ContextLoadResult, thread_store};
use agent::{
MessageSegment, context::ContextLoadResult, thread::UserMessageParams, thread_store,
};
use assistant_tool::{ToolRegistry, ToolWorkingSet};
use editor::EditorSettings;
use fs::FakeFs;
@@ -3776,6 +3906,7 @@ mod tests {
use settings::SettingsStore;
use util::path;
use workspace::CollaboratorId;
use zed_llm_client::CompletionIntent;
#[gpui::test]
async fn test_agent_is_unfollowed_after_cancelling_completion(cx: &mut TestAppContext) {
@@ -3792,13 +3923,12 @@ mod tests {
// Insert user message without any context (empty context vector)
thread.update(cx, |thread, cx| {
thread.insert_user_message(
thread.send_message(
"What is the best way to learn Rust?",
ContextLoadResult::default(),
model.clone(),
None,
vec![],
cx,
);
)
});
// Stream response to user message
@@ -3839,7 +3969,7 @@ mod tests {
registry.set_default_model(
Some(ConfiguredModel {
provider: Arc::new(FakeLanguageModelProvider),
model,
model: model.clone(),
}),
cx,
);
@@ -3853,15 +3983,19 @@ mod tests {
context: None,
}];
let message = thread.update(cx, |thread, cx| {
let message_id = thread.insert_user_message(
"Tell me about @foo.txt",
ContextLoadResult::default(),
let message = thread.update(cx, |agent, cx| {
let message_id = agent.send_message(
UserMessageParams {
text: "Tell me about @foo.txt".to_string(),
creases,
checkpoint: None,
context: ContextLoadResult::default(),
},
model.clone(),
None,
creases,
cx,
);
thread.message(message_id).cloned().unwrap()
agent.message(message_id).cloned().unwrap()
});
active_thread.update_in(cx, |active_thread, window, cx| {
@@ -3953,20 +4087,8 @@ mod tests {
// Insert a user message and start streaming a response
let message = thread.update(cx, |thread, cx| {
let message_id = thread.insert_user_message(
"Hello, how are you?",
ContextLoadResult::default(),
None,
vec![],
cx,
);
thread.advance_prompt_id();
thread.send_to_model(
model.clone(),
CompletionIntent::UserPrompt,
cx.active_window(),
cx,
);
let message_id =
thread.send_message("Hello, how are you?", model.clone(), cx.active_window(), cx);
thread.message(message_id).cloned().unwrap()
});
@@ -4053,7 +4175,7 @@ mod tests {
&mut VisualTestContext,
Entity<ActiveThread>,
Entity<Workspace>,
Entity<Thread>,
Entity<ZedAgentThread>,
Arc<dyn LanguageModel>,
) {
let (workspace, cx) =

View File

@@ -16,9 +16,7 @@ use gpui::{
Focusable, ScrollHandle, Subscription, Task, Transformation, WeakEntity, percentage,
};
use language::LanguageRegistry;
use language_model::{
LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
use notifications::status_toast::{StatusToast, ToastIcon};
use project::{
context_server_store::{ContextServerConfiguration, ContextServerStatus, ContextServerStore},
@@ -26,8 +24,8 @@ use project::{
};
use settings::{Settings, update_settings_file};
use ui::{
ContextMenu, Disclosure, Divider, DividerColor, ElevationIndex, Indicator, PopoverMenu,
Scrollbar, ScrollbarState, Switch, SwitchColor, Tooltip, prelude::*,
ContextMenu, Disclosure, ElevationIndex, Indicator, PopoverMenu, Scrollbar, ScrollbarState,
Switch, SwitchColor, Tooltip, prelude::*,
};
use util::ResultExt as _;
use workspace::Workspace;
@@ -88,14 +86,6 @@ impl AgentConfiguration {
let scroll_handle = ScrollHandle::new();
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
let mut expanded_provider_configurations = HashMap::default();
if LanguageModelRegistry::read_global(cx)
.provider(&ZED_CLOUD_PROVIDER_ID)
.map_or(false, |cloud_provider| cloud_provider.must_accept_terms(cx))
{
expanded_provider_configurations.insert(ZED_CLOUD_PROVIDER_ID, true);
}
let mut this = Self {
fs,
language_registry,
@@ -104,7 +94,7 @@ impl AgentConfiguration {
configuration_views_by_provider: HashMap::default(),
context_server_store,
expanded_context_server_tools: HashMap::default(),
expanded_provider_configurations,
expanded_provider_configurations: HashMap::default(),
tools,
_registry_subscription: registry_subscription,
scroll_handle,
@@ -172,29 +162,19 @@ impl AgentConfiguration {
.unwrap_or(false);
v_flex()
.when(is_expanded, |this| this.mb_2())
.child(
div()
.opacity(0.6)
.px_2()
.child(Divider::horizontal().color(DividerColor::Border)),
)
.py_2()
.gap_1p5()
.border_t_1()
.border_color(cx.theme().colors().border.opacity(0.6))
.child(
h_flex()
.map(|this| {
if is_expanded {
this.mt_2().mb_1()
} else {
this.my_2()
}
})
.w_full()
.gap_1()
.justify_between()
.child(
h_flex()
.id(provider_id_string.clone())
.cursor_pointer()
.px_2()
.py_0p5()
.w_full()
.justify_between()
@@ -257,16 +237,12 @@ impl AgentConfiguration {
)
}),
)
.child(
div()
.px_2()
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
}),
)
.when(is_expanded, |parent| match configuration_view {
Some(configuration_view) => parent.child(configuration_view),
None => parent.child(Label::new(format!(
"No configuration view for {provider_name}",
))),
})
}
fn render_provider_configuration_section(
@@ -276,11 +252,12 @@ impl AgentConfiguration {
let providers = LanguageModelRegistry::read_global(cx).providers();
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.border_b_1()
.border_color(cx.theme().colors().border)
.child(
v_flex()
.p(DynamicSpacing::Base16.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.pb_0()
.mb_2p5()
.gap_0p5()
.child(Headline::new("LLM Providers"))
@@ -289,15 +266,10 @@ impl AgentConfiguration {
.color(Color::Muted),
),
)
.child(
div()
.pl(DynamicSpacing::Base08.rems(cx))
.pr(DynamicSpacing::Base20.rems(cx))
.children(
providers.into_iter().map(|provider| {
self.render_provider_configuration_block(&provider, cx)
}),
),
.children(
providers
.into_iter()
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
)
}
@@ -436,7 +408,7 @@ impl AgentConfiguration {
window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let context_server_ids = self.context_server_store.read(cx).configured_server_ids();
let context_server_ids = self.context_server_store.read(cx).all_server_ids().clone();
v_flex()
.p(DynamicSpacing::Base16.rems(cx))

View File

@@ -379,14 +379,6 @@ impl ConfigureContextServerModal {
};
self.state = State::Waiting;
let existing_server = self.context_server_store.read(cx).get_running_server(&id);
if existing_server.is_some() {
self.context_server_store.update(cx, |store, cx| {
store.stop_server(&id, cx).log_err();
});
}
let wait_for_context_server_task =
wait_for_context_server(&self.context_server_store, id.clone(), cx);
cx.spawn({
@@ -407,21 +399,13 @@ impl ConfigureContextServerModal {
})
.detach();
let settings_changed =
ProjectSettings::get_global(cx).context_servers.get(&id.0) != Some(&settings);
if settings_changed {
// When we write the settings to the file, the context server will be restarted.
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
project_settings.context_servers.insert(id.0, settings);
});
// When we write the settings to the file, the context server will be restarted.
workspace.update(cx, |workspace, cx| {
let fs = workspace.app_state().fs.clone();
update_settings_file::<ProjectSettings>(fs.clone(), cx, |project_settings, _| {
project_settings.context_servers.insert(id.0, settings);
});
} else if let Some(existing_server) = existing_server {
self.context_server_store
.update(cx, |store, cx| store.start_server(existing_server, cx));
}
});
}
fn cancel(&mut self, _: &menu::Cancel, cx: &mut Context<Self>) {

View File

@@ -1,7 +1,8 @@
use crate::{Keep, KeepAll, OpenAgentDiff, Reject, RejectAll};
use agent::{Thread, ThreadEvent};
use agent::{ThreadEvent, ZedAgentThread};
use agent_settings::AgentSettings;
use anyhow::Result;
use assistant_tool::ActionLog;
use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet};
use editor::{
@@ -41,7 +42,8 @@ use zed_actions::assistant::ToggleFocus;
pub struct AgentDiffPane {
multibuffer: Entity<MultiBuffer>,
editor: Entity<Editor>,
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
action_log: Entity<ActionLog>,
focus_handle: FocusHandle,
workspace: WeakEntity<Workspace>,
title: SharedString,
@@ -50,70 +52,71 @@ pub struct AgentDiffPane {
impl AgentDiffPane {
pub fn deploy(
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut App,
) -> Result<Entity<Self>> {
workspace.update(cx, |workspace, cx| {
Self::deploy_in_workspace(thread, workspace, window, cx)
Self::deploy_in_workspace(agent, workspace, window, cx)
})
}
pub fn deploy_in_workspace(
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
workspace: &mut Workspace,
window: &mut Window,
cx: &mut Context<Workspace>,
) -> Entity<Self> {
let existing_diff = workspace
.items_of_type::<AgentDiffPane>(cx)
.find(|diff| diff.read(cx).thread == thread);
.find(|diff| diff.read(cx).agent == agent);
if let Some(existing_diff) = existing_diff {
workspace.activate_item(&existing_diff, true, true, window, cx);
existing_diff
} else {
let agent_diff = cx
.new(|cx| AgentDiffPane::new(thread.clone(), workspace.weak_handle(), window, cx));
let agent_diff =
cx.new(|cx| AgentDiffPane::new(agent.clone(), workspace.weak_handle(), window, cx));
workspace.add_item_to_center(Box::new(agent_diff.clone()), window, cx);
agent_diff
}
}
pub fn new(
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
workspace: WeakEntity<Workspace>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let focus_handle = cx.focus_handle();
let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite));
let action_log = agent.read(cx).action_log();
let project = agent.read(cx).project().clone();
let project = thread.read(cx).project().clone();
let editor = cx.new(|cx| {
let mut editor =
Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
editor.disable_inline_diagnostics();
editor.set_expand_all_diff_hunks(cx);
editor.set_render_diff_hunk_controls(diff_hunk_controls(&thread), cx);
editor.set_render_diff_hunk_controls(diff_hunk_controls(&action_log), cx);
editor.register_addon(AgentDiffAddon);
editor
});
let action_log = thread.read(cx).action_log().clone();
let mut this = Self {
_subscriptions: vec![
cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
this.update_excerpts(window, cx)
}),
cx.subscribe(&thread, |this, _thread, event, cx| {
cx.subscribe(&agent, |this, _thread, event, cx| {
this.handle_thread_event(event, cx)
}),
],
title: SharedString::default(),
action_log,
multibuffer,
editor,
thread,
agent,
focus_handle,
workspace,
};
@@ -123,8 +126,8 @@ impl AgentDiffPane {
}
fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let thread = self.thread.read(cx);
let changed_buffers = thread.action_log().read(cx).changed_buffers(cx);
let agent = self.agent.read(cx);
let changed_buffers = agent.action_log().read(cx).changed_buffers(cx);
let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>();
for (buffer, diff_handle) in changed_buffers {
@@ -211,7 +214,7 @@ impl AgentDiffPane {
}
fn update_title(&mut self, cx: &mut Context<Self>) {
let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes");
let new_title = self.agent.read(cx).summary().unwrap_or("Agent Changes");
if new_title != self.title {
self.title = new_title;
cx.emit(EditorEvent::TitleChanged);
@@ -248,14 +251,14 @@ impl AgentDiffPane {
fn keep(&mut self, _: &Keep, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
keep_edits_in_selection(editor, &snapshot, &self.thread, window, cx);
keep_edits_in_selection(editor, &snapshot, &self.action_log, window, cx);
});
}
fn reject(&mut self, _: &Reject, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
reject_edits_in_selection(editor, &snapshot, &self.thread, window, cx);
reject_edits_in_selection(editor, &snapshot, &self.action_log, window, cx);
});
}
@@ -265,7 +268,7 @@ impl AgentDiffPane {
reject_edits_in_ranges(
editor,
&snapshot,
&self.thread,
&self.action_log,
vec![editor::Anchor::min()..editor::Anchor::max()],
window,
cx,
@@ -274,15 +277,15 @@ impl AgentDiffPane {
}
fn keep_all(&mut self, _: &KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
self.thread
.update(cx, |thread, cx| thread.keep_all_edits(cx));
self.action_log
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
}
}
fn keep_edits_in_selection(
editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>,
action_log: &Entity<ActionLog>,
window: &mut Window,
cx: &mut Context<Editor>,
) {
@@ -291,13 +294,13 @@ fn keep_edits_in_selection(
.disjoint_anchor_ranges()
.collect::<Vec<_>>();
keep_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx)
keep_edits_in_ranges(editor, buffer_snapshot, &action_log, ranges, window, cx)
}
fn reject_edits_in_selection(
editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>,
action_log: &Entity<ActionLog>,
window: &mut Window,
cx: &mut Context<Editor>,
) {
@@ -305,13 +308,13 @@ fn reject_edits_in_selection(
.selections
.disjoint_anchor_ranges()
.collect::<Vec<_>>();
reject_edits_in_ranges(editor, buffer_snapshot, &thread, ranges, window, cx)
reject_edits_in_ranges(editor, buffer_snapshot, &action_log, ranges, window, cx)
}
fn keep_edits_in_ranges(
editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>,
action_log: &Entity<ActionLog>,
ranges: Vec<Range<editor::Anchor>>,
window: &mut Window,
cx: &mut Context<Editor>,
@@ -326,8 +329,8 @@ fn keep_edits_in_ranges(
for hunk in &diff_hunks_in_ranges {
let buffer = multibuffer.read(cx).buffer(hunk.buffer_id);
if let Some(buffer) = buffer {
thread.update(cx, |thread, cx| {
thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
action_log.update(cx, |action_log, cx| {
action_log.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
});
}
}
@@ -336,7 +339,7 @@ fn keep_edits_in_ranges(
fn reject_edits_in_ranges(
editor: &mut Editor,
buffer_snapshot: &MultiBufferSnapshot,
thread: &Entity<Thread>,
action_log: &Entity<ActionLog>,
ranges: Vec<Range<editor::Anchor>>,
window: &mut Window,
cx: &mut Context<Editor>,
@@ -361,9 +364,9 @@ fn reject_edits_in_ranges(
}
for (buffer, ranges) in ranges_by_buffer {
thread
.update(cx, |thread, cx| {
thread.reject_edits_in_ranges(buffer, ranges, cx)
action_log
.update(cx, |action_log, cx| {
action_log.reject_edits_in_ranges(buffer, ranges, cx)
})
.detach_and_log_err(cx);
}
@@ -461,7 +464,7 @@ impl Item for AgentDiffPane {
}
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes");
let summary = self.agent.read(cx).summary().or_default();
Label::new(format!("Review: {}", summary))
.color(if params.selected {
Color::Default
@@ -511,7 +514,7 @@ impl Item for AgentDiffPane {
where
Self: Sized,
{
Some(cx.new(|cx| Self::new(self.thread.clone(), self.workspace.clone(), window, cx)))
Some(cx.new(|cx| Self::new(self.agent.clone(), self.workspace.clone(), window, cx)))
}
fn is_dirty(&self, cx: &App) -> bool {
@@ -641,8 +644,8 @@ impl Render for AgentDiffPane {
}
}
fn diff_hunk_controls(thread: &Entity<Thread>) -> editor::RenderDiffHunkControlsFn {
let thread = thread.clone();
fn diff_hunk_controls(action_log: &Entity<ActionLog>) -> editor::RenderDiffHunkControlsFn {
let action_log = action_log.clone();
Arc::new(
move |row,
@@ -660,7 +663,7 @@ fn diff_hunk_controls(thread: &Entity<Thread>) -> editor::RenderDiffHunkControls
hunk_range,
is_created_file,
line_height,
&thread,
&action_log,
editor,
window,
cx,
@@ -676,7 +679,7 @@ fn render_diff_hunk_controls(
hunk_range: Range<editor::Anchor>,
is_created_file: bool,
line_height: Pixels,
thread: &Entity<Thread>,
action_log: &Entity<ActionLog>,
editor: &Entity<Editor>,
window: &mut Window,
cx: &mut App,
@@ -711,14 +714,14 @@ fn render_diff_hunk_controls(
)
.on_click({
let editor = editor.clone();
let thread = thread.clone();
let action_log = action_log.clone();
move |_event, window, cx| {
editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
reject_edits_in_ranges(
editor,
&snapshot,
&thread,
&action_log,
vec![hunk_range.start..hunk_range.start],
window,
cx,
@@ -733,14 +736,14 @@ fn render_diff_hunk_controls(
)
.on_click({
let editor = editor.clone();
let thread = thread.clone();
let action_log = action_log.clone();
move |_event, window, cx| {
editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
keep_edits_in_ranges(
editor,
&snapshot,
&thread,
&action_log,
vec![hunk_range.start..hunk_range.start],
window,
cx,
@@ -1114,7 +1117,7 @@ impl Render for AgentDiffToolbar {
let has_pending_edit_tool_use = agent_diff
.read(cx)
.thread
.agent
.read(cx)
.has_pending_edit_tool_uses();
@@ -1187,7 +1190,7 @@ pub enum EditorState {
}
struct WorkspaceThread {
thread: WeakEntity<Thread>,
agent: WeakEntity<ZedAgentThread>,
_thread_subscriptions: [Subscription; 2],
singleton_editors: HashMap<WeakEntity<Buffer>, HashMap<WeakEntity<Editor>, Subscription>>,
_settings_subscription: Subscription,
@@ -1212,7 +1215,7 @@ impl AgentDiff {
pub fn set_active_thread(
workspace: &WeakEntity<Workspace>,
thread: &Entity<Thread>,
thread: &Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut App,
) {
@@ -1224,11 +1227,11 @@ impl AgentDiff {
fn register_active_thread_impl(
&mut self,
workspace: &WeakEntity<Workspace>,
thread: &Entity<Thread>,
agent: &Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let action_log = thread.read(cx).action_log().clone();
let action_log = agent.read(cx).action_log().clone();
let action_log_subscription = cx.observe_in(&action_log, window, {
let workspace = workspace.clone();
@@ -1237,7 +1240,7 @@ impl AgentDiff {
}
});
let thread_subscription = cx.subscribe_in(&thread, window, {
let thread_subscription = cx.subscribe_in(&agent, window, {
let workspace = workspace.clone();
move |this, _thread, event, window, cx| {
this.handle_thread_event(&workspace, event, window, cx)
@@ -1246,7 +1249,7 @@ impl AgentDiff {
if let Some(workspace_thread) = self.workspace_threads.get_mut(&workspace) {
// replace thread and action log subscription, but keep editors
workspace_thread.thread = thread.downgrade();
workspace_thread.agent = agent.downgrade();
workspace_thread._thread_subscriptions = [action_log_subscription, thread_subscription];
self.update_reviewing_editors(&workspace, window, cx);
return;
@@ -1271,7 +1274,7 @@ impl AgentDiff {
self.workspace_threads.insert(
workspace.clone(),
WorkspaceThread {
thread: thread.downgrade(),
agent: agent.downgrade(),
_thread_subscriptions: [action_log_subscription, thread_subscription],
singleton_editors: HashMap::default(),
_settings_subscription: settings_subscription,
@@ -1319,7 +1322,7 @@ impl AgentDiff {
fn register_review_action<T: Action>(
workspace: &mut Workspace,
review: impl Fn(&Entity<Editor>, &Entity<Thread>, &mut Window, &mut App) -> PostReviewState
review: impl Fn(&Entity<Editor>, &Entity<ZedAgentThread>, &mut Window, &mut App) -> PostReviewState
+ 'static,
this: &Entity<AgentDiff>,
) {
@@ -1362,6 +1365,7 @@ impl AgentDiff {
| ThreadEvent::StreamedAssistantText(_, _)
| ThreadEvent::StreamedAssistantThinking(_, _)
| ThreadEvent::StreamedToolUse { .. }
| ThreadEvent::StreamedToolUse2 { .. }
| ThreadEvent::InvalidToolInput { .. }
| ThreadEvent::MissingToolUse { .. }
| ThreadEvent::MessageAdded(_)
@@ -1481,11 +1485,11 @@ impl AgentDiff {
return;
};
let Some(thread) = workspace_thread.thread.upgrade() else {
let Some(agent) = workspace_thread.agent.upgrade() else {
return;
};
let action_log = thread.read(cx).action_log();
let action_log = agent.read(cx).action_log();
let changed_buffers = action_log.read(cx).changed_buffers(cx);
let mut unaffected = self.reviewing_editors.clone();
@@ -1510,7 +1514,7 @@ impl AgentDiff {
multibuffer.add_diff(diff_handle.clone(), cx);
});
let new_state = if thread.read(cx).is_generating() {
let new_state = if agent.read(cx).is_generating() {
EditorState::Generating
} else {
EditorState::Reviewing
@@ -1523,7 +1527,7 @@ impl AgentDiff {
if previous_state.is_none() {
editor.update(cx, |editor, cx| {
editor.start_temporary_diff_override();
editor.set_render_diff_hunk_controls(diff_hunk_controls(&thread), cx);
editor.set_render_diff_hunk_controls(diff_hunk_controls(&action_log), cx);
editor.set_expand_all_diff_hunks(cx);
editor.register_addon(EditorAgentDiffAddon);
});
@@ -1591,22 +1595,22 @@ impl AgentDiff {
return;
};
let Some(WorkspaceThread { thread, .. }) =
let Some(WorkspaceThread { agent, .. }) =
self.workspace_threads.get(&workspace.downgrade())
else {
return;
};
let Some(thread) = thread.upgrade() else {
let Some(agent) = agent.upgrade() else {
return;
};
AgentDiffPane::deploy(thread, workspace.downgrade(), window, cx).log_err();
AgentDiffPane::deploy(agent, workspace.downgrade(), window, cx).log_err();
}
fn keep_all(
editor: &Entity<Editor>,
thread: &Entity<Thread>,
agent: &Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut App,
) -> PostReviewState {
@@ -1615,7 +1619,7 @@ impl AgentDiff {
keep_edits_in_ranges(
editor,
&snapshot,
thread,
&agent.read(cx).action_log(),
vec![editor::Anchor::min()..editor::Anchor::max()],
window,
cx,
@@ -1626,7 +1630,7 @@ impl AgentDiff {
fn reject_all(
editor: &Entity<Editor>,
thread: &Entity<Thread>,
thread: &Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut App,
) -> PostReviewState {
@@ -1635,7 +1639,7 @@ impl AgentDiff {
reject_edits_in_ranges(
editor,
&snapshot,
thread,
&thread.read(cx).action_log(),
vec![editor::Anchor::min()..editor::Anchor::max()],
window,
cx,
@@ -1646,26 +1650,26 @@ impl AgentDiff {
fn keep(
editor: &Entity<Editor>,
thread: &Entity<Thread>,
agent: &Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut App,
) -> PostReviewState {
editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
keep_edits_in_selection(editor, &snapshot, thread, window, cx);
keep_edits_in_selection(editor, &snapshot, &agent.read(cx).action_log(), window, cx);
Self::post_review_state(&snapshot)
})
}
fn reject(
editor: &Entity<Editor>,
thread: &Entity<Thread>,
agent: &Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut App,
) -> PostReviewState {
editor.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
reject_edits_in_selection(editor, &snapshot, thread, window, cx);
reject_edits_in_selection(editor, &snapshot, &agent.read(cx).action_log(), window, cx);
Self::post_review_state(&snapshot)
})
}
@@ -1682,7 +1686,7 @@ impl AgentDiff {
fn review_in_active_editor(
&mut self,
workspace: &mut Workspace,
review: impl Fn(&Entity<Editor>, &Entity<Thread>, &mut Window, &mut App) -> PostReviewState,
review: impl Fn(&Entity<Editor>, &Entity<ZedAgentThread>, &mut Window, &mut App) -> PostReviewState,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<Task<Result<()>>> {
@@ -1696,14 +1700,13 @@ impl AgentDiff {
return None;
}
let WorkspaceThread { thread, .. } =
self.workspace_threads.get(&workspace.weak_handle())?;
let WorkspaceThread { agent, .. } = self.workspace_threads.get(&workspace.weak_handle())?;
let thread = thread.upgrade()?;
let agent = agent.upgrade()?;
if let PostReviewState::AllReviewed = review(&editor, &thread, window, cx) {
if let PostReviewState::AllReviewed = review(&editor, &agent, window, cx) {
if let Some(curr_buffer) = editor.read(cx).buffer().read(cx).as_singleton() {
let changed_buffers = thread.read(cx).action_log().read(cx).changed_buffers(cx);
let changed_buffers = agent.read(cx).action_log().read(cx).changed_buffers(cx);
let mut keys = changed_buffers.keys().cycle();
keys.find(|k| *k == &curr_buffer);
@@ -1801,13 +1804,13 @@ mod tests {
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let agent = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = agent.read_with(cx, |agent, _| agent.action_log().clone());
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let agent_diff = cx.new_window_entity(|window, cx| {
AgentDiffPane::new(thread.clone(), workspace.downgrade(), window, cx)
AgentDiffPane::new(agent.clone(), workspace.downgrade(), window, cx)
});
let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone());
@@ -1895,7 +1898,7 @@ mod tests {
keep_edits_in_ranges(
editor,
&snapshot,
&thread,
&agent.read(cx).action_log(),
vec![position..position],
window,
cx,
@@ -1966,8 +1969,8 @@ mod tests {
})
.await
.unwrap();
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
let agent = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = agent.read_with(cx, |agent, _| agent.action_log().clone());
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
@@ -1989,7 +1992,7 @@ mod tests {
// Set the active thread
cx.update(|window, cx| {
AgentDiff::set_active_thread(&workspace.downgrade(), &thread, window, cx)
AgentDiff::set_active_thread(&workspace.downgrade(), &agent, window, cx)
});
let buffer1 = project
@@ -2146,7 +2149,7 @@ mod tests {
keep_edits_in_ranges(
editor,
&snapshot,
&thread,
&agent.read(cx).action_log(),
vec![position..position],
window,
cx,

View File

@@ -26,7 +26,7 @@ use crate::{
ui::AgentOnboardingModal,
};
use agent::{
Thread, ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio,
ThreadError, ThreadEvent, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgentThread,
context_store::ContextStore,
history_store::{HistoryEntryId, HistoryStore},
thread_store::{TextThreadStore, ThreadStore},
@@ -41,7 +41,7 @@ use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, Hsla,
Corner, DismissEvent, Entity, EventEmitter, ExternalPaths, FocusHandle, Focusable, FontWeight,
KeyContext, Pixels, Subscription, Task, UpdateGlobal, WeakEntity, linear_color_stop,
linear_gradient, prelude::*, pulsating_between,
};
@@ -59,7 +59,7 @@ use theme::ThemeSettings;
use time::UtcOffset;
use ui::utils::WithRemSize;
use ui::{
Banner, Callout, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
Banner, CheckboxWithLabel, ContextMenu, ElevationIndex, KeyBinding, PopoverMenu,
PopoverMenuHandle, ProgressBar, Tab, Tooltip, Vector, VectorName, prelude::*,
};
use util::ResultExt as _;
@@ -72,7 +72,7 @@ use zed_actions::{
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding},
assistant::{OpenRulesLibrary, ToggleFocus},
};
use zed_llm_client::{CompletionIntent, UsageLimit};
use zed_llm_client::UsageLimit;
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -122,8 +122,8 @@ pub fn init(cx: &mut App) {
workspace.focus_panel::<AgentPanel>(window, cx);
match &panel.read(cx).active_view {
ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx).thread().clone();
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx);
let agent = thread.read(cx).agent().clone();
AgentDiffPane::deploy_in_workspace(agent, workspace, window, cx);
}
ActiveView::TextThread { .. }
| ActiveView::History
@@ -251,9 +251,9 @@ impl ActiveView {
let new_summary = editor.read(cx).text(cx);
thread.update(cx, |thread, cx| {
thread.thread().update(cx, |thread, cx| {
thread.set_summary(new_summary, cx);
});
thread.agent().update(cx, |agent, cx| {
agent.set_summary(new_summary, cx);
})
})
}
EditorEvent::Blurred => {
@@ -274,11 +274,11 @@ impl ActiveView {
cx.notify();
}
}),
cx.subscribe_in(&active_thread.read(cx).thread().clone(), window, {
cx.subscribe_in(&active_thread.read(cx).agent().clone(), window, {
let editor = editor.clone();
move |_, thread, event, window, cx| match event {
move |_, agent, event, window, cx| match event {
ThreadEvent::SummaryGenerated => {
let summary = thread.read(cx).summary().or_default();
let summary = agent.read(cx).summary().or_default();
editor.update(cx, |editor, cx| {
editor.set_text(summary, window, cx);
@@ -524,7 +524,7 @@ impl AgentPanel {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let agent = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
let project = workspace.project();
@@ -546,13 +546,13 @@ impl AgentPanel {
prompt_store.clone(),
thread_store.downgrade(),
context_store.downgrade(),
thread.clone(),
agent.clone(),
window,
cx,
)
});
let thread_id = thread.read(cx).id().clone();
let thread_id = agent.read(cx).id().clone();
let history_store = cx.new(|cx| {
HistoryStore::new(
thread_store.clone(),
@@ -566,7 +566,7 @@ impl AgentPanel {
let active_thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
agent.clone(),
thread_store.clone(),
context_store.clone(),
message_editor_context_store.clone(),
@@ -607,7 +607,7 @@ impl AgentPanel {
}
};
AgentDiff::set_active_thread(&workspace, &thread, window, cx);
AgentDiff::set_active_thread(&workspace, &agent, window, cx);
let weak_panel = weak_self.clone();
@@ -649,9 +649,9 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => {
thread
.read(cx)
.thread()
.agent()
.clone()
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx));
.update(cx, |agent, cx| agent.get_or_init_configured_model(cx));
}
ActiveView::TextThread { .. }
| ActiveView::History
@@ -753,7 +753,7 @@ impl AgentPanel {
None
};
let thread = self
let agent = self
.thread_store
.update(cx, |this, cx| this.create_thread(cx));
@@ -786,7 +786,7 @@ impl AgentPanel {
let active_thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
agent.clone(),
self.thread_store.clone(),
self.context_store.clone(),
context_store.clone(),
@@ -806,7 +806,7 @@ impl AgentPanel {
self.prompt_store.clone(),
self.thread_store.downgrade(),
self.context_store.downgrade(),
thread.clone(),
agent.clone(),
window,
cx,
)
@@ -823,7 +823,7 @@ impl AgentPanel {
let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx);
self.set_active_view(thread_view, window, cx);
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx);
AgentDiff::set_active_thread(&self.workspace, &agent, window, cx);
}
fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -971,7 +971,7 @@ impl AgentPanel {
pub(crate) fn open_thread(
&mut self,
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -984,7 +984,7 @@ impl AgentPanel {
let active_thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
agent.clone(),
self.thread_store.clone(),
self.context_store.clone(),
context_store.clone(),
@@ -1003,7 +1003,7 @@ impl AgentPanel {
self.prompt_store.clone(),
self.thread_store.downgrade(),
self.context_store.downgrade(),
thread.clone(),
agent.clone(),
window,
cx,
)
@@ -1012,7 +1012,7 @@ impl AgentPanel {
let thread_view = ActiveView::thread(active_thread.clone(), message_editor, window, cx);
self.set_active_view(thread_view, window, cx);
AgentDiff::set_active_thread(&self.workspace, &thread, window, cx);
AgentDiff::set_active_thread(&self.workspace, &agent, window, cx);
}
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
@@ -1137,10 +1137,10 @@ impl AgentPanel {
) {
match &self.active_view {
ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx).thread().clone();
let agent = thread.read(cx).agent().clone();
self.workspace
.update(cx, |workspace, cx| {
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx)
AgentDiffPane::deploy_in_workspace(agent, workspace, window, cx)
})
.log_err();
}
@@ -1190,7 +1190,7 @@ impl AgentPanel {
match &self.active_view {
ActiveView::Thread { thread, .. } => {
active_thread::open_active_thread_as_markdown(
thread.read(cx).thread().clone(),
thread.read(cx).agent().clone(),
workspace,
window,
cx,
@@ -1228,9 +1228,9 @@ impl AgentPanel {
}
}
pub(crate) fn active_thread(&self, cx: &App) -> Option<Entity<Thread>> {
pub(crate) fn active_thread(&self, cx: &App) -> Option<Entity<ZedAgentThread>> {
match &self.active_view {
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
ActiveView::Thread { thread, .. } => Some(thread.read(cx).agent().clone()),
_ => None,
}
}
@@ -1249,23 +1249,16 @@ impl AgentPanel {
return;
};
let thread_state = thread.read(cx).thread().read(cx);
if !thread_state.tool_use_limit_reached() {
let agent_state = thread.read(cx).agent().read(cx);
if !agent_state.tool_use_limit_reached() {
return;
}
let model = thread_state.configured_model().map(|cm| cm.model.clone());
let model = agent_state.configured_model().map(|cm| cm.model.clone());
if let Some(model) = model {
thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, cx| {
thread.insert_invisible_continue_message(cx);
thread.advance_prompt_id();
thread.send_to_model(
model,
CompletionIntent::UserPrompt,
Some(window.window_handle()),
cx,
);
active_thread.agent().update(cx, |agent, cx| {
agent.send_continue_message(model, Some(window.window_handle()), cx);
});
});
} else {
@@ -1284,10 +1277,10 @@ impl AgentPanel {
};
thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
let current_mode = thread.completion_mode();
active_thread.agent().update(cx, |agent, _cx| {
let current_mode = agent.completion_mode();
thread.set_completion_mode(match current_mode {
agent.set_completion_mode(match current_mode {
CompletionMode::Burn => CompletionMode::Normal,
CompletionMode::Normal => CompletionMode::Burn,
});
@@ -1330,7 +1323,7 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => {
let thread = thread.read(cx);
if thread.is_empty() {
let id = thread.thread().read(cx).id().clone();
let id = thread.agent().read(cx).id().clone();
self.history_store.update(cx, |store, cx| {
store.remove_recently_opened_thread(id, cx);
});
@@ -1341,7 +1334,7 @@ impl AgentPanel {
match &new_view {
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
let id = thread.read(cx).thread().read(cx).id().clone();
let id = thread.read(cx).agent().read(cx).id().clone();
store.push_recently_opened_entry(HistoryEntryId::Thread(id), cx);
}),
ActiveView::TextThread { context_editor, .. } => {
@@ -1726,7 +1719,7 @@ impl AgentPanel {
};
let active_thread = match &self.active_view {
ActiveView::Thread { thread, .. } => Some(thread.read(cx).thread().clone()),
ActiveView::Thread { thread, .. } => Some(thread.clone()),
ActiveView::TextThread { .. } | ActiveView::History | ActiveView::Configuration => None,
};
@@ -1761,7 +1754,7 @@ impl AgentPanel {
this.action(
"New From Summary",
Box::new(NewThread {
from_thread_id: Some(thread.id().clone()),
from_thread_id: Some(thread.agent().read(cx).id().clone()),
}),
)
} else {
@@ -1904,14 +1897,14 @@ impl AgentPanel {
return None;
}
let thread = active_thread.thread().read(cx);
let is_generating = thread.is_generating();
let conversation_token_usage = thread.total_token_usage()?;
let agent = active_thread.agent().read(cx);
let is_generating = agent.is_generating();
let conversation_token_usage = agent.total_token_usage(cx)?;
let (total_token_usage, is_estimating) =
if let Some((editing_message_id, unsent_tokens)) = active_thread.editing_message_id() {
let combined = thread
.token_usage_up_to_message(editing_message_id)
let combined = agent
.token_usage_up_to_message(editing_message_id, cx)
.add(unsent_tokens);
(combined, unsent_tokens > 0)
@@ -2022,10 +2015,12 @@ impl AgentPanel {
ActiveView::Thread { thread, .. } => {
let is_using_zed_provider = thread
.read(cx)
.thread()
.agent()
.read(cx)
.configured_model()
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
});
if !is_using_zed_provider {
return false;
@@ -2598,7 +2593,7 @@ impl AgentPanel {
Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => {
parent.child(Banner::new().severity(ui::Severity::Warning).child(
h_flex().w_full().children(provider.render_accept_terms(
LanguageModelProviderTosView::ThreadEmptyState,
LanguageModelProviderTosView::ThreadtEmptyState,
cx,
)),
))
@@ -2620,14 +2615,14 @@ impl AgentPanel {
}
};
let thread = active_thread.read(cx).thread().read(cx);
let agent = active_thread.read(cx).agent().read(cx);
let tool_use_limit_reached = thread.tool_use_limit_reached();
let tool_use_limit_reached = agent.tool_use_limit_reached();
if !tool_use_limit_reached {
return None;
}
let model = thread.configured_model()?.model;
let model = agent.configured_model()?.model;
let focus_handle = self.focus_handle(cx);
@@ -2675,8 +2670,8 @@ impl AgentPanel {
let active_thread = active_thread.clone();
cx.listener(move |this, _, window, cx| {
active_thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
active_thread.agent().update(cx, |agent, _cx| {
agent.set_completion_mode(CompletionMode::Burn);
});
});
this.continue_conversation(window, cx);
@@ -2689,90 +2684,58 @@ impl AgentPanel {
Some(div().px_2().pb_2().child(banner).into_any_element())
}
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
let message = message.into();
IconButton::new("copy", IconName::Copy)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(Tooltip::text("Copy Error Message"))
.on_click(move |_, _, cx| {
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
})
}
fn dismiss_error_button(
&self,
thread: &Entity<ActiveThread>,
cx: &mut Context<Self>,
) -> impl IntoElement {
IconButton::new("dismiss", IconName::Close)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.tooltip(Tooltip::text("Dismiss Error"))
.on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
}
}))
}
fn upgrade_button(
&self,
thread: &Entity<ActiveThread>,
cx: &mut Context<Self>,
) -> impl IntoElement {
Button::new("upgrade", "Upgrade")
.label_size(LabelSize::Small)
.style(ButtonStyle::Tinted(ui::TintColor::Accent))
.on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
}
}))
}
fn error_callout_bg(&self, cx: &Context<Self>) -> Hsla {
cx.theme().status().error.opacity(0.08)
}
fn render_payment_required_error(
&self,
thread: &Entity<ActiveThread>,
cx: &mut Context<Self>,
) -> AnyElement {
const ERROR_MESSAGE: &str =
"You reached your free usage limit. Upgrade to Zed Pro for more prompts.";
const ERROR_MESSAGE: &str = "Free tier exceeded. Subscribe and add payment to continue using Zed LLMs. You'll be billed at cost for tokens used.";
let icon = Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error);
div()
.border_t_1()
.border_color(cx.theme().colors().border)
v_flex()
.gap_0p5()
.child(
Callout::new()
.icon(icon)
.title("Free Usage Exceeded")
.description(ERROR_MESSAGE)
.tertiary_action(self.upgrade_button(thread, cx))
.secondary_action(self.create_copy_button(ERROR_MESSAGE))
.primary_action(self.dismiss_error_button(thread, cx))
.bg_color(self.error_callout_bg(cx)),
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Free Usage Exceeded").weight(FontWeight::MEDIUM)),
)
.into_any_element()
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(ERROR_MESSAGE)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(ERROR_MESSAGE))
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
}
})))
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
}
}))),
)
.into_any()
}
fn render_model_request_limit_reached_error(
@@ -2782,28 +2745,67 @@ impl AgentPanel {
cx: &mut Context<Self>,
) -> AnyElement {
let error_message = match plan {
Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
Plan::ZedProTrial | Plan::Free => "Upgrade to Zed Pro for more prompts.",
Plan::ZedPro => {
"Model request limit reached. Upgrade to usage-based billing for more requests."
}
Plan::ZedProTrial => {
"Model request limit reached. Upgrade to Zed Pro for more requests."
}
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
};
let call_to_action = match plan {
Plan::ZedPro => "Upgrade to usage-based billing",
Plan::ZedProTrial => "Upgrade to Zed Pro",
Plan::Free => "Upgrade to Zed Pro",
};
let icon = Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error);
div()
.border_t_1()
.border_color(cx.theme().colors().border)
v_flex()
.gap_0p5()
.child(
Callout::new()
.icon(icon)
.title("Model Prompt Limit Reached")
.description(error_message)
.tertiary_action(self.upgrade_button(thread, cx))
.secondary_action(self.create_copy_button(error_message))
.primary_action(self.dismiss_error_button(thread, cx))
.bg_color(self.error_callout_bg(cx)),
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new("Model Request Limit Reached").weight(FontWeight::MEDIUM)),
)
.into_any_element()
.child(
div()
.id("error-message")
.max_h_24()
.overflow_y_scroll()
.child(Label::new(error_message)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(error_message))
.child(
Button::new("subscribe", call_to_action).on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.open_url(&zed_urls::account_url(cx));
cx.notify();
}
})),
)
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
}
}))),
)
.into_any()
}
fn render_error_message(
@@ -2814,24 +2816,40 @@ impl AgentPanel {
cx: &mut Context<Self>,
) -> AnyElement {
let message_with_header = format!("{}\n{}", header, message);
let icon = Icon::new(IconName::XCircle)
.size(IconSize::Small)
.color(Color::Error);
div()
.border_t_1()
.border_color(cx.theme().colors().border)
v_flex()
.gap_0p5()
.child(
Callout::new()
.icon(icon)
.title(header)
.description(message.clone())
.primary_action(self.dismiss_error_button(thread, cx))
.secondary_action(self.create_copy_button(message_with_header))
.bg_color(self.error_callout_bg(cx)),
h_flex()
.gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(Label::new(header).weight(FontWeight::MEDIUM)),
)
.into_any_element()
.child(
div()
.id("error-message")
.max_h_32()
.overflow_y_scroll()
.child(Label::new(message.clone())),
)
.child(
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(message_with_header))
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener({
let thread = thread.clone();
move |_, _, _, cx| {
thread.update(cx, |this, _cx| {
this.clear_last_error();
});
cx.notify();
}
}))),
)
.into_any()
}
fn render_prompt_editor(
@@ -2976,6 +2994,15 @@ impl AgentPanel {
}
}
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
let message = message.into();
IconButton::new("copy", IconName::Copy)
.on_click(move |_, _, cx| {
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
})
.tooltip(Tooltip::text("Copy Error Message"))
}
fn key_context(&self) -> KeyContext {
let mut key_context = KeyContext::new_with_defaults();
key_context.add("AgentPanel");
@@ -3028,8 +3055,8 @@ impl Render for AgentPanel {
match &this.active_view {
ActiveView::Thread { thread, .. } => {
thread.update(cx, |active_thread, cx| {
active_thread.thread().update(cx, |thread, _cx| {
thread.set_completion_mode(CompletionMode::Burn);
active_thread.agent().update(cx, |agent, _cx| {
agent.set_completion_mode(CompletionMode::Burn);
});
});
this.continue_conversation(window, cx);
@@ -3057,9 +3084,18 @@ impl Render for AgentPanel {
thread.clone().into_any_element()
})
.children(self.render_tool_use_limit_reached(window, cx))
.child(h_flex().child(message_editor.clone()))
.when_some(thread.read(cx).last_error(), |this, last_error| {
this.child(
div()
.absolute()
.right_3()
.bottom_12()
.max_w_96()
.py_2()
.px_3()
.elevation_2(cx)
.occlude()
.child(match last_error {
ThreadError::PaymentRequired => {
self.render_payment_required_error(thread, cx)
@@ -3073,7 +3109,6 @@ impl Render for AgentPanel {
.into_any(),
)
})
.child(h_flex().child(message_editor.clone()))
.child(self.render_drag_target(cx)),
ActiveView::History => parent.child(self.history.clone()),
ActiveView::TextThread {

View File

@@ -26,7 +26,7 @@ mod ui;
use std::sync::Arc;
use agent::{Thread, ThreadId};
use agent::{ThreadId, ZedAgentThread};
use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection};
use assistant_slash_command::SlashCommandRegistry;
use client::Client;
@@ -54,88 +54,51 @@ pub use ui::preview::{all_agent_previews, get_agent_preview};
actions!(
agent,
[
/// Creates a new text-based conversation thread.
NewTextThread,
/// Toggles the context picker interface for adding files, symbols, or other context.
ToggleContextPicker,
/// Toggles the navigation menu for switching between threads and views.
ToggleNavigationMenu,
/// Toggles the options menu for agent settings and preferences.
ToggleOptionsMenu,
/// Deletes the recently opened thread from history.
DeleteRecentlyOpenThread,
/// Toggles the profile selector for switching between agent profiles.
ToggleProfileSelector,
/// Removes all added context from the current conversation.
RemoveAllContext,
/// Expands the message editor to full size.
ExpandMessageEditor,
/// Opens the conversation history view.
OpenHistory,
/// Adds a context server to the configuration.
AddContextServer,
/// Removes the currently selected thread.
RemoveSelectedThread,
/// Starts a chat conversation with the agent.
Chat,
/// Starts a chat conversation with follow-up enabled.
ChatWithFollow,
/// Cycles to the next inline assist suggestion.
CycleNextInlineAssist,
/// Cycles to the previous inline assist suggestion.
CyclePreviousInlineAssist,
/// Moves focus up in the interface.
FocusUp,
/// Moves focus down in the interface.
FocusDown,
/// Moves focus left in the interface.
FocusLeft,
/// Moves focus right in the interface.
FocusRight,
/// Removes the currently focused context item.
RemoveFocusedContext,
/// Accepts the suggested context item.
AcceptSuggestedContext,
/// Opens the active thread as a markdown file.
OpenActiveThreadAsMarkdown,
/// Opens the agent diff view to review changes.
OpenAgentDiff,
/// Keeps the current suggestion or change.
Keep,
/// Rejects the current suggestion or change.
Reject,
/// Rejects all suggestions or changes.
RejectAll,
/// Keeps all suggestions or changes.
KeepAll,
/// Follows the agent's suggestions.
Follow,
/// Resets the trial upsell notification.
ResetTrialUpsell,
/// Resets the trial end upsell notification.
ResetTrialEndUpsell,
/// Continues the current thread.
ContinueThread,
/// Continues the thread with burn mode enabled.
ContinueWithBurnMode,
/// Toggles burn mode for faster responses.
ToggleBurnMode,
]
);
/// Creates a new conversation thread, optionally based on an existing thread.
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
#[serde(deny_unknown_fields)]
pub struct NewThread {
#[serde(default)]
from_thread_id: Option<ThreadId>,
}
/// Opens the profile management interface for configuring agent tools and settings.
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
#[action(namespace = agent)]
#[serde(deny_unknown_fields)]
pub struct ManageProfiles {
#[serde(default)]
pub customize_tools: Option<AgentProfileId>,
@@ -151,7 +114,7 @@ impl ManageProfiles {
#[derive(Clone)]
pub(crate) enum ModelUsageContext {
Thread(Entity<Thread>),
Thread(Entity<ZedAgentThread>),
InlineAssistant,
}
@@ -246,7 +209,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
}
}
let default = settings.default_model.as_ref().map(to_selected_model);
let default = to_selected_model(&settings.default_model);
let inline_assistant = settings
.inline_assistant_model
.as_ref()
@@ -266,7 +229,7 @@ fn update_active_language_model_from_settings(cx: &mut App) {
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.select_default_model(default.as_ref(), cx);
registry.select_default_model(Some(&default), cx);
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
registry.select_commit_message_model(commit_message.as_ref(), cx);
registry.select_thread_summary_model(thread_summary.as_ref(), cx);

View File

@@ -426,7 +426,6 @@ impl ContextPicker {
this.add_recent_file(project_path.clone(), window, cx);
})
},
None,
)
}
RecentEntry::Thread(thread) => {
@@ -444,7 +443,6 @@ impl ContextPicker {
.detach_and_log_err(cx);
})
},
None,
)
}
}

View File

@@ -22,7 +22,7 @@ use util::ResultExt as _;
use workspace::Workspace;
use agent::{
Thread,
ZedAgentThread,
context::{AgentContextHandle, AgentContextKey, RULES_ICON},
thread_store::{TextThreadStore, ThreadStore},
};
@@ -449,7 +449,7 @@ impl ContextPickerCompletionProvider {
let context_store = context_store.clone();
let thread_store = thread_store.clone();
window.spawn::<_, Option<_>>(cx, async move |cx| {
let thread: Entity<Thread> = thread_store
let thread: Entity<ZedAgentThread> = thread_store
.update_in(cx, |thread_store, window, cx| {
thread_store.open_thread(&thread_id, window, cx)
})
@@ -686,7 +686,6 @@ impl ContextPickerCompletionProvider {
let mut label = CodeLabel::plain(symbol.name.clone(), None);
label.push_str(" ", None);
label.push_str(&file_name, comment_id);
label.push_str(&format!(" L{}", symbol.range.start.0.row + 1), comment_id);
let new_text = format!("{} ", MentionLink::for_symbol(&symbol.name, &full_path));
let new_text_len = new_text.len();

View File

@@ -18,7 +18,6 @@ use ui::{ListItem, ListItemSpacing, prelude::*};
actions!(
agent,
[
/// Toggles the language model selector dropdown.
#[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
ToggleModelSelector
]
@@ -400,7 +399,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let all_models = self.all_models.clone();
let active_model = (self.get_active_model)(cx);
let current_index = self.selected_index;
let bg_executor = cx.background_executor();
let language_model_registry = LanguageModelRegistry::global(cx);
@@ -442,9 +441,12 @@ impl PickerDelegate for LanguageModelPickerDelegate {
cx.spawn_in(window, async move |this, cx| {
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries = filtered_models.entries();
// Finds the currently selected model in the list
let new_index =
Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
// Preserve selection focus
let new_index = if current_index >= this.delegate.filtered_entries.len() {
0
} else {
current_index
};
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
cx.notify();
})

View File

@@ -9,6 +9,7 @@ use crate::ui::{
MaxModeTooltip,
preview::{AgentPreview, UsageCallout},
};
use agent::thread::UserMessageParams;
use agent::{
context::{AgentContextKey, ContextLoadResult, load_context},
context_store::ContextStoreEvent,
@@ -31,7 +32,7 @@ use gpui::{
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language, Point};
use language::{Buffer, Language};
use language_model::{
ConfiguredModel, LanguageModelRequestMessage, MessageContent, ZED_CLOUD_PROVIDER_ID,
};
@@ -47,7 +48,6 @@ use ui::{
};
use util::ResultExt as _;
use workspace::{CollaboratorId, Workspace};
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@@ -58,14 +58,14 @@ use crate::{
ToggleContextPicker, ToggleProfileSelector, register_agent_preview,
};
use agent::{
MessageCrease, Thread, TokenUsageRatio,
MessageCrease, TokenUsageRatio, ZedAgentThread,
context_store::ContextStore,
thread_store::{TextThreadStore, ThreadStore},
};
#[derive(RegisterComponent)]
pub struct MessageEditor {
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
incompatible_tools_state: Entity<IncompatibleToolsState>,
editor: Entity<Editor>,
workspace: WeakEntity<Workspace>,
@@ -156,7 +156,7 @@ impl MessageEditor {
prompt_store: Option<Entity<PromptStore>>,
thread_store: WeakEntity<ThreadStore>,
text_thread_store: WeakEntity<TextThreadStore>,
thread: Entity<Thread>,
agent: Entity<ZedAgentThread>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -182,13 +182,13 @@ impl MessageEditor {
Some(text_thread_store.clone()),
context_picker_menu_handle.clone(),
SuggestContextKind::File,
ModelUsageContext::Thread(thread.clone()),
ModelUsageContext::Thread(agent.clone()),
window,
cx,
)
});
let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(thread.clone(), cx));
let incompatible_tools = cx.new(|cx| IncompatibleToolsState::new(agent.clone(), cx));
let subscriptions = vec![
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
@@ -200,9 +200,7 @@ impl MessageEditor {
// When context changes, reload it for token counting.
let _ = this.reload_context(cx);
}),
cx.observe(&thread.read(cx).action_log().clone(), |_, _, cx| {
cx.notify()
}),
cx.observe(&agent.read(cx).action_log().clone(), |_, _, cx| cx.notify()),
];
let model_selector = cx.new(|cx| {
@@ -210,20 +208,20 @@ impl MessageEditor {
fs.clone(),
model_selector_menu_handle,
editor.focus_handle(cx),
ModelUsageContext::Thread(thread.clone()),
ModelUsageContext::Thread(agent.clone()),
window,
cx,
)
});
let profile_selector =
cx.new(|cx| ProfileSelector::new(fs, thread.clone(), editor.focus_handle(cx), cx));
cx.new(|cx| ProfileSelector::new(fs, agent.clone(), editor.focus_handle(cx), cx));
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
project: agent.read(cx).project().clone(),
user_store,
thread,
agent,
incompatible_tools_state: incompatible_tools.clone(),
workspace,
context_store,
@@ -313,11 +311,11 @@ impl MessageEditor {
return;
}
self.thread.update(cx, |thread, cx| {
self.agent.update(cx, |thread, cx| {
thread.cancel_editing(cx);
});
if self.thread.read(cx).is_generating() {
if self.agent.read(cx).is_generating() {
self.stop_current_and_send_new_message(window, cx);
return;
}
@@ -354,7 +352,7 @@ impl MessageEditor {
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let Some(ConfiguredModel { model, provider }) = self
.thread
.agent
.update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
else {
return;
@@ -375,7 +373,7 @@ impl MessageEditor {
self.last_estimated_token_count.take();
cx.emit(MessageEditorEvent::EstimatedTokenCount);
let thread = self.thread.clone();
let agent = self.agent.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let context_task = self.reload_context(cx);
@@ -385,24 +383,16 @@ impl MessageEditor {
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
let loaded_context = loaded_context.unwrap_or_default();
thread
agent
.update(cx, |thread, cx| {
thread.insert_user_message(
user_message,
loaded_context,
checkpoint.ok(),
user_message_creases,
cx,
);
})
.log_err();
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(
thread.send_message(
UserMessageParams {
text: user_message,
creases: user_message_creases,
checkpoint: checkpoint.ok(),
context: loaded_context,
},
model,
CompletionIntent::UserPrompt,
Some(window_handle),
cx,
);
@@ -413,11 +403,11 @@ impl MessageEditor {
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.thread.update(cx, |thread, cx| {
self.agent.update(cx, |thread, cx| {
thread.cancel_editing(cx);
});
let cancelled = self.thread.update(cx, |thread, cx| {
let cancelled = self.agent.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
});
@@ -459,7 +449,7 @@ impl MessageEditor {
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.edits_expanded = true;
AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
AgentDiffPane::deploy(self.agent.clone(), self.workspace.clone(), window, cx).log_err();
cx.notify();
}
@@ -475,7 +465,7 @@ impl MessageEditor {
cx: &mut Context<Self>,
) {
if let Ok(diff) =
AgentDiffPane::deploy(self.thread.clone(), self.workspace.clone(), window, cx)
AgentDiffPane::deploy(self.agent.clone(), self.workspace.clone(), window, cx)
{
let path_key = multi_buffer::PathKey::for_buffer(&buffer, cx);
diff.update(cx, |diff, cx| diff.move_to_path(path_key, window, cx));
@@ -488,7 +478,7 @@ impl MessageEditor {
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread.update(cx, |thread, _cx| {
self.agent.update(cx, |thread, _cx| {
let active_completion_mode = thread.completion_mode();
thread.set_completion_mode(match active_completion_mode {
@@ -499,36 +489,22 @@ impl MessageEditor {
}
fn handle_accept_all(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
if self.thread.read(cx).has_pending_edit_tool_uses() {
if self.agent.read(cx).has_pending_edit_tool_uses() {
return;
}
self.thread.update(cx, |thread, cx| {
thread.keep_all_edits(cx);
});
let action_log = self.agent.read(cx).action_log();
action_log.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
cx.notify();
}
fn handle_reject_all(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
if self.thread.read(cx).has_pending_edit_tool_uses() {
if self.agent.read(cx).has_pending_edit_tool_uses() {
return;
}
// Since there's no reject_all_edits method in the thread API,
// we need to iterate through all buffers and reject their edits
let action_log = self.thread.read(cx).action_log().clone();
let changed_buffers = action_log.read(cx).changed_buffers(cx);
for (buffer, _) in changed_buffers {
self.thread.update(cx, |thread, cx| {
let buffer_snapshot = buffer.read(cx);
let start = buffer_snapshot.anchor_before(Point::new(0, 0));
let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point());
thread
.reject_edits_in_ranges(buffer, vec![start..end], cx)
.detach();
});
}
let action_log = self.agent.read(cx).action_log();
action_log.update(cx, |action_log, cx| action_log.reject_all_edits(cx));
cx.notify();
}
@@ -538,17 +514,13 @@ impl MessageEditor {
_window: &mut Window,
cx: &mut Context<Self>,
) {
if self.thread.read(cx).has_pending_edit_tool_uses() {
if self.agent.read(cx).has_pending_edit_tool_uses() {
return;
}
self.thread.update(cx, |thread, cx| {
let buffer_snapshot = buffer.read(cx);
let start = buffer_snapshot.anchor_before(Point::new(0, 0));
let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point());
thread
.reject_edits_in_ranges(buffer, vec![start..end], cx)
.detach();
let action_log = self.agent.read(cx).action_log();
action_log.update(cx, |action_log, cx| {
action_log.reject_buffer_edits(buffer, cx)
});
cx.notify();
}
@@ -559,21 +531,19 @@ impl MessageEditor {
_window: &mut Window,
cx: &mut Context<Self>,
) {
if self.thread.read(cx).has_pending_edit_tool_uses() {
if self.agent.read(cx).has_pending_edit_tool_uses() {
return;
}
self.thread.update(cx, |thread, cx| {
let buffer_snapshot = buffer.read(cx);
let start = buffer_snapshot.anchor_before(Point::new(0, 0));
let end = buffer_snapshot.anchor_after(buffer_snapshot.max_point());
thread.keep_edits_in_range(buffer, start..end, cx);
let action_log = self.agent.read(cx).action_log();
action_log.update(cx, |action_log, cx| {
action_log.keep_buffer_edits(buffer, cx)
});
cx.notify();
}
fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let thread = self.thread.read(cx);
let thread = self.agent.read(cx);
let model = thread.configured_model();
if !model?.model.supports_burn_mode() {
return None;
@@ -644,7 +614,7 @@ impl MessageEditor {
}
fn render_editor(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
let thread = self.thread.read(cx);
let thread = self.agent.read(cx);
let model = thread.configured_model();
let editor_bg_color = cx.theme().colors().editor_background;
@@ -945,7 +915,7 @@ impl MessageEditor {
let bg_edit_files_disclosure = editor_bg_color.blend(active_color.opacity(0.3));
let is_edit_changes_expanded = self.edits_expanded;
let thread = self.thread.read(cx);
let thread = self.agent.read(cx);
let pending_edits = thread.has_pending_edit_tool_uses();
const EDIT_NOT_READY_TOOLTIP_LABEL: &str = "Wait until file edits are complete.";
@@ -1247,10 +1217,12 @@ impl MessageEditor {
}
fn is_using_zed_provider(&self, cx: &App) -> bool {
self.thread
self.agent
.read(cx)
.configured_model()
.map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
.map_or(false, |model| {
model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
})
}
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
@@ -1323,7 +1295,7 @@ impl MessageEditor {
Button::new("start-new-thread", "Start New Thread")
.label_size(LabelSize::Small)
.on_click(cx.listener(|this, _, window, cx| {
let from_thread_id = Some(this.thread.read(cx).id().clone());
let from_thread_id = Some(this.agent.read(cx).id().clone());
window.dispatch_action(Box::new(NewThread { from_thread_id }), cx);
})),
);
@@ -1357,10 +1329,11 @@ impl MessageEditor {
fn reload_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
let load_task = cx.spawn(async move |this, cx| {
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this
.context_store
.read(cx)
.new_context_for_thread(this.thread.read(cx), None);
let new_context = this.context_store.read(cx).new_context_for_thread(
this.agent.read(cx),
None,
cx,
);
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {
return;
@@ -1392,7 +1365,7 @@ impl MessageEditor {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
let Some(model) = self.thread.read(cx).configured_model() else {
let Some(model) = self.agent.read(cx).configured_model() else {
self.last_estimated_token_count.take();
return;
};
@@ -1597,16 +1570,16 @@ impl Focusable for MessageEditor {
impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx);
let token_usage_ratio = thread
.total_token_usage()
let agent = self.agent.read(cx);
let token_usage_ratio = agent
.total_token_usage(cx)
.map_or(TokenUsageRatio::Normal, |total_token_usage| {
total_token_usage.ratio()
});
let burn_mode_enabled = thread.completion_mode() == CompletionMode::Burn;
let burn_mode_enabled = agent.completion_mode() == CompletionMode::Burn;
let action_log = self.thread.read(cx).action_log();
let action_log = agent.action_log();
let changed_buffers = action_log.read(cx).changed_buffers(cx);
let line_height = TextSize::Small.rems(cx).to_pixels(window.rem_size()) * 1.5;
@@ -1689,7 +1662,7 @@ impl AgentPreview for MessageEditor {
let weak_project = project.downgrade();
let context_store = cx.new(|_cx| ContextStore::new(weak_project, None));
let active_thread = active_thread.read(cx);
let thread = active_thread.thread().clone();
let agent = active_thread.agent().clone();
let thread_store = active_thread.thread_store().clone();
let text_thread_store = active_thread.text_thread_store().clone();
@@ -1702,7 +1675,7 @@ impl AgentPreview for MessageEditor {
None,
thread_store.downgrade(),
text_thread_store.downgrade(),
thread,
agent,
window,
cx,
)

View File

@@ -1,6 +1,6 @@
use crate::{ManageProfiles, ToggleProfileSelector};
use agent::{
Thread,
ZedAgentThread,
agent_profile::{AgentProfile, AvailableProfiles},
};
use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
@@ -17,7 +17,7 @@ use ui::{
pub struct ProfileSelector {
profiles: AvailableProfiles,
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
thread: Entity<ZedAgentThread>,
menu_handle: PopoverMenuHandle<ContextMenu>,
focus_handle: FocusHandle,
_subscriptions: Vec<Subscription>,
@@ -26,7 +26,7 @@ pub struct ProfileSelector {
impl ProfileSelector {
pub fn new(
fs: Arc<dyn Fs>,
thread: Entity<Thread>,
thread: Entity<ZedAgentThread>,
focus_handle: FocusHandle,
cx: &mut Context<Self>,
) -> Self {

View File

@@ -85,24 +85,16 @@ use assistant_context::{
actions!(
assistant,
[
/// Sends the current message to the assistant.
Assist,
/// Confirms and executes the entered slash command.
ConfirmCommand,
/// Copies code from the assistant's response to the clipboard.
CopyCode,
/// Cycles between user and assistant message roles.
CycleMessageRole,
/// Inserts the selected text into the active editor.
InsertIntoEditor,
/// Quotes the current selection in the assistant conversation.
QuoteSelection,
/// Splits the conversation at the current cursor position.
Split,
]
);
/// Inserts files that were dragged and dropped into the assistant conversation.
#[derive(PartialEq, Clone, Action)]
#[action(namespace = assistant, no_json, no_register)]
pub enum InsertDraggedFiles {

View File

@@ -1,4 +1,4 @@
use agent::{Thread, ThreadEvent};
use agent::{ThreadEvent, ZedAgentThread};
use assistant_tool::{Tool, ToolSource};
use collections::HashMap;
use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
@@ -8,12 +8,12 @@ use ui::prelude::*;
pub struct IncompatibleToolsState {
cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
thread: Entity<Thread>,
thread: Entity<ZedAgentThread>,
_thread_subscription: Subscription,
}
impl IncompatibleToolsState {
pub fn new(thread: Entity<Thread>, cx: &mut Context<Self>) -> Self {
pub fn new(thread: Entity<ZedAgentThread>, cx: &mut Context<Self>) -> Self {
let _tool_working_set_subscription =
cx.subscribe(&thread, |this, _, event, _| match event {
ThreadEvent::ProfileChanged => {
@@ -42,8 +42,8 @@ impl IncompatibleToolsState {
.profile()
.enabled_tools(cx)
.iter()
.filter(|(_, tool)| tool.input_schema(model.tool_input_format()).is_err())
.map(|(_, tool)| tool.clone())
.filter(|tool| tool.input_schema(model.tool_input_format()).is_err())
.cloned()
.collect()
})
}

View File

@@ -1,4 +1,5 @@
mod agent_notification;
mod animated_label;
mod burn_mode_tooltip;
mod context_pill;
mod onboarding_modal;
@@ -6,6 +7,7 @@ pub mod preview;
mod upsell;
pub use agent_notification::*;
pub use animated_label::*;
pub use burn_mode_tooltip::*;
pub use context_pill::*;
pub use onboarding_modal::*;

View File

@@ -1,24 +1,24 @@
use crate::prelude::*;
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
use std::time::Duration;
use ui::prelude::*;
#[derive(IntoElement)]
pub struct LoadingLabel {
pub struct AnimatedLabel {
base: Label,
text: SharedString,
}
impl LoadingLabel {
impl AnimatedLabel {
pub fn new(text: impl Into<SharedString>) -> Self {
let text = text.into();
LoadingLabel {
AnimatedLabel {
base: Label::new(text.clone()),
text,
}
}
}
impl LabelCommon for LoadingLabel {
impl LabelCommon for AnimatedLabel {
fn size(mut self, size: LabelSize) -> Self {
self.base = self.base.size(size);
self
@@ -80,14 +80,14 @@ impl LabelCommon for LoadingLabel {
}
}
impl RenderOnce for LoadingLabel {
impl RenderOnce for AnimatedLabel {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let text = self.text.clone();
self.base
.color(Color::Muted)
.with_animations(
"loading_label",
"animated-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),

View File

@@ -488,7 +488,7 @@ impl AddedContext {
parent: None,
tooltip: None,
icon_path: None,
status: if handle.thread.read(cx).is_generating_detailed_summary() {
status: if handle.agent.read(cx).is_generating_detailed_summary() {
ContextStatus::Loading {
message: "Summarizing…".into(),
}
@@ -496,9 +496,9 @@ impl AddedContext {
ContextStatus::Ready
},
render_hover: {
let thread = handle.thread.clone();
let agent = handle.agent.clone();
Some(Rc::new(move |_, cx| {
let text = thread.read(cx).latest_detailed_summary_or_text();
let text = agent.read(cx).latest_detailed_summary_or_text(cx);
ContextPillHover::new_text(text.clone(), cx).into()
}))
},

View File

@@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@@ -356,7 +356,7 @@ pub async fn complete(
.send(request)
.await
.map_err(AnthropicError::HttpSend)?;
let status_code = response.status();
let status = response.status();
let mut body = String::new();
response
.body_mut()
@@ -364,12 +364,12 @@ pub async fn complete(
.await
.map_err(AnthropicError::ReadResponse)?;
if status_code.is_success() {
if status.is_success() {
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
} else {
Err(AnthropicError::HttpResponseError {
status_code,
message: body,
status: status.as_u16(),
body,
})
}
}
@@ -444,7 +444,11 @@ impl RateLimitInfo {
}
Self {
retry_after: parse_retry_after(headers),
retry_after: headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@@ -453,17 +457,6 @@ impl RateLimitInfo {
}
}
/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
/// seconds). Note that other services might specify an HTTP date or some other format for this
/// header. Returns `None` if the header is not present or cannot be parsed.
pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
headers
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_secs)
}
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
Ok(headers
.get(key)
@@ -527,10 +520,6 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
} else if response.status().as_u16() == 529 {
Err(AnthropicError::ServerOverloaded {
retry_after: rate_limits.retry_after,
})
} else if let Some(retry_after) = rate_limits.retry_after {
Err(AnthropicError::RateLimit { retry_after })
} else {
@@ -543,9 +532,10 @@ pub async fn stream_completion_with_rate_limit_info(
match serde_json::from_str::<Event>(&body) {
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
status_code: response.status(),
message: body,
Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
Err(_) => Err(AnthropicError::HttpResponseError {
status: response.status().as_u16(),
body: body,
}),
}
}
@@ -811,19 +801,16 @@ pub enum AnthropicError {
ReadResponse(io::Error),
/// HTTP error response from the API
HttpResponseError {
status_code: StatusCode,
message: String,
},
HttpResponseError { status: u16, body: String },
/// Rate limit exceeded
RateLimit { retry_after: Duration },
/// Server overloaded
ServerOverloaded { retry_after: Option<Duration> },
/// API returned an error response
ApiError(ApiError),
/// Unexpected response format
UnexpectedResponseFormat(String),
}
#[derive(Debug, Serialize, Deserialize, Error)]

View File

@@ -2140,8 +2140,7 @@ impl AssistantContext {
);
}
LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
LanguageModelCompletionEvent::UsageUpdate(_) => {}
}
});

View File

@@ -5,6 +5,9 @@ edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[features]
test-support = []
[lints]
workspace = true
@@ -22,7 +25,6 @@ gpui.workspace = true
icons.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true

View File

@@ -495,6 +495,10 @@ impl ActionLog {
cx.notify();
}
pub fn keep_buffer_edits(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.keep_edits_in_range(buffer, Anchor::MIN..Anchor::MAX, cx);
}
pub fn keep_edits_in_range(
&mut self,
buffer: Entity<Buffer>,
@@ -555,6 +559,19 @@ impl ActionLog {
}
}
pub fn reject_all_edits(&mut self, cx: &mut Context<Self>) {
let changed_buffers = self.changed_buffers(cx);
for (buffer, _) in changed_buffers {
self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx)
.detach();
}
}
pub fn reject_buffer_edits(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.reject_edits_in_ranges(buffer, vec![Anchor::MIN..Anchor::MAX], cx)
.detach()
}
pub fn reject_edits_in_ranges(
&mut self,
buffer: Entity<Buffer>,

View File

@@ -70,7 +70,7 @@ pub struct ToolResultOutput {
pub output: Option<serde_json::Value>,
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ToolResultContent {
Text(String),
Image(LanguageModelImage),
@@ -135,7 +135,8 @@ pub trait ToolCard: 'static + Sized {
) -> impl IntoElement;
}
#[derive(Clone)]
#[derive(Debug, Clone)]
#[cfg_attr(any(test, feature = "test-support"), derive(PartialEq, Eq))]
pub struct AnyToolCard {
entity: gpui::AnyEntity,
render: fn(

View File

@@ -25,15 +25,10 @@ fn preprocess_json_schema(json: &mut Value) -> Result<()> {
// `additionalProperties` defaults to `false` unless explicitly specified.
// This prevents models from hallucinating tool parameters.
if let Value::Object(obj) = json {
if matches!(obj.get("type"), Some(Value::String(s)) if s == "object") {
if !obj.contains_key("additionalProperties") {
if let Some(Value::String(type_str)) = obj.get("type") {
if type_str == "object" && !obj.contains_key("additionalProperties") {
obj.insert("additionalProperties".to_string(), Value::Bool(false));
}
// OpenAI API requires non-missing `properties`
if !obj.contains_key("properties") {
obj.insert("properties".to_string(), Value::Object(Default::default()));
}
}
}
Ok(())

View File

@@ -1,52 +1,18 @@
use std::{borrow::Borrow, sync::Arc};
use std::sync::Arc;
use collections::{HashMap, IndexMap};
use gpui::App;
use crate::{Tool, ToolRegistry, ToolSource};
use collections::{HashMap, HashSet, IndexMap};
use gpui::{App, SharedString};
use util::debug_panic;
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize);
/// A unique identifier for a tool within a working set.
#[derive(Clone, PartialEq, Eq, Hash, Default)]
pub struct UniqueToolName(SharedString);
impl Borrow<str> for UniqueToolName {
fn borrow(&self) -> &str {
&self.0
}
}
impl From<String> for UniqueToolName {
fn from(value: String) -> Self {
UniqueToolName(SharedString::new(value))
}
}
impl Into<String> for UniqueToolName {
fn into(self) -> String {
self.0.into()
}
}
impl std::fmt::Debug for UniqueToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::fmt::Display for UniqueToolName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_ref())
}
}
/// A working set of tools for use in one instance of the Assistant Panel.
#[derive(Default)]
pub struct ToolWorkingSet {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<UniqueToolName, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
next_tool_id: ToolId,
}
@@ -58,20 +24,16 @@ impl ToolWorkingSet {
.or_else(|| ToolRegistry::global(cx).tool(name))
}
pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc<dyn Tool>)> {
let mut tools = ToolRegistry::global(cx)
.tools()
.into_iter()
.map(|tool| (UniqueToolName(tool.name().into()), tool))
.collect::<Vec<_>>();
tools.extend(self.context_server_tools_by_name.clone());
pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
let mut tools = ToolRegistry::global(cx).tools();
tools.extend(self.context_server_tools_by_id.values().cloned());
tools
}
pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
let mut tools_by_source = IndexMap::default();
for (_, tool) in self.tools(cx) {
for tool in self.tools(cx) {
tools_by_source
.entry(tool.source())
.or_insert_with(Vec::new)
@@ -87,324 +49,27 @@ impl ToolWorkingSet {
tools_by_source
}
pub fn insert(&mut self, tool: Arc<dyn Tool>, cx: &App) -> ToolId {
let tool_id = self.register_tool(tool);
self.tools_changed(cx);
tool_id
}
pub fn extend(&mut self, tools: impl Iterator<Item = Arc<dyn Tool>>, cx: &App) -> Vec<ToolId> {
let ids = tools.map(|tool| self.register_tool(tool)).collect();
self.tools_changed(cx);
ids
}
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed(cx);
}
fn register_tool(&mut self, tool: Arc<dyn Tool>) -> ToolId {
pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
let tool_id = self.next_tool_id;
self.next_tool_id.0 += 1;
self.context_server_tools_by_id
.insert(tool_id, tool.clone());
self.tools_changed();
tool_id
}
fn tools_changed(&mut self, cx: &App) {
self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts(
&self
.context_server_tools_by_id
pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
self.context_server_tools_by_id
.retain(|id, _| !tool_ids_to_remove.contains(id));
self.tools_changed();
}
fn tools_changed(&mut self) {
self.context_server_tools_by_name.clear();
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
.cloned()
.collect::<Vec<_>>(),
&ToolRegistry::global(cx).tools(),
.map(|tool| (tool.name(), tool.clone())),
);
}
}
fn resolve_context_server_tool_name_conflicts(
context_server_tools: &[Arc<dyn Tool>],
native_tools: &[Arc<dyn Tool>],
) -> HashMap<UniqueToolName, Arc<dyn Tool>> {
fn resolve_tool_name(tool: &Arc<dyn Tool>) -> String {
let mut tool_name = tool.name();
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
tool_name
}
const MAX_TOOL_NAME_LENGTH: usize = 64;
let mut duplicated_tool_names = HashSet::default();
let mut seen_tool_names = HashSet::default();
seen_tool_names.extend(native_tools.iter().map(|tool| tool.name()));
for tool in context_server_tools {
let tool_name = resolve_tool_name(tool);
if seen_tool_names.contains(&tool_name) {
debug_assert!(
tool.source() != ToolSource::Native,
"Expected MCP tool but got a native tool: {}",
tool_name
);
duplicated_tool_names.insert(tool_name);
} else {
seen_tool_names.insert(tool_name);
}
}
if duplicated_tool_names.is_empty() {
return context_server_tools
.into_iter()
.map(|tool| (resolve_tool_name(tool).into(), tool.clone()))
.collect();
}
context_server_tools
.into_iter()
.filter_map(|tool| {
let mut tool_name = resolve_tool_name(tool);
if !duplicated_tool_names.contains(&tool_name) {
return Some((tool_name.into(), tool.clone()));
}
match tool.source() {
ToolSource::Native => {
debug_panic!("Expected MCP tool but got a native tool: {}", tool_name);
// Built-in tools always keep their original name
Some((tool_name.into(), tool.clone()))
}
ToolSource::ContextServer { id } => {
// Context server tools are prefixed with the context server ID, and truncated if necessary
tool_name.insert(0, '_');
if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH {
let len = MAX_TOOL_NAME_LENGTH - tool_name.len();
let mut id = id.to_string();
id.truncate(len);
tool_name.insert_str(0, &id);
} else {
tool_name.insert_str(0, &id);
}
tool_name.truncate(MAX_TOOL_NAME_LENGTH);
if seen_tool_names.contains(&tool_name) {
log::error!("Cannot resolve tool name conflict for tool {}", tool.name());
None
} else {
Some((tool_name.into(), tool.clone()))
}
}
}
})
.collect()
}
#[cfg(test)]
mod tests {
use gpui::{AnyWindowHandle, Entity, Task, TestAppContext};
use language_model::{LanguageModel, LanguageModelRequest};
use project::Project;
use crate::{ActionLog, ToolResult};
use super::*;
#[gpui::test]
fn test_unique_tool_names(cx: &mut TestAppContext) {
fn assert_tool(
tool_working_set: &ToolWorkingSet,
unique_name: &str,
expected_name: &str,
expected_source: ToolSource,
cx: &App,
) {
let tool = tool_working_set.tool(unique_name, cx).unwrap();
assert_eq!(tool.name(), expected_name);
assert_eq!(tool.source(), expected_source);
}
let tool_registry = cx.update(ToolRegistry::default_global);
tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native));
tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native));
let mut tool_working_set = ToolWorkingSet::default();
cx.update(|cx| {
tool_working_set.extend(
vec![
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
)) as Arc<dyn Tool>,
Arc::new(TestTool::new(
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
)) as Arc<dyn Tool>,
]
.into_iter(),
cx,
);
});
cx.update(|cx| {
assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx);
assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx);
assert_tool(
&tool_working_set,
"mcp-1_tool2",
"tool2",
ToolSource::ContextServer { id: "mcp-1".into() },
cx,
);
assert_tool(
&tool_working_set,
"mcp-2_tool2",
"tool2",
ToolSource::ContextServer { id: "mcp-2".into() },
cx,
);
})
}
#[gpui::test]
fn test_resolve_context_server_tool_name_conflicts() {
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
],
vec![TestTool::new(
"tool3",
ToolSource::ContextServer { id: "mcp-1".into() },
)],
vec!["tool3"],
);
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
],
vec![
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["mcp-1_tool3", "mcp-2_tool3"],
);
assert_resolve_context_server_tool_name_conflicts(
vec![
TestTool::new("tool1", ToolSource::Native),
TestTool::new("tool2", ToolSource::Native),
TestTool::new("tool3", ToolSource::Native),
],
vec![
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }),
TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }),
],
vec!["mcp-1_tool3", "mcp-2_tool3"],
);
// Test deduplication of tools with very long names, in this case the mcp server name should be truncated
assert_resolve_context_server_tool_name_conflicts(
vec![TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::Native,
)],
vec![TestTool::new(
"tool-with-very-very-very-long-name",
ToolSource::ContextServer {
id: "mcp-with-very-very-very-long-name".into(),
},
)],
vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"],
);
fn assert_resolve_context_server_tool_name_conflicts(
builtin_tools: Vec<TestTool>,
context_server_tools: Vec<TestTool>,
expected: Vec<&'static str>,
) {
let context_server_tools: Vec<Arc<dyn Tool>> = context_server_tools
.into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>)
.collect();
let builtin_tools: Vec<Arc<dyn Tool>> = builtin_tools
.into_iter()
.map(|t| Arc::new(t) as Arc<dyn Tool>)
.collect();
let tools =
resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools);
assert_eq!(tools.len(), expected.len());
for (i, (name, _)) in tools.into_iter().enumerate() {
assert_eq!(
name.0.as_ref(),
expected[i],
"Expected '{}' got '{}' at index {}",
expected[i],
name,
i
);
}
}
}
struct TestTool {
name: String,
source: ToolSource,
}
impl TestTool {
fn new(name: impl Into<String>, source: ToolSource) -> Self {
Self {
name: name.into(),
source,
}
}
}
impl Tool for TestTool {
fn name(&self) -> String {
self.name.clone()
}
fn icon(&self) -> icons::IconName {
icons::IconName::Ai
}
fn may_perform_edits(&self) -> bool {
false
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
true
}
fn source(&self) -> ToolSource {
self.source.clone()
}
fn description(&self) -> String {
"Test tool".to_string()
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Test tool".to_string()
}
fn run(
self: Arc<Self>,
_input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
ToolResult {
output: Task::ready(Err(anyhow::anyhow!("No content"))),
card: None,
}
}
}
}

View File

@@ -11,7 +11,6 @@ mod list_directory_tool;
mod move_path_tool;
mod now_tool;
mod open_tool;
mod project_notifications_tool;
mod read_file_tool;
mod schema;
mod templates;
@@ -46,7 +45,6 @@ pub use edit_file_tool::{EditFileMode, EditFileToolInput};
pub use find_path_tool::FindPathToolInput;
pub use grep_tool::{GrepTool, GrepToolInput};
pub use open_tool::OpenTool;
pub use project_notifications_tool::ProjectNotificationsTool;
pub use read_file_tool::{ReadFileTool, ReadFileToolInput};
pub use terminal_tool::TerminalTool;
@@ -63,7 +61,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ListDirectoryTool);
registry.register_tool(NowTool);
registry.register_tool(OpenTool);
registry.register_tool(ProjectNotificationsTool);
registry.register_tool(FindPathTool);
registry.register_tool(ReadFileTool);
registry.register_tool(GrepTool);

View File

@@ -29,7 +29,6 @@ use std::{
path::Path,
str::FromStr,
sync::mpsc,
time::Duration,
};
use util::path;
@@ -1659,14 +1658,12 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
match request().await {
Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
Ok(err) => match &err {
LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
| LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
Ok(err) => match err {
LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
eprintln!(
"Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
"Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
);
Timer::after(retry_after + jitter).await;
continue;

View File

@@ -9132,7 +9132,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| lines.sort())
self.manipulate_immutable_lines(window, cx, |lines| lines.sort())
}
pub fn sort_lines_case_insensitive(
@@ -9141,7 +9141,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| {
self.manipulate_immutable_lines(window, cx, |lines| {
lines.sort_by_key(|line| line.to_lowercase())
})
}
@@ -9152,7 +9152,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| {
self.manipulate_immutable_lines(window, cx, |lines| {
let mut seen = HashSet::default();
lines.retain(|line| seen.insert(line.to_lowercase()));
})
@@ -9164,7 +9164,7 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
self.manipulate_lines(window, cx, |lines| {
self.manipulate_immutable_lines(window, cx, |lines| {
let mut seen = HashSet::default();
lines.retain(|line| seen.insert(*line));
})
@@ -9606,20 +9606,20 @@ impl Editor {
}
pub fn reverse_lines(&mut self, _: &ReverseLines, window: &mut Window, cx: &mut Context<Self>) {
self.manipulate_lines(window, cx, |lines| lines.reverse())
self.manipulate_immutable_lines(window, cx, |lines| lines.reverse())
}
pub fn shuffle_lines(&mut self, _: &ShuffleLines, window: &mut Window, cx: &mut Context<Self>) {
self.manipulate_lines(window, cx, |lines| lines.shuffle(&mut thread_rng()))
self.manipulate_immutable_lines(window, cx, |lines| lines.shuffle(&mut thread_rng()))
}
fn manipulate_lines<Fn>(
fn manipulate_lines<M>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
mut callback: Fn,
mut manipulate: M,
) where
Fn: FnMut(&mut Vec<&str>),
M: FnMut(&str) -> LineManipulationResult,
{
self.hide_mouse_cursor(&HideMouseCursorOrigin::TypingAction);
@@ -9652,18 +9652,14 @@ impl Editor {
.text_for_range(start_point..end_point)
.collect::<String>();
let mut lines = text.split('\n').collect_vec();
let LineManipulationResult { new_text, line_count_before, line_count_after} = manipulate(&text);
let lines_before = lines.len();
callback(&mut lines);
let lines_after = lines.len();
edits.push((start_point..end_point, lines.join("\n")));
edits.push((start_point..end_point, new_text));
// Selections must change based on added and removed line count
let start_row =
MultiBufferRow(start_point.row + added_lines as u32 - removed_lines as u32);
let end_row = MultiBufferRow(start_row.0 + lines_after.saturating_sub(1) as u32);
let end_row = MultiBufferRow(start_row.0 + line_count_after.saturating_sub(1) as u32);
new_selections.push(Selection {
id: selection.id,
start: start_row,
@@ -9672,10 +9668,10 @@ impl Editor {
reversed: selection.reversed,
});
if lines_after > lines_before {
added_lines += lines_after - lines_before;
} else if lines_before > lines_after {
removed_lines += lines_before - lines_after;
if line_count_after > line_count_before {
added_lines += line_count_after - line_count_before;
} else if line_count_before > line_count_after {
removed_lines += line_count_before - line_count_after;
}
}
@@ -9720,6 +9716,171 @@ impl Editor {
})
}
fn manipulate_immutable_lines<Fn>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
mut callback: Fn,
) where
Fn: FnMut(&mut Vec<&str>),
{
self.manipulate_lines(window, cx, |text| {
let mut lines: Vec<&str> = text.split('\n').collect();
let line_count_before = lines.len();
callback(&mut lines);
LineManipulationResult {
new_text: lines.join("\n"),
line_count_before,
line_count_after: lines.len(),
}
});
}
fn manipulate_mutable_lines<Fn>(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
mut callback: Fn,
) where
Fn: FnMut(&mut Vec<Cow<'_, str>>),
{
self.manipulate_lines(window, cx, |text| {
let mut lines: Vec<Cow<str>> = text.split('\n').map(Cow::from).collect();
let line_count_before = lines.len();
callback(&mut lines);
LineManipulationResult {
new_text: lines.join("\n"),
line_count_before,
line_count_after: lines.len(),
}
});
}
pub fn convert_indentation_to_spaces(
&mut self,
_: &ConvertIndentationToSpaces,
window: &mut Window,
cx: &mut Context<Self>,
) {
let settings = self.buffer.read(cx).language_settings(cx);
let tab_size = settings.tab_size.get() as usize;
self.manipulate_mutable_lines(window, cx, |lines| {
// Allocates a reasonably sized scratch buffer once for the whole loop
let mut reindented_line = String::with_capacity(MAX_LINE_LEN);
// Avoids recomputing spaces that could be inserted many times
let space_cache: Vec<Vec<char>> = (1..=tab_size)
.map(|n| IndentSize::spaces(n as u32).chars().collect())
.collect();
for line in lines.iter_mut().filter(|line| !line.is_empty()) {
let mut chars = line.as_ref().chars();
let mut col = 0;
let mut changed = false;
while let Some(ch) = chars.next() {
match ch {
' ' => {
reindented_line.push(' ');
col += 1;
}
'\t' => {
// \t are converted to spaces depending on the current column
let spaces_len = tab_size - (col % tab_size);
reindented_line.extend(&space_cache[spaces_len - 1]);
col += spaces_len;
changed = true;
}
_ => {
// If we dont append before break, the character is consumed
reindented_line.push(ch);
break;
}
}
}
if !changed {
reindented_line.clear();
continue;
}
// Append the rest of the line and replace old reference with new one
reindented_line.extend(chars);
*line = Cow::Owned(reindented_line.clone());
reindented_line.clear();
}
});
}
pub fn convert_indentation_to_tabs(
&mut self,
_: &ConvertIndentationToTabs,
window: &mut Window,
cx: &mut Context<Self>,
) {
let settings = self.buffer.read(cx).language_settings(cx);
let tab_size = settings.tab_size.get() as usize;
self.manipulate_mutable_lines(window, cx, |lines| {
// Allocates a reasonably sized buffer once for the whole loop
let mut reindented_line = String::with_capacity(MAX_LINE_LEN);
// Avoids recomputing spaces that could be inserted many times
let space_cache: Vec<Vec<char>> = (1..=tab_size)
.map(|n| IndentSize::spaces(n as u32).chars().collect())
.collect();
for line in lines.iter_mut().filter(|line| !line.is_empty()) {
let mut chars = line.chars();
let mut spaces_count = 0;
let mut first_non_indent_char = None;
let mut changed = false;
while let Some(ch) = chars.next() {
match ch {
' ' => {
// Keep track of spaces. Append \t when we reach tab_size
spaces_count += 1;
changed = true;
if spaces_count == tab_size {
reindented_line.push('\t');
spaces_count = 0;
}
}
'\t' => {
reindented_line.push('\t');
spaces_count = 0;
}
_ => {
// Dont append it yet, we might have remaining spaces
first_non_indent_char = Some(ch);
break;
}
}
}
if !changed {
reindented_line.clear();
continue;
}
// Remaining spaces that didn't make a full tab stop
if spaces_count > 0 {
reindented_line.extend(&space_cache[spaces_count - 1]);
}
// If we consume an extra character that was not indentation, add it back
if let Some(extra_char) = first_non_indent_char {
reindented_line.push(extra_char);
}
// Append the rest of the line and replace old reference with new one
reindented_line.extend(chars);
*line = Cow::Owned(reindented_line.clone());
reindented_line.clear();
}
});
}
pub fn convert_to_upper_case(
&mut self,
_: &ConvertToUpperCase,
@@ -21157,6 +21318,13 @@ pub struct LineHighlight {
pub type_id: Option<TypeId>,
}
struct LineManipulationResult {
pub new_text: String,
pub line_count_before: usize,
pub line_count_after: usize,
}
fn render_diff_hunk_controls(
row: u32,
status: &DiffHunkStatus,

View File

@@ -1,193 +0,0 @@
use crate::schema::json_schema_for;
use anyhow::Result;
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write as _;
use std::sync::Arc;
use ui::IconName;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ProjectUpdatesToolInput {}
pub struct ProjectNotificationsTool;
impl Tool for ProjectNotificationsTool {
fn name(&self) -> String {
"project_notifications".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn may_perform_edits(&self) -> bool {
false
}
fn description(&self) -> String {
include_str!("./project_notifications_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Envelope
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<ProjectUpdatesToolInput>(format)
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Check project notifications".into()
}
fn run(
self: Arc<Self>,
_input: serde_json::Value,
_request: Arc<LanguageModelRequest>,
_project: Entity<Project>,
action_log: Entity<ActionLog>,
_model: Arc<dyn LanguageModel>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let mut stale_files = String::new();
let action_log = action_log.read(cx);
for stale_file in action_log.stale_buffers(cx) {
if let Some(file) = stale_file.read(cx).file() {
writeln!(&mut stale_files, "- {}", file.path().display()).ok();
}
}
let response = if stale_files.is_empty() {
"No new notifications".to_string()
} else {
// NOTE: Changes to this prompt require a symmetric update in the LLM Worker
const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt");
format!("{HEADER}{stale_files}").replace("\r\n", "\n")
};
Task::ready(Ok(response.into())).into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use assistant_tool::ToolResultContent;
use gpui::{AppContext, TestAppContext};
use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use std::sync::Arc;
use util::path;
#[gpui::test]
async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/test"),
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
)
.await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let buffer_path = project
.read_with(cx, |project, cx| {
project.find_project_path("test/code.rs", cx)
})
.unwrap();
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(buffer_path.clone(), cx)
})
.await
.unwrap();
// Start tracking the buffer
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx);
});
// Run the tool before any changes
let tool = Arc::new(ProjectNotificationsTool);
let provider = Arc::new(FakeLanguageModelProvider);
let model: Arc<dyn LanguageModel> = Arc::new(provider.test_model());
let request = Arc::new(LanguageModelRequest::default());
let tool_input = json!({});
let result = cx.update(|cx| {
tool.clone().run(
tool_input.clone(),
request.clone(),
project.clone(),
action_log.clone(),
model.clone(),
None,
cx,
)
});
let response = result.output.await.unwrap();
let response_text = match &response.content {
ToolResultContent::Text(text) => text.clone(),
_ => panic!("Expected text response"),
};
assert_eq!(
response_text.as_str(),
"No new notifications",
"Tool should return 'No new notifications' when no stale buffers"
);
// Modify the buffer (makes it stale)
buffer.update(cx, |buffer, cx| {
buffer.edit([(1..1, "\nChange!\n")], None, cx);
});
// Run the tool again
let result = cx.update(|cx| {
tool.run(
tool_input.clone(),
request.clone(),
project.clone(),
action_log,
model.clone(),
None,
cx,
)
});
// This time the buffer is stale, so the tool should return a notification
let response = result.output.await.unwrap();
let response_text = match &response.content {
ToolResultContent::Text(text) => text.clone(),
_ => panic!("Expected text response"),
};
let expected_content = "[The following is an auto-generated notification; do not reply]\n\nThese files have changed since the last read:\n- code.rs\n";
assert_eq!(
response_text.as_str(),
expected_content,
"Tool should return the stale buffer notification"
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
assistant_tool::init(cx);
});
}
}

View File

@@ -1,3 +0,0 @@
This tool reports which files have been modified by the user since the agent last accessed them.
It serves as a notification mechanism to inform the agent of recent changes. No immediate action is required in response to these updates.

View File

@@ -1,3 +0,0 @@
[The following is an auto-generated notification; do not reply]
These files have changed since the last read:

View File

@@ -1,9 +1,8 @@
use anyhow::Result;
use language_model::LanguageModelToolSchemaFormat;
use schemars::{
JsonSchema, Schema,
generate::SchemaSettings,
transform::{Transform, transform_subschemas},
JsonSchema,
schema::{RootSchema, Schema, SchemaObject},
};
pub fn json_schema_for<T: JsonSchema>(
@@ -14,7 +13,7 @@ pub fn json_schema_for<T: JsonSchema>(
}
fn schema_to_json(
schema: &Schema,
schema: &RootSchema,
format: LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut value = serde_json::to_value(schema)?;
@@ -22,40 +21,58 @@ fn schema_to_json(
Ok(value)
}
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> Schema {
fn root_schema_for<T: JsonSchema>(format: LanguageModelToolSchemaFormat) -> RootSchema {
let mut generator = match format {
LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
})
.with_transform(ToJsonSchemaSubsetTransform)
.into_generator(),
LanguageModelToolSchemaFormat::JsonSchema => schemars::SchemaGenerator::default(),
LanguageModelToolSchemaFormat::JsonSchemaSubset => {
schemars::r#gen::SchemaSettings::default()
.with(|settings| {
settings.meta_schema = None;
settings.inline_subschemas = true;
settings
.visitors
.push(Box::new(TransformToJsonSchemaSubsetVisitor));
})
.into_generator()
}
};
generator.root_schema_for::<T>()
}
#[derive(Debug, Clone)]
struct ToJsonSchemaSubsetTransform;
struct TransformToJsonSchemaSubsetVisitor;
impl Transform for ToJsonSchemaSubsetTransform {
fn transform(&mut self, schema: &mut Schema) {
impl schemars::visit::Visitor for TransformToJsonSchemaSubsetVisitor {
fn visit_root_schema(&mut self, root: &mut RootSchema) {
schemars::visit::visit_root_schema(self, root)
}
fn visit_schema(&mut self, schema: &mut Schema) {
schemars::visit::visit_schema(self, schema)
}
fn visit_schema_object(&mut self, schema: &mut SchemaObject) {
// Ensure that the type field is not an array, this happens when we use
// Option<T>, the type will be [T, "null"].
if let Some(type_field) = schema.get_mut("type") {
if let Some(types) = type_field.as_array() {
if let Some(first_type) = types.first() {
*type_field = first_type.clone();
if let Some(instance_type) = schema.instance_type.take() {
schema.instance_type = match instance_type {
schemars::schema::SingleOrVec::Single(t) => {
Some(schemars::schema::SingleOrVec::Single(t))
}
schemars::schema::SingleOrVec::Vec(items) => items
.into_iter()
.next()
.map(schemars::schema::SingleOrVec::from),
};
}
// One of is not supported, use anyOf instead.
if let Some(subschema) = schema.subschemas.as_mut() {
if let Some(one_of) = subschema.one_of.take() {
subschema.any_of = Some(one_of);
}
}
// oneOf is not supported, use anyOf instead
if let Some(one_of) = schema.remove("oneOf") {
schema.insert("anyOf".to_string(), one_of);
}
transform_subschemas(self, schema);
schemars::visit::visit_schema_object(self, schema)
}
}

View File

@@ -218,7 +218,7 @@ impl Tool for TerminalTool {
.update(cx, |project, cx| {
project.create_terminal(
TerminalKind::Task(task::SpawnInTerminal {
command: Some(program),
command: program,
args,
cwd,
env,

View File

@@ -28,17 +28,7 @@ use workspace::Workspace;
const SHOULD_SHOW_UPDATE_NOTIFICATION_KEY: &str = "auto-updater-should-show-updated-notification";
const POLL_INTERVAL: Duration = Duration::from_secs(60 * 60);
actions!(
auto_update,
[
/// Checks for available updates.
Check,
/// Dismisses the update error message.
DismissErrorMessage,
/// Opens the release notes for the current version in a browser.
ViewReleaseNotes,
]
);
actions!(auto_update, [Check, DismissErrorMessage, ViewReleaseNotes,]);
#[derive(Serialize)]
struct UpdateRequestBody {

View File

@@ -12,13 +12,7 @@ use workspace::Workspace;
use workspace::notifications::simple_message_notification::MessageNotification;
use workspace::notifications::{NotificationId, show_app_notification};
actions!(
auto_update,
[
/// Opens the release notes for the current version in a new tab.
ViewReleaseNotesLocally
]
);
actions!(auto_update, [ViewReleaseNotesLocally]);
pub fn init(cx: &mut App) {
notify_if_app_was_updated(cx);

View File

@@ -25,4 +25,5 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
workspace-hack.workspace = true

View File

@@ -1,6 +1,9 @@
mod models;
use anyhow::{Context, Error, Result, anyhow};
use std::collections::HashMap;
use std::pin::Pin;
use anyhow::{Context as _, Error, Result, anyhow};
use aws_sdk_bedrockruntime as bedrock;
pub use aws_sdk_bedrockruntime as bedrock_client;
pub use aws_sdk_bedrockruntime::types::{
@@ -21,10 +24,9 @@ pub use bedrock::types::{
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,
};
use futures::stream::{self, BoxStream};
use futures::stream::{self, BoxStream, Stream};
use serde::{Deserialize, Serialize};
use serde_json::{Number, Value};
use std::collections::HashMap;
use thiserror::Error;
pub use crate::models::*;
@@ -32,59 +34,70 @@ pub use crate::models::*;
pub async fn stream_completion(
client: bedrock::Client,
request: Request,
handle: tokio::runtime::Handle,
) -> Result<BoxStream<'static, Result<BedrockStreamingResponse, BedrockError>>, Error> {
let mut response = bedrock::Client::converse_stream(&client)
.model_id(request.model.clone())
.set_messages(request.messages.into());
handle
.spawn(async move {
let mut response = bedrock::Client::converse_stream(&client)
.model_id(request.model.clone())
.set_messages(request.messages.into());
if let Some(Thinking::Enabled {
budget_tokens: Some(budget_tokens),
}) = request.thinking
{
let thinking_config = HashMap::from([
("type".to_string(), Document::String("enabled".to_string())),
(
"budget_tokens".to_string(),
Document::Number(AwsNumber::PosInt(budget_tokens)),
),
]);
response = response.additional_model_request_fields(Document::Object(HashMap::from([(
"thinking".to_string(),
Document::from(thinking_config),
)])));
}
if let Some(Thinking::Enabled {
budget_tokens: Some(budget_tokens),
}) = request.thinking
{
response =
response.additional_model_request_fields(Document::Object(HashMap::from([(
"thinking".to_string(),
Document::from(HashMap::from([
("type".to_string(), Document::String("enabled".to_string())),
(
"budget_tokens".to_string(),
Document::Number(AwsNumber::PosInt(budget_tokens)),
),
])),
)])));
}
if request
.tools
.as_ref()
.map_or(false, |t| !t.tools.is_empty())
{
response = response.set_tool_config(request.tools);
}
if request.tools.is_some() && !request.tools.as_ref().unwrap().tools.is_empty() {
response = response.set_tool_config(request.tools);
}
let output = response
.send()
.await
.context("Failed to send API request to Bedrock");
let response = response.send().await;
let stream = Box::pin(stream::unfold(
output?.stream,
move |mut stream| async move {
match stream.recv().await {
Ok(Some(output)) => Some((Ok(output), stream)),
Ok(None) => None,
Err(err) => Some((
Err(BedrockError::ClientError(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
))),
stream,
match response {
Ok(output) => {
let stream: Pin<
Box<
dyn Stream<Item = Result<BedrockStreamingResponse, BedrockError>>
+ Send,
>,
> = Box::pin(stream::unfold(output.stream, |mut stream| async move {
match stream.recv().await {
Ok(Some(output)) => Some(({ Ok(output) }, stream)),
Ok(None) => None,
Err(err) => {
Some((
// TODO: Figure out how we can capture Throttling Exceptions
Err(BedrockError::ClientError(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
))),
stream,
))
}
}
}));
Ok(stream)
}
Err(err) => Err(anyhow!(
"{:?}",
aws_sdk_bedrockruntime::error::DisplayErrorContext(err)
)),
}
},
));
Ok(stream)
})
.await
.context("spawning a task")?
}
pub fn aws_document_to_value(document: &Document) -> Value {

View File

@@ -29,7 +29,7 @@ client.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
gpui.workspace = true
language.workspace = true
log.workspace = true
postage.workspace = true

View File

@@ -12,6 +12,7 @@ pub struct CallSettings {
/// Configuration of voice calls in Zed.
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct CallSettingsContent {
/// Whether the microphone should be muted when joining a channel or a call.
///

View File

@@ -81,17 +81,7 @@ pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
actions!(
client,
[
/// Signs in to Zed account.
SignIn,
/// Signs out of Zed account.
SignOut,
/// Reconnects to the collaboration server.
Reconnect
]
);
actions!(client, [SignIn, SignOut, Reconnect]);
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
pub struct ClientSettingsContent {

View File

@@ -35,7 +35,6 @@ dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
gpui = { workspace = true, features = ["screen-capture"] }
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true

View File

@@ -107,7 +107,7 @@ CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("proj
CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id");
CREATE TABLE "project_repositories" (
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
"project_id" INTEGER NOT NULL,
"abs_path" VARCHAR,
"id" INTEGER NOT NULL,
"entry_ids" VARCHAR,
@@ -124,7 +124,7 @@ CREATE TABLE "project_repositories" (
CREATE INDEX "index_project_repositories_on_project_id" ON "project_repositories" ("project_id");
CREATE TABLE "project_repository_statuses" (
"project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
"project_id" INTEGER NOT NULL,
"repository_id" INTEGER NOT NULL,
"repo_path" VARCHAR NOT NULL,
"status" INT8 NOT NULL,

View File

@@ -1,25 +0,0 @@
DELETE FROM project_repositories
WHERE project_id NOT IN (SELECT id FROM projects);
ALTER TABLE project_repositories
ADD CONSTRAINT fk_project_repositories_project_id
FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE project_repositories
VALIDATE CONSTRAINT fk_project_repositories_project_id;
DELETE FROM project_repository_statuses
WHERE project_id NOT IN (SELECT id FROM projects);
ALTER TABLE project_repository_statuses
ADD CONSTRAINT fk_project_repository_statuses_project_id
FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
NOT VALID;
ALTER TABLE project_repository_statuses
VALIDATE CONSTRAINT fk_project_repository_statuses_project_id;

View File

@@ -1404,9 +1404,6 @@ async fn sync_model_request_usage_with_stripe(
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
log::info!("Stripe usage sync: Starting");
let started_at = Utc::now();
let staff_users = app.db.get_staff_users().await?;
let staff_user_ids = staff_users
.iter()
@@ -1451,10 +1448,6 @@ async fn sync_model_request_usage_with_stripe(
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
let usage_meter_count = usage_meters.len();
log::info!("Stripe usage sync: Syncing {usage_meter_count} usage meters");
for (usage_meter, usage) in usage_meters {
maybe!(async {
let Some((billing_customer, billing_subscription)) =
@@ -1511,10 +1504,5 @@ async fn sync_model_request_usage_with_stripe(
.log_err();
}
log::info!(
"Stripe usage sync: Synced {usage_meter_count} usage meters in {:?}",
Utc::now() - started_at
);
Ok(())
}

View File

@@ -4,19 +4,20 @@ mod tables;
#[cfg(test)]
pub mod tests;
use crate::{Error, Result};
use crate::{Error, Result, executor::Executor};
use anyhow::{Context as _, anyhow};
use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use dashmap::DashMap;
use futures::StreamExt;
use project_repository_statuses::StatusKind;
use rand::{Rng, SeedableRng, prelude::StdRng};
use rpc::ExtensionProvides;
use rpc::{
ConnectionId, ExtensionMetadata,
proto::{self},
};
use sea_orm::{
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
entity::prelude::*,
@@ -32,6 +33,7 @@ use std::{
ops::{Deref, DerefMut},
rc::Rc,
sync::Arc,
time::Duration,
};
use time::PrimitiveDateTime;
use tokio::sync::{Mutex, OwnedMutexGuard};
@@ -56,7 +58,6 @@ pub use tables::*;
#[cfg(test)]
pub struct DatabaseTestOptions {
pub executor: gpui::BackgroundExecutor,
pub runtime: tokio::runtime::Runtime,
pub query_failure_probability: parking_lot::Mutex<f64>,
}
@@ -68,6 +69,8 @@ pub struct Database {
pool: DatabaseConnection,
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
projects: DashMap<ProjectId, Arc<Mutex<()>>>,
rng: Mutex<StdRng>,
executor: Executor,
notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
notification_kinds_by_name: HashMap<String, NotificationKindId>,
#[cfg(test)]
@@ -78,15 +81,17 @@ pub struct Database {
// separate files in the `queries` folder.
impl Database {
/// Connects to the database with the given options
pub async fn new(options: ConnectOptions) -> Result<Self> {
pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
sqlx::any::install_default_drivers();
Ok(Self {
options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
rooms: DashMap::with_capacity(16384),
projects: DashMap::with_capacity(16384),
rng: Mutex::new(StdRng::seed_from_u64(0)),
notification_kinds_by_id: HashMap::default(),
notification_kinds_by_name: HashMap::default(),
executor,
#[cfg(test)]
test_options: None,
})
@@ -102,13 +107,48 @@ impl Database {
self.projects.clear();
}
/// Transaction runs things in a transaction. If you want to call other methods
/// and pass the transaction around you need to reborrow the transaction at each
/// call site with: `&*tx`.
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
let mut i = 0;
loop {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(result),
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
self.run(body).await
}
pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let (tx, result) = self.with_weak_transaction(&f).await?;
match result {
Ok(result) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(result),
@@ -134,28 +174,44 @@ impl Database {
Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
{
let body = async {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(Some((room_id, data))) => {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(Some(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
})),
Err(error) => Err(error),
let mut i = 0;
loop {
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(Some((room_id, data))) => {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(Some(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}));
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
}
Ok(None) => match tx.commit().await.map_err(Into::into) {
Ok(()) => return Ok(None),
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
Ok(None) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(None),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
}
i += 1;
}
};
@@ -173,26 +229,38 @@ impl Database {
{
let room_id = Database::room_id_for_project(self, project_id).await?;
let body = async {
let lock = if let Some(room_id) = room_id {
self.rooms.entry(room_id).or_default().clone()
} else {
self.projects.entry(project_id).or_default().clone()
};
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
let mut i = 0;
loop {
let lock = if let Some(room_id) = room_id {
self.rooms.entry(room_id).or_default().clone()
} else {
self.projects.entry(project_id).or_default().clone()
};
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
});
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
@@ -212,22 +280,34 @@ impl Database {
Fut: Send + Future<Output = Result<T>>,
{
let body = async {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
}),
Err(error) => Err(error),
},
Err(error) => {
tx.rollback().await?;
Err(error)
let mut i = 0;
loop {
let lock = self.rooms.entry(room_id).or_default().clone();
let _guard = lock.lock_owned().await;
let (tx, result) = self.with_transaction(&f).await?;
match result {
Ok(data) => match tx.commit().await.map_err(Into::into) {
Ok(()) => {
return Ok(TransactionGuard {
data,
_guard,
_not_send: PhantomData,
});
}
Err(error) => {
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
},
Err(error) => {
tx.rollback().await?;
if !self.retry_on_serialization_error(&error, i).await {
return Err(error);
}
}
}
i += 1;
}
};
@@ -235,6 +315,28 @@ impl Database {
}
async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
{
let tx = self
.pool
.begin_with_config(Some(IsolationLevel::Serializable), None)
.await?;
let mut tx = Arc::new(Some(tx));
let result = f(TransactionHandle(tx.clone())).await;
let tx = Arc::get_mut(&mut tx)
.and_then(|tx| tx.take())
.context("couldn't complete transaction because it's still in use")?;
Ok((tx, result))
}
async fn with_weak_transaction<F, Fut, T>(
&self,
f: &F,
) -> Result<(DatabaseTransaction, Result<T>)>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
@@ -259,13 +361,13 @@ impl Database {
{
#[cfg(test)]
{
use rand::prelude::*;
let test_options = self.test_options.as_ref().unwrap();
test_options.executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if test_options.executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
if let Executor::Deterministic(executor) = &self.executor {
executor.simulate_random_delay().await;
let fail_probability = *test_options.query_failure_probability.lock();
if executor.rng().gen_bool(fail_probability) {
return Err(anyhow!("simulated query failure"))?;
}
}
test_options.runtime.block_on(future)
@@ -276,6 +378,46 @@ impl Database {
future.await
}
}
async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: usize) -> bool {
// If the error is due to a failure to serialize concurrent transactions, then retry
// this transaction after a delay. With each subsequent retry, double the delay duration.
// Also vary the delay randomly in order to ensure different database connections retry
// at different times.
const SLEEPS: [f32; 10] = [10., 20., 40., 80., 160., 320., 640., 1280., 2560., 5120.];
if is_serialization_error(error) && prev_attempt_count < SLEEPS.len() {
let base_delay = SLEEPS[prev_attempt_count];
let randomized_delay = base_delay * self.rng.lock().await.gen_range(0.5..=2.0);
log::warn!(
"retrying transaction after serialization error. delay: {} ms.",
randomized_delay
);
self.executor
.sleep(Duration::from_millis(randomized_delay as u64))
.await;
true
} else {
false
}
}
}
fn is_serialization_error(error: &Error) -> bool {
const SERIALIZATION_FAILURE_CODE: &str = "40001";
match error {
Error::Database(
DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
| DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
) if error
.as_database_error()
.and_then(|error| error.code())
.as_deref()
== Some(SERIALIZATION_FAILURE_CODE) =>
{
true
}
_ => false,
}
}
/// A handle to a [`DatabaseTransaction`].

View File

@@ -20,7 +20,7 @@ impl Database {
&self,
params: &CreateBillingCustomerParams,
) -> Result<billing_customer::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let customer = billing_customer::Entity::insert(billing_customer::ActiveModel {
user_id: ActiveValue::set(params.user_id),
stripe_customer_id: ActiveValue::set(params.stripe_customer_id.clone()),
@@ -40,7 +40,7 @@ impl Database {
id: BillingCustomerId,
params: &UpdateBillingCustomerParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_customer::Entity::update(billing_customer::ActiveModel {
id: ActiveValue::set(id),
user_id: params.user_id.clone(),
@@ -61,7 +61,7 @@ impl Database {
&self,
id: BillingCustomerId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::Id.eq(id))
.one(&*tx)
@@ -75,7 +75,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::UserId.eq(user_id))
.one(&*tx)
@@ -89,7 +89,7 @@ impl Database {
&self,
stripe_customer_id: &str,
) -> Result<Option<billing_customer::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_customer::Entity::find()
.filter(billing_customer::Column::StripeCustomerId.eq(stripe_customer_id))
.one(&*tx)

View File

@@ -22,7 +22,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_preference::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_preference::Entity::find()
.filter(billing_preference::Column::UserId.eq(user_id))
.one(&*tx)
@@ -37,7 +37,7 @@ impl Database {
user_id: UserId,
params: &CreateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
user_id: ActiveValue::set(user_id),
max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
@@ -65,7 +65,7 @@ impl Database {
user_id: UserId,
params: &UpdateBillingPreferencesParams,
) -> Result<billing_preference::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let preferences = billing_preference::Entity::update_many()
.set(billing_preference::ActiveModel {
max_monthly_llm_usage_spending_in_cents: params

View File

@@ -35,7 +35,7 @@ impl Database {
&self,
params: &CreateBillingSubscriptionParams,
) -> Result<billing_subscription::Model> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
billing_customer_id: ActiveValue::set(params.billing_customer_id),
kind: ActiveValue::set(params.kind),
@@ -64,7 +64,7 @@ impl Database {
id: BillingSubscriptionId,
params: &UpdateBillingSubscriptionParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
billing_subscription::Entity::update(billing_subscription::ActiveModel {
id: ActiveValue::set(id),
billing_customer_id: params.billing_customer_id.clone(),
@@ -90,7 +90,7 @@ impl Database {
&self,
id: BillingSubscriptionId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find_by_id(id)
.one(&*tx)
.await?)
@@ -103,7 +103,7 @@ impl Database {
&self,
stripe_subscription_id: &str,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.filter(
billing_subscription::Column::StripeSubscriptionId.eq(stripe_subscription_id),
@@ -118,7 +118,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -152,7 +152,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Vec<billing_subscription::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let subscriptions = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(billing_customer::Column::UserId.eq(user_id))
@@ -169,7 +169,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -201,7 +201,7 @@ impl Database {
&self,
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
self.weak_transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
@@ -236,7 +236,7 @@ impl Database {
/// Returns the count of the active billing subscriptions for the user with the specified ID.
pub async fn count_active_billing_subscriptions(&self, user_id: UserId) -> Result<usize> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let count = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.filter(

View File

@@ -501,8 +501,10 @@ impl Database {
/// Returns all channels for the user with the given ID.
pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
self.transaction(|tx| async move { self.get_user_channels(user_id, None, true, &tx).await })
.await
self.weak_transaction(
|tx| async move { self.get_user_channels(user_id, None, true, &tx).await },
)
.await
}
/// Returns all channels for the user with the given ID that are descendants

View File

@@ -15,7 +15,7 @@ impl Database {
user_b_busy: bool,
}
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user_a_participant = Alias::new("user_a_participant");
let user_b_participant = Alias::new("user_b_participant");
let mut db_contacts = contact::Entity::find()
@@ -91,7 +91,7 @@ impl Database {
/// Returns whether the given user is a busy (on a call).
pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let participant = room_participant::Entity::find()
.filter(room_participant::Column::UserId.eq(user_id))
.one(&*tx)

View File

@@ -9,7 +9,7 @@ pub enum ContributorSelector {
impl Database {
/// Retrieves the GitHub logins of all users who have signed the CLA.
pub async fn get_contributors(&self) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryGithubLogin {
GithubLogin,
@@ -32,7 +32,7 @@ impl Database {
&self,
selector: &ContributorSelector,
) -> Result<Option<DateTime>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = match selector {
ContributorSelector::GitHubUserId { github_user_id } => {
user::Column::GithubUserId.eq(*github_user_id)
@@ -69,7 +69,7 @@ impl Database {
github_user_created_at: DateTimeUtc,
initial_channel_id: Option<ChannelId>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let user = self
.update_or_create_user_by_github_account_tx(
github_login,

View File

@@ -8,7 +8,7 @@ impl Database {
model: &str,
digests: &[Vec<u8>],
) -> Result<HashMap<Vec<u8>, Vec<f32>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let embeddings = {
let mut db_embeddings = embedding::Entity::find()
.filter(
@@ -52,7 +52,7 @@ impl Database {
model: &str,
embeddings: &HashMap<Vec<u8>, Vec<f32>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| {
let now_offset_datetime = OffsetDateTime::now_utc();
let retrieved_at =
@@ -78,7 +78,7 @@ impl Database {
}
pub async fn purge_old_embeddings(&self) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
embedding::Entity::delete_many()
.filter(
embedding::Column::RetrievedAt

View File

@@ -15,7 +15,7 @@ impl Database {
max_schema_version: i32,
limit: usize,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut condition = Condition::all()
.add(
extension::Column::LatestVersion
@@ -43,7 +43,7 @@ impl Database {
ids: &[&str],
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extensions = extension::Entity::find()
.filter(extension::Column::ExternalId.is_in(ids.iter().copied()))
.all(&*tx)
@@ -123,7 +123,7 @@ impl Database {
&self,
extension_id: &str,
) -> Result<Vec<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let condition = extension::Column::ExternalId
.eq(extension_id)
.into_condition();
@@ -162,7 +162,7 @@ impl Database {
extension_id: &str,
constraints: Option<&ExtensionVersionConstraints>,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.one(&*tx)
@@ -187,7 +187,7 @@ impl Database {
extension_id: &str,
version: &str,
) -> Result<Option<ExtensionMetadata>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let extension = extension::Entity::find()
.filter(extension::Column::ExternalId.eq(extension_id))
.filter(extension_version::Column::Version.eq(version))
@@ -204,7 +204,7 @@ impl Database {
}
pub async fn get_known_extension_versions(&self) -> Result<HashMap<String, Vec<String>>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let mut extension_external_ids_by_id = HashMap::default();
let mut rows = extension::Entity::find().stream(&*tx).await?;
@@ -242,7 +242,7 @@ impl Database {
&self,
versions_by_extension_id: &HashMap<&str, Vec<NewExtensionVersion>>,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
for (external_id, versions) in versions_by_extension_id {
if versions.is_empty() {
continue;
@@ -349,7 +349,7 @@ impl Database {
}
pub async fn record_extension_download(&self, extension: &str, version: &str) -> Result<bool> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryId {
Id,

View File

@@ -13,7 +13,7 @@ impl Database {
&self,
params: &CreateProcessedStripeEventParams,
) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
processed_stripe_event::Entity::insert(processed_stripe_event::ActiveModel {
stripe_event_id: ActiveValue::set(params.stripe_event_id.clone()),
stripe_event_type: ActiveValue::set(params.stripe_event_type.clone()),
@@ -35,7 +35,7 @@ impl Database {
&self,
event_id: &str,
) -> Result<Option<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find_by_id(event_id)
.one(&*tx)
.await?)
@@ -48,7 +48,7 @@ impl Database {
&self,
event_ids: &[&str],
) -> Result<Vec<processed_stripe_event::Model>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
Ok(processed_stripe_event::Entity::find()
.filter(
processed_stripe_event::Column::StripeEventId.is_in(event_ids.iter().copied()),

View File

@@ -112,7 +112,7 @@ impl Database {
}
pub async fn delete_project(&self, project_id: ProjectId) -> Result<()> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
project::Entity::delete_by_id(project_id).exec(&*tx).await?;
Ok(())
})

View File

@@ -80,7 +80,7 @@ impl Database {
&self,
user_id: UserId,
) -> Result<Option<proto::IncomingCall>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
let pending_participant = room_participant::Entity::find()
.filter(
room_participant::Column::UserId

View File

@@ -142,50 +142,6 @@ impl Database {
}
}
loop {
let delete_query = Query::delete()
.from_table(project_repository_statuses::Entity)
.and_where(
Expr::tuple([Expr::col((
project_repository_statuses::Entity,
project_repository_statuses::Column::ProjectId,
))
.into()])
.in_subquery(
Query::select()
.columns([(
project_repository_statuses::Entity,
project_repository_statuses::Column::ProjectId,
)])
.from(project_repository_statuses::Entity)
.inner_join(
project::Entity,
Expr::col((project::Entity, project::Column::Id)).equals((
project_repository_statuses::Entity,
project_repository_statuses::Column::ProjectId,
)),
)
.and_where(project::Column::HostConnectionServerId.ne(server_id))
.limit(10000)
.to_owned(),
),
)
.to_owned();
let statement = Statement::from_sql_and_values(
tx.get_database_backend(),
delete_query
.to_string(sea_orm::sea_query::PostgresQueryBuilder)
.as_str(),
vec![],
);
let result = tx.execute(statement).await?;
if result.rows_affected() == 0 {
break;
}
}
Ok(())
})
.await

View File

@@ -382,7 +382,7 @@ impl Database {
/// Returns the active flags for the user.
pub async fn get_user_flags(&self, user: UserId) -> Result<Vec<String>> {
self.transaction(|tx| async move {
self.weak_transaction(|tx| async move {
#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
enum QueryAs {
Flag,

View File

@@ -17,15 +17,11 @@ use crate::migrations::run_database_migrations;
use super::*;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;
use rand::prelude::*;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
use std::{
sync::{
Arc,
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
},
time::Duration,
use std::sync::{
Arc,
atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
};
pub struct TestDb {
@@ -45,7 +41,9 @@ impl TestDb {
let mut db = runtime.block_on(async {
let mut options = ConnectOptions::new(url);
options.max_connections(5);
let mut db = Database::new(options).await.unwrap();
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
.await
.unwrap();
let sql = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations.sqlite/20221109000000_test_schema.sql"
@@ -62,7 +60,6 @@ impl TestDb {
});
db.test_options = Some(DatabaseTestOptions {
executor,
runtime,
query_failure_probability: parking_lot::Mutex::new(0.0),
});
@@ -96,7 +93,9 @@ impl TestDb {
options
.max_connections(5)
.idle_timeout(Duration::from_secs(0));
let mut db = Database::new(options).await.unwrap();
let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
run_database_migrations(db.options(), migrations_path)
.await
@@ -106,7 +105,6 @@ impl TestDb {
});
db.test_options = Some(DatabaseTestOptions {
executor,
runtime,
query_failure_probability: parking_lot::Mutex::new(0.0),
});

View File

@@ -49,7 +49,7 @@ async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
db.save_embeddings(model, &embeddings).await.unwrap();
// Reach into the DB and change the retrieved at to be > 60 days
db.transaction(|tx| {
db.weak_transaction(|tx| {
let digest = digest.clone();
async move {
let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));

View File

@@ -285,7 +285,7 @@ impl AppState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
db_options.max_connections(config.database_max_connections);
let mut db = Database::new(db_options).await?;
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config

View File

@@ -59,7 +59,7 @@ async fn main() -> Result<()> {
let config = envy::from_env::<Config>().expect("error loading config");
let db_options = db::ConnectOptions::new(config.database_url.clone());
let mut db = Database::new(db_options).await?;
let mut db = Database::new(db_options, Executor::Production).await?;
db.initialize_notification_kinds().await?;
collab::seed::seed(&config, &db, false).await?;
@@ -253,7 +253,7 @@ async fn main() -> Result<()> {
async fn setup_app_database(config: &Config) -> Result<()> {
let db_options = db::ConnectOptions::new(config.database_url.clone());
let mut db = Database::new(db_options).await?;
let mut db = Database::new(db_options, Executor::Production).await?;
let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
#[cfg(feature = "sqlite")]

View File

@@ -4591,13 +4591,14 @@ async fn test_formatting_buffer(
cx_a.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
Formatter::External {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::External {
command: "awk".into(),
arguments: Some(
vec!["{sub(/two/,\"{buffer_path}\")}1".to_string()].into(),
),
},
}]
.into(),
)));
});
});
@@ -4698,8 +4699,8 @@ async fn test_prettier_formatting_buffer(
cx_b.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
Formatter::LanguageServer { name: None },
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::LanguageServer { name: None }].into(),
)));
file.defaults.prettier = Some(PrettierSettings {
allowed: true,
@@ -4821,7 +4822,7 @@ async fn test_definition(
);
let definitions_1 = project_b
.update(cx_b, |p, cx| p.definitions(&buffer_b, 23, cx))
.update(cx_b, |p, cx| p.definition(&buffer_b, 23, cx))
.await
.unwrap();
cx_b.read(|cx| {
@@ -4852,7 +4853,7 @@ async fn test_definition(
);
let definitions_2 = project_b
.update(cx_b, |p, cx| p.definitions(&buffer_b, 33, cx))
.update(cx_b, |p, cx| p.definition(&buffer_b, 33, cx))
.await
.unwrap();
cx_b.read(|cx| {
@@ -4889,7 +4890,7 @@ async fn test_definition(
);
let type_definitions = project_b
.update(cx_b, |p, cx| p.type_definitions(&buffer_b, 7, cx))
.update(cx_b, |p, cx| p.type_definition(&buffer_b, 7, cx))
.await
.unwrap();
cx_b.read(|cx| {
@@ -5057,7 +5058,7 @@ async fn test_references(
lsp_response_tx
.unbounded_send(Err(anyhow!("can't find references")))
.unwrap();
assert_eq!(references.await.unwrap(), []);
references.await.unwrap_err();
// User is informed that the request is no longer pending.
executor.run_until_parked();
@@ -5641,7 +5642,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
let definitions;
let buffer_b2;
if rng.r#gen() {
definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx));
definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
(buffer_b2, _) = project_b
.update(cx_b, |p, cx| {
p.open_buffer_with_lsp((worktree_id, "b.rs"), cx)
@@ -5655,7 +5656,7 @@ async fn test_open_buffer_while_getting_definition_pointing_to_it(
})
.await
.unwrap();
definitions = project_b.update(cx_b, |p, cx| p.definitions(&buffer_b1, 23, cx));
definitions = project_b.update(cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
}
let definitions = definitions.await.unwrap();

View File

@@ -838,7 +838,7 @@ impl RandomizedTest for ProjectCollaborationTest {
.map(|_| Ok(()))
.boxed(),
LspRequestKind::Definition => project
.definitions(&buffer, offset, cx)
.definition(&buffer, offset, cx)
.map_ok(|_| ())
.boxed(),
LspRequestKind::Highlights => project

View File

@@ -505,8 +505,8 @@ async fn test_ssh_collaboration_formatting_with_prettier(
cx_b.update(|cx| {
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |file| {
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList::Single(
Formatter::LanguageServer { name: None },
file.defaults.formatter = Some(SelectedFormatter::List(FormatterList(
vec![Formatter::LanguageServer { name: None }].into(),
)));
file.defaults.prettier = Some(PrettierSettings {
allowed: true,

View File

@@ -30,13 +30,7 @@ use workspace::{
};
use workspace::{item::Dedup, notifications::NotificationId};
actions!(
collab,
[
/// Copies a link to the current position in the channel buffer.
CopyLink
]
);
actions!(collab, [CopyLink]);
pub fn init(cx: &mut App) {
workspace::FollowableViewRegistry::register::<ChannelView>(cx)

View File

@@ -71,13 +71,7 @@ struct SerializedChatPanel {
width: Option<Pixels>,
}
actions!(
chat_panel,
[
/// Toggles focus on the chat panel.
ToggleFocus
]
);
actions!(chat_panel, [ToggleFocus]);
impl ChatPanel {
pub fn new(

View File

@@ -44,25 +44,15 @@ use workspace::{
actions!(
collab_panel,
[
/// Toggles focus on the collaboration panel.
ToggleFocus,
/// Removes the selected channel or contact.
Remove,
/// Opens the context menu for the selected item.
Secondary,
/// Collapses the selected channel in the tree view.
CollapseSelectedChannel,
/// Expands the selected channel in the tree view.
ExpandSelectedChannel,
/// Starts moving a channel to a new location.
StartMoveChannel,
/// Moves the selected item to the current location.
MoveSelected,
/// Inserts a space character in the filter input.
InsertSpace,
/// Moves the selected channel up in the list.
MoveChannelUp,
/// Moves the selected channel down in the list.
MoveChannelDown,
]
);

View File

@@ -17,13 +17,9 @@ use workspace::{ModalView, notifications::DetachAndPromptErr};
actions!(
channel_modal,
[
/// Selects the next control in the channel modal.
SelectNextControl,
/// Toggles between invite members and manage members mode.
ToggleMode,
/// Toggles admin status for the selected member.
ToggleMemberAdmin,
/// Removes the selected member from the channel.
RemoveMember
]
);

View File

@@ -74,13 +74,7 @@ pub struct NotificationPresenter {
pub can_navigate: bool,
}
actions!(
notification_panel,
[
/// Toggles focus on the notification panel.
ToggleFocus
]
);
actions!(notification_panel, [ToggleFocus]);
pub fn init(cx: &mut App) {
cx.observe_new(|workspace: &mut Workspace, _, _| {

View File

@@ -28,6 +28,7 @@ pub struct ChatPanelSettings {
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct ChatPanelSettingsContent {
/// When to show the panel button in the status bar.
///
@@ -51,6 +52,7 @@ pub struct NotificationPanelSettings {
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct PanelSettingsContent {
/// Whether to show the panel button in the status bar.
///
@@ -67,6 +69,7 @@ pub struct PanelSettingsContent {
}
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
#[schemars(deny_unknown_fields)]
pub struct MessageEditorSettings {
/// Whether to automatically replace emoji shortcodes with emoji characters.
/// For example: typing `:wave:` gets replaced with `👋`.

View File

@@ -41,7 +41,7 @@ pub struct CommandPalette {
/// Removes subsequent whitespace characters and double colons from the query.
///
/// This improves the likelihood of a match by either humanized name or keymap-style name.
pub fn normalize_action_query(input: &str) -> String {
fn normalize_query(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut last_char = None;
@@ -297,7 +297,7 @@ impl PickerDelegate for CommandPaletteDelegate {
let mut commands = self.all_commands.clone();
let hit_counts = self.hit_counts();
let executor = cx.background_executor().clone();
let query = normalize_action_query(query.as_str());
let query = normalize_query(query.as_str());
async move {
commands.sort_by_key(|action| {
(
@@ -311,17 +311,29 @@ impl PickerDelegate for CommandPaletteDelegate {
.enumerate()
.map(|(ix, command)| StringMatchCandidate::new(ix, &command.name))
.collect::<Vec<_>>();
let matches = fuzzy::match_strings(
&candidates,
&query,
true,
true,
10000,
&Default::default(),
executor,
)
.await;
let matches = if query.is_empty() {
candidates
.into_iter()
.enumerate()
.map(|(index, candidate)| StringMatch {
candidate_id: index,
string: candidate.string,
positions: Vec::new(),
score: 0.0,
})
.collect()
} else {
fuzzy::match_strings(
&candidates,
&query,
true,
true,
10000,
&Default::default(),
executor,
)
.await
};
tx.send((commands, matches)).await.log_err();
}
@@ -410,8 +422,8 @@ impl PickerDelegate for CommandPaletteDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let matching_command = self.matches.get(ix)?;
let command = self.commands.get(matching_command.candidate_id)?;
let r#match = self.matches.get(ix)?;
let command = self.commands.get(r#match.candidate_id)?;
Some(
ListItem::new(ix)
.inset(true)
@@ -424,7 +436,7 @@ impl PickerDelegate for CommandPaletteDelegate {
.justify_between()
.child(HighlightedLabel::new(
command.name.clone(),
matching_command.positions.clone(),
r#match.positions.clone(),
))
.children(KeyBinding::for_action_in(
&*command.action,
@@ -500,28 +512,19 @@ mod tests {
#[test]
fn test_normalize_query() {
assert_eq!(normalize_query("editor: backspace"), "editor: backspace");
assert_eq!(normalize_query("editor: backspace"), "editor: backspace");
assert_eq!(normalize_query("editor: backspace"), "editor: backspace");
assert_eq!(
normalize_action_query("editor: backspace"),
"editor: backspace"
);
assert_eq!(
normalize_action_query("editor: backspace"),
"editor: backspace"
);
assert_eq!(
normalize_action_query("editor: backspace"),
"editor: backspace"
);
assert_eq!(
normalize_action_query("editor::GoToDefinition"),
normalize_query("editor::GoToDefinition"),
"editor:GoToDefinition"
);
assert_eq!(
normalize_action_query("editor::::GoToDefinition"),
normalize_query("editor::::GoToDefinition"),
"editor:GoToDefinition"
);
assert_eq!(
normalize_action_query("editor: :GoToDefinition"),
normalize_query("editor: :GoToDefinition"),
"editor: :GoToDefinition"
);
}

View File

@@ -61,7 +61,7 @@ impl RenderOnce for ComponentExample {
12.0,
12.0,
))
.shadow_xs()
.shadow_sm()
.child(self.element),
)
.into_any_element()

View File

@@ -46,17 +46,11 @@ pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_an
actions!(
copilot,
[
/// Requests a code completion suggestion from Copilot.
Suggest,
/// Cycles to the next Copilot suggestion.
NextSuggestion,
/// Cycles to the previous Copilot suggestion.
PreviousSuggestion,
/// Reinstalls the Copilot language server.
Reinstall,
/// Signs in to GitHub Copilot.
SignIn,
/// Signs out of GitHub Copilot.
SignOut
]
);

View File

@@ -528,7 +528,6 @@ impl CopilotChat {
pub async fn stream_completion(
request: Request,
is_user_initiated: bool,
mut cx: AsyncApp,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let this = cx
@@ -563,14 +562,7 @@ impl CopilotChat {
};
let api_url = configuration.api_url_from_endpoint(&token.api_endpoint);
stream_completion(
client.clone(),
token.api_key,
api_url.into(),
request,
is_user_initiated,
)
.await
stream_completion(client.clone(), token.api_key, api_url.into(), request).await
}
pub fn set_configuration(
@@ -705,7 +697,6 @@ async fn stream_completion(
api_key: String,
completion_url: Arc<str>,
request: Request,
is_user_initiated: bool,
) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
let is_vision_request = request.messages.iter().any(|message| match message {
ChatMessage::User { content }
@@ -716,8 +707,6 @@ async fn stream_completion(
_ => false,
});
let request_initiator = if is_user_initiated { "user" } else { "agent" };
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(completion_url.as_ref())
@@ -730,8 +719,7 @@ async fn stream_completion(
)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.header("Copilot-Integration-Id", "vscode-chat")
.header("X-Initiator", request_initiator);
.header("Copilot-Integration-Id", "vscode-chat");
if is_vision_request {
request_builder =

View File

@@ -38,6 +38,7 @@ language.workspace = true
log.workspace = true
node_runtime.workspace = true
parking_lot.workspace = true
paths.workspace = true
proto.workspace = true
schemars.workspace = true
serde.workspace = true

View File

@@ -10,12 +10,11 @@ use gpui::{AsyncApp, SharedString};
pub use http_client::{HttpClient, github::latest_github_release};
use language::{LanguageName, LanguageToolchainStore};
use node_runtime::NodeRuntime;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::WorktreeId;
use smol::fs::File;
use std::{
borrow::{Borrow, Cow},
borrow::Borrow,
ffi::OsStr,
fmt::Debug,
net::Ipv4Addr,
@@ -48,10 +47,7 @@ pub trait DapDelegate: Send + Sync + 'static {
async fn shell_env(&self) -> collections::HashMap<String, String>;
}
#[derive(
Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize, JsonSchema,
)]
#[serde(transparent)]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub struct DebugAdapterName(pub SharedString);
impl Deref for DebugAdapterName {
@@ -265,13 +261,12 @@ pub struct GithubRepo {
}
pub async fn download_adapter_from_github(
adapter_name: &str,
adapter_name: DebugAdapterName,
github_version: AdapterVersion,
file_type: DownloadedFileType,
base_path: &Path,
delegate: &dyn DapDelegate,
) -> Result<PathBuf> {
let adapter_path = base_path.join(adapter_name);
let adapter_path = paths::debug_adapters_dir().join(&adapter_name.as_ref());
let version_path = adapter_path.join(format!("{}_{}", adapter_name, github_version.tag_name));
let fs = delegate.fs();
@@ -285,7 +280,11 @@ pub async fn download_adapter_from_github(
.context("Failed creating adapter path")?;
}
log::debug!("Downloading {} from {}", adapter_name, &github_version.url);
log::debug!(
"Downloading adapter {} from {}",
adapter_name,
&github_version.url,
);
delegate.output_to_console(format!("Downloading from {}...", github_version.url));
let mut response = delegate
@@ -370,7 +369,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
}
}
fn dap_schema(&self) -> Cow<'static, serde_json::Value>;
fn dap_schema(&self) -> serde_json::Value;
fn label_for_child_session(&self, _args: &StartDebuggingRequestArguments) -> Option<String> {
None
@@ -396,8 +395,8 @@ impl DebugAdapter for FakeAdapter {
DebugAdapterName(Self::ADAPTER_NAME.into())
}
fn dap_schema(&self) -> Cow<'static, serde_json::Value> {
Cow::Owned(serde_json::Value::Null)
fn dap_schema(&self) -> serde_json::Value {
serde_json::Value::Null
}
async fn request_kind(

View File

@@ -4,10 +4,12 @@ use collections::FxHashMap;
use gpui::{App, Global, SharedString};
use language::LanguageName;
use parking_lot::RwLock;
use task::{DebugRequest, DebugScenario, SpawnInTerminal, TaskTemplate};
use task::{
AdapterSchema, AdapterSchemas, DebugRequest, DebugScenario, SpawnInTerminal, TaskTemplate,
};
use crate::adapters::{DebugAdapter, DebugAdapterName};
use std::{borrow::Cow, collections::BTreeMap, sync::Arc};
use std::{collections::BTreeMap, sync::Arc};
/// Given a user build configuration, locator creates a fill-in debug target ([DebugScenario]) on behalf of the user.
#[async_trait]
@@ -61,13 +63,19 @@ impl DapRegistry {
.and_then(|adapter| adapter.adapter_language_name())
}
pub async fn adapter_schemas(&self) -> Vec<(SharedString, Cow<'static, serde_json::Value>)> {
self.0
.read()
.adapters
.iter()
.map(|(name, adapter)| (name.0.clone(), adapter.dap_schema()))
.collect()
pub async fn adapters_schema(&self) -> task::AdapterSchemas {
let mut schemas = AdapterSchemas(vec![]);
let adapters = self.0.read().adapters.clone();
for (name, adapter) in adapters.into_iter() {
schemas.0.push(AdapterSchema {
adapter: name.into(),
schema: adapter.dap_schema(),
});
}
schemas
}
pub fn locators(&self) -> FxHashMap<SharedString, Arc<dyn DapLocator>> {

View File

@@ -12,12 +12,6 @@ test-support = [
"task/test-support",
"util/test-support",
]
update-schemas = [
"dep:node_runtime",
"dep:reqwest_client",
"dep:settings",
"dep:tempfile",
]
[lints]
workspace = true
@@ -26,19 +20,12 @@ workspace = true
path = "src/dap_adapters.rs"
doctest = false
[[bin]]
name = "update-schemas"
path = "src/update_schemas.rs"
required-features = ["update-schemas"]
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
collections.workspace = true
dap.workspace = true
dotenvy.workspace = true
futures.workspace = true
fs.workspace = true
gpui.workspace = true
json_dotpath.workspace = true
language.workspace = true
@@ -51,11 +38,6 @@ task.workspace = true
util.workspace = true
workspace-hack.workspace = true
node_runtime = { workspace = true, optional = true }
reqwest_client = { workspace = true, optional = true }
settings = { workspace = true, optional = true }
tempfile = { workspace = true, optional = true }
[dev-dependencies]
dap = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -1,253 +0,0 @@
{
"allOf": [
{
"if": { "properties": { "request": { "const": "launch" } }, "required": ["request"] },
"then": {
"properties": {
"program": { "description": "Path to the program to debug.", "type": "string" },
"cargo": {
"description": "Cargo invocation parameters.",
"type": "object",
"properties": {
"args": { "description": "Cargo command line arguments.", "type": "array", "default": [] },
"env": {
"description": "Additional environment variables passed to Cargo.",
"type": "object",
"patternProperties": { ".*": { "type": "string" } },
"default": {}
},
"cwd": { "description": "Cargo working directory.", "type": "string" },
"problemMatcher": {
"description": "Problem matcher(s) to apply to Cargo output.",
"type": ["string", "array"]
},
"filter": {
"description": "Filter applied to compilation artifacts.",
"type": "object",
"properties": { "name": { "type": "string" }, "kind": { "type": "string" } }
}
},
"required": ["args"],
"additionalProperties": false,
"defaultSnippets": [
{
"label": "Library unit tests",
"body": { "args": ["test", "--no-run", "--lib"], "filter": { "kind": "lib" } }
},
{ "label": "Executable", "body": { "args": ["build", "--bin=${1:<name>}"], "filter": { "kind": "bin" } } }
]
},
"args": { "description": "Program arguments.", "type": ["array", "string"] },
"cwd": { "description": "Program working directory.", "type": "string" },
"env": {
"description": "Additional environment variables.",
"type": "object",
"patternProperties": { ".*": { "type": "string" } }
},
"envFile": { "description": "File to read the environment variables from.", "type": "string" },
"stdio": {
"description": "Destination for stdio streams: null = send to debugger console or a terminal, \"<path>\" = attach to a file/tty/fifo.",
"type": ["null", "string", "array", "object"],
"default": null
},
"terminal": {
"description": "Terminal type to use.",
"type": "string",
"enum": ["integrated", "external", "console"],
"enumDescriptions": [
"Use integrated terminal in VSCode.",
"Use external terminal window.",
"Use VScode Debug Console for stdout and stderr. Stdin will be unavailable."
],
"default": "integrated"
},
"console": {
"description": "Terminal type to use. (This setting is a compatibility alias of 'terminal'.)",
"type": "string",
"enum": ["integratedTerminal", "externalTerminal", "internalConsole"],
"enumDescriptions": [
"Use integrated terminal in VSCode.",
"Use external terminal window.",
"Use VScode Debug Console for stdout and stderr. Stdin will be unavailable."
]
},
"stopOnEntry": {
"description": "Automatically stop debuggee after launch.",
"type": "boolean",
"default": false
},
"initCommands": {
"description": "Initialization commands executed upon debugger startup.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"targetCreateCommands": {
"description": "Commands that create the debug target.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"preRunCommands": {
"description": "Commands executed just before the program is launched.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"processCreateCommands": {
"description": "Commands that create the debuggee process.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"postRunCommands": {
"description": "Commands executed just after the program has been launched.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"preTerminateCommands": {
"description": "Commands executed just before the debuggee is terminated or disconnected from.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"exitCommands": {
"description": "Commands executed at the end of debugging session.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"expressions": {
"description": "The default evaluator type used for expressions.",
"type": "string",
"enum": ["simple", "python", "native"]
},
"sourceMap": {
"description": "Source path remapping between the build machine and the local machine. Each item is a pair of remote and local path prefixes.",
"type": "object",
"patternProperties": { ".*": { "type": ["string", "null"] } },
"default": {}
},
"relativePathBase": {
"description": "Base directory used for resolution of relative source paths. Defaults to \"${ZED_WORKTREE_ROOT}\".",
"type": "string"
},
"sourceLanguages": {
"description": "A list of source languages to enable language-specific features for.",
"type": "array",
"default": []
},
"reverseDebugging": {
"description": "Enable reverse debugging (Requires reverse execution support in the debug server, see User's Manual for details).",
"type": "boolean",
"default": false
},
"breakpointMode": {
"description": "Specifies how source breakpoints should be set.",
"type": "string",
"enum": ["path", "file"],
"enumDescriptions": [
"Resolve locations using full source file path.",
"Resolve locations using file name only."
]
}
},
"anyOf": [{ "required": ["program"] }, { "required": ["targetCreateCommands"] }, { "required": ["cargo"] }]
}
},
{
"if": { "properties": { "request": { "const": "attach" } }, "required": ["request"] },
"then": {
"properties": {
"program": { "description": "Path to the program to attach to.", "type": "string" },
"pid": {
"description": "Process id to attach to.",
"type": ["integer", "string"],
"default": "${command:pickMyProcess}"
},
"stopOnEntry": {
"description": "Automatically stop debuggee after attach.",
"type": "boolean",
"default": false
},
"waitFor": {
"description": "Wait for the process to launch (MacOS only).",
"type": "boolean",
"default": false
},
"initCommands": {
"description": "Initialization commands executed upon debugger startup.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"targetCreateCommands": {
"description": "Commands that create the debug target.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"preRunCommands": {
"description": "Commands executed just before the program is launched.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"processCreateCommands": {
"description": "Commands that create the debuggee process.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"postRunCommands": {
"description": "Commands executed just after the program has been launched.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"exitCommands": {
"description": "Commands executed at the end of debugging session.",
"type": "array",
"items": { "type": "string" },
"default": []
},
"expressions": {
"description": "The default evaluator type used for expressions.",
"type": "string",
"enum": ["simple", "python", "native"]
},
"sourceMap": {
"description": "Source path remapping between the build machine and the local machine. Each item is a pair of remote and local path prefixes.",
"type": "object",
"patternProperties": { ".*": { "type": ["string", "null"] } },
"default": {}
},
"relativePathBase": {
"description": "Base directory used for resolution of relative source paths. Defaults to \"${ZED_WORKTREE_ROOT}\".",
"type": "string"
},
"sourceLanguages": {
"description": "A list of source languages to enable language-specific features for.",
"type": "array",
"default": []
},
"reverseDebugging": {
"description": "Enable reverse debugging (Requires reverse execution support in the debug server, see User's Manual for details).",
"type": "boolean",
"default": false
},
"breakpointMode": {
"description": "Specifies how source breakpoints should be set.",
"type": "string",
"enum": ["path", "file"],
"enumDescriptions": [
"Resolve locations using full source file path.",
"Resolve locations using file name only."
]
}
}
}
}
]
}

View File

@@ -1,281 +0,0 @@
{
"allOf": [
{
"if": { "properties": { "request": { "const": "launch" } }, "required": ["request"] },
"then": {
"properties": {
"args": {
"default": [],
"description": "Command line arguments passed to the program. For string type arguments, it will pass through the shell as is, and therefore all shell variable expansions will apply. But for the array type, the values will be shell-escaped.",
"items": { "type": "string" },
"anyOf": [
{ "default": "${command:pickArgs}", "enum": ["${command:pickArgs}"] },
{ "type": ["array", "string"] }
]
},
"autoReload": {
"default": {},
"description": "Configures automatic reload of code on edit.",
"properties": {
"enable": { "default": false, "description": "Automatically reload code on edit.", "type": "boolean" },
"exclude": {
"default": [
"**/.git/**",
"**/.metadata/**",
"**/__pycache__/**",
"**/node_modules/**",
"**/site-packages/**"
],
"description": "Glob patterns of paths to exclude from auto reload.",
"items": { "type": "string" },
"type": "array"
},
"include": {
"default": ["**/*.py", "**/*.pyw"],
"description": "Glob patterns of paths to include in auto reload.",
"items": { "type": "string" },
"type": "array"
}
},
"type": "object"
},
"console": {
"default": "integratedTerminal",
"description": "Where to launch the debug target: internal console, integrated terminal, or external terminal.",
"enum": ["externalTerminal", "integratedTerminal", "internalConsole"]
},
"cwd": {
"default": "${ZED_WORKTREE_ROOT}",
"description": "Absolute path to the working directory of the program being debugged. Default is the root directory of the file (leave empty).",
"type": "string"
},
"debugAdapterPath": {
"description": "Path (fully qualified) to the Python debug adapter executable.",
"type": "string"
},
"autoStartBrowser": {
"default": false,
"description": "Open external browser to launch the application",
"type": "boolean"
},
"django": { "default": false, "description": "Django debugging.", "type": "boolean" },
"env": {
"additionalProperties": { "type": "string" },
"default": {},
"description": "Environment variables defined as a key value pair. Property ends up being the Environment Variable and the value of the property ends up being the value of the Env Variable.",
"type": "object"
},
"envFile": {
"default": "${ZED_WORKTREE_ROOT}/.env",
"description": "Absolute path to a file containing environment variable definitions.",
"type": "string"
},
"gevent": {
"default": false,
"description": "Enable debugging of gevent monkey-patched code.",
"type": "boolean"
},
"jinja": {
"default": null,
"description": "Jinja template debugging (e.g. Flask).",
"enum": [false, null, true]
},
"justMyCode": { "default": true, "description": "Debug only user-written code.", "type": "boolean" },
"logToFile": {
"default": false,
"description": "Enable logging of debugger events to a log file. This file can be found in the debugpy extension install folder.",
"type": "boolean"
},
"module": { "default": "", "description": "Name of the module to be debugged.", "type": "string" },
"pathMappings": {
"default": [],
"items": {
"label": "Path mapping",
"properties": {
"localRoot": { "default": "${ZED_WORKTREE_ROOT}", "label": "Local source root.", "type": "string" },
"remoteRoot": { "default": "", "label": "Remote source root.", "type": "string" }
},
"required": ["localRoot", "remoteRoot"],
"type": "object"
},
"label": "Path mappings.",
"type": "array"
},
"program": { "default": "${file}", "description": "Absolute path to the program.", "type": "string" },
"purpose": {
"default": [],
"description": "Tells extension to use this configuration for test debugging, or when using debug-in-terminal command.",
"items": {
"enum": ["debug-test", "debug-in-terminal"],
"enumDescriptions": [
"Use this configuration while debugging tests using test view or test debug commands.",
"Use this configuration while debugging a file using debug in terminal button in the editor."
]
},
"type": "array"
},
"pyramid": { "default": false, "description": "Whether debugging Pyramid applications.", "type": "boolean" },
"python": {
"default": "${command:python.interpreterPath}",
"description": "Absolute path to the Python interpreter executable; overrides workspace configuration if set.",
"type": "string"
},
"pythonArgs": {
"default": [],
"description": "Command-line arguments passed to the Python interpreter. To pass arguments to the debug target, use \"args\".",
"items": { "type": "string" },
"type": "array"
},
"redirectOutput": { "default": true, "description": "Redirect output.", "type": "boolean" },
"showReturnValue": {
"default": true,
"description": "Show return value of functions when stepping.",
"type": "boolean"
},
"stopOnEntry": { "default": false, "description": "Automatically stop after launch.", "type": "boolean" },
"subProcess": {
"default": false,
"description": "Whether to enable Sub Process debugging.",
"type": "boolean"
},
"sudo": {
"default": false,
"description": "Running debug program under elevated permissions (on Unix).",
"type": "boolean"
},
"guiEventLoop": {
"default": "matplotlib",
"description": "The GUI event loop that's going to run. Possible values: \"matplotlib\", \"wx\", \"qt\", \"none\", or a custom function that'll be imported and run.",
"type": "string"
},
"consoleName": {
"default": "Python Debug Console",
"description": "Display name of the debug console or terminal",
"type": "string"
},
"clientOS": { "default": null, "description": "OS that VS code is using.", "enum": ["windows", null, "unix"] }
}
}
},
{
"if": { "properties": { "request": { "const": "attach" } }, "required": ["request"] },
"then": {
"properties": {
"autoReload": {
"default": {},
"description": "Configures automatic reload of code on edit.",
"properties": {
"enable": { "default": false, "description": "Automatically reload code on edit.", "type": "boolean" },
"exclude": {
"default": [
"**/.git/**",
"**/.metadata/**",
"**/__pycache__/**",
"**/node_modules/**",
"**/site-packages/**"
],
"description": "Glob patterns of paths to exclude from auto reload.",
"items": { "type": "string" },
"type": "array"
},
"include": {
"default": ["**/*.py", "**/*.pyw"],
"description": "Glob patterns of paths to include in auto reload.",
"items": { "type": "string" },
"type": "array"
}
},
"type": "object"
},
"connect": {
"label": "Attach by connecting to debugpy over a socket.",
"properties": {
"host": {
"default": "127.0.0.1",
"description": "Hostname or IP address to connect to.",
"type": "string"
},
"port": { "description": "Port to connect to.", "type": ["number", "string"] }
},
"required": ["port"],
"type": "object"
},
"debugAdapterPath": {
"description": "Path (fully qualified) to the python debug adapter executable.",
"type": "string"
},
"django": { "default": false, "description": "Django debugging.", "type": "boolean" },
"jinja": {
"default": null,
"description": "Jinja template debugging (e.g. Flask).",
"enum": [false, null, true]
},
"justMyCode": {
"default": true,
"description": "If true, show and debug only user-written code. If false, show and debug all code, including library calls.",
"type": "boolean"
},
"listen": {
"label": "Attach by listening for incoming socket connection from debugpy",
"properties": {
"host": {
"default": "127.0.0.1",
"description": "Hostname or IP address of the interface to listen on.",
"type": "string"
},
"port": { "description": "Port to listen on.", "type": ["number", "string"] }
},
"required": ["port"],
"type": "object"
},
"logToFile": {
"default": false,
"description": "Enable logging of debugger events to a log file. This file can be found in the debugpy extension install folder.",
"type": "boolean"
},
"pathMappings": {
"default": [],
"items": {
"label": "Path mapping",
"properties": {
"localRoot": { "default": "${ZED_WORKTREE_ROOT}", "label": "Local source root.", "type": "string" },
"remoteRoot": { "default": "", "label": "Remote source root.", "type": "string" }
},
"required": ["localRoot", "remoteRoot"],
"type": "object"
},
"label": "Path mappings.",
"type": "array"
},
"processId": {
"anyOf": [
{
"default": "${command:pickProcess}",
"description": "Use process picker to select a process to attach, or Process ID as integer.",
"enum": ["${command:pickProcess}"]
},
{ "description": "ID of the local process to attach to.", "type": "integer" }
]
},
"redirectOutput": { "default": true, "description": "Redirect output.", "type": "boolean" },
"showReturnValue": {
"default": true,
"description": "Show return value of functions when stepping.",
"type": "boolean"
},
"subProcess": {
"default": false,
"description": "Whether to enable Sub Process debugging",
"type": "boolean"
},
"consoleName": {
"default": "Python Debug Console",
"description": "Display name of the debug console or terminal",
"type": "string"
},
"clientOS": { "default": null, "description": "OS that VS code is using.", "enum": ["windows", null, "unix"] }
}
}
}
]
}

Some files were not shown because too many files have changed in this diff Show More