Compare commits
65 Commits
fix-python
...
edit-diffs
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d97b1fcd7b | ||
|
|
1ca4a011c6 | ||
|
|
fa4abaf56e | ||
|
|
f9c729a7b1 | ||
|
|
3c8b7caa2e | ||
|
|
4fc96e9453 | ||
|
|
3f7abfbfe4 | ||
|
|
0d8f77b5de | ||
|
|
cb79420773 | ||
|
|
79cf3f5e93 | ||
|
|
d27eec8c58 | ||
|
|
c641209341 | ||
|
|
48a716fcb5 | ||
|
|
25956c49c1 | ||
|
|
4efabe17dd | ||
|
|
320abe9b22 | ||
|
|
9a9f2e71ca | ||
|
|
609895d95f | ||
|
|
3467b80595 | ||
|
|
da2d8bd845 | ||
|
|
6267a147ba | ||
|
|
aceecec6bf | ||
|
|
f3f2c6d811 | ||
|
|
41cffa64b0 | ||
|
|
b486e32f05 | ||
|
|
222d4a2546 | ||
|
|
1eb948654a | ||
|
|
35da1502e1 | ||
|
|
1d98b33ae0 | ||
|
|
4e8ecfc0c4 | ||
|
|
134a0563c2 | ||
|
|
3f4d4af080 | ||
|
|
68ec1d724c | ||
|
|
102ea6ac79 | ||
|
|
5d3718df2d | ||
|
|
f1f5d602fc | ||
|
|
60624d81ba | ||
|
|
91755b2db1 | ||
|
|
e34fee55a0 | ||
|
|
dad6067e18 | ||
|
|
5619a3e618 | ||
|
|
06ad45ce08 | ||
|
|
7e6387052f | ||
|
|
d6196d72c1 | ||
|
|
303036e333 | ||
|
|
d0634bbf2b | ||
|
|
714a60e9e9 | ||
|
|
b09eb4b683 | ||
|
|
8afca164cf | ||
|
|
bf2284019a | ||
|
|
706be9bc06 | ||
|
|
b65aeedfef | ||
|
|
01ccb732be | ||
|
|
3313348769 | ||
|
|
bf052b56a4 | ||
|
|
0b17b59305 | ||
|
|
e63a14721e | ||
|
|
931891414c | ||
|
|
9292ab94b6 | ||
|
|
5fb549699b | ||
|
|
88ea23d7da | ||
|
|
df509bbc20 | ||
|
|
38fcadf948 | ||
|
|
e5cbac1373 | ||
|
|
53375434cf |
28
.github/workflows/run_agent_eval_daily.yml
vendored
Normal file
28
.github/workflows/run_agent_eval_daily.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Run Eval Daily
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
CARGO_INCREMENTAL: 0
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
jobs:
|
||||
run_eval:
|
||||
name: Run Eval
|
||||
if: github.repository_owner == 'zed-industries'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repo
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
clean: false
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Run cargo eval
|
||||
run: cargo run -p eval
|
||||
137
Cargo.lock
generated
137
Cargo.lock
generated
@@ -324,7 +324,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"thiserror 2.0.12",
|
||||
"workspace-hack",
|
||||
]
|
||||
@@ -567,7 +567,7 @@ dependencies = [
|
||||
"settings",
|
||||
"smallvec",
|
||||
"smol",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"telemetry_events",
|
||||
"text",
|
||||
"theme",
|
||||
@@ -702,8 +702,11 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_tool",
|
||||
"buffer_diff",
|
||||
"chrono",
|
||||
"collections",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"html_to_markdown",
|
||||
@@ -711,6 +714,7 @@ dependencies = [
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"multi_buffer",
|
||||
"open",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
@@ -721,9 +725,11 @@ dependencies = [
|
||||
"ui",
|
||||
"unindent",
|
||||
"util",
|
||||
"web_search",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"worktree",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1881,7 +1887,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"workspace-hack",
|
||||
@@ -3028,7 +3034,7 @@ dependencies = [
|
||||
"settings",
|
||||
"sha2",
|
||||
"sqlx",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"subtle",
|
||||
"supermaven_api",
|
||||
"telemetry_events",
|
||||
@@ -3360,7 +3366,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"task",
|
||||
"theme",
|
||||
"ui",
|
||||
@@ -4477,7 +4483,7 @@ dependencies = [
|
||||
"optfield",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
@@ -4886,6 +4892,7 @@ dependencies = [
|
||||
"collections",
|
||||
"context_server",
|
||||
"dap",
|
||||
"dirs 5.0.1",
|
||||
"env_logger 0.11.8",
|
||||
"extension",
|
||||
"fs",
|
||||
@@ -4907,9 +4914,11 @@ dependencies = [
|
||||
"serde",
|
||||
"settings",
|
||||
"shellexpand 2.1.2",
|
||||
"telemetry",
|
||||
"toml 0.8.20",
|
||||
"unindent",
|
||||
"util",
|
||||
"uuid",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -5119,7 +5128,7 @@ dependencies = [
|
||||
"serde",
|
||||
"settings",
|
||||
"smallvec",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
@@ -5970,7 +5979,7 @@ dependencies = [
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"time",
|
||||
@@ -6063,7 +6072,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -6169,7 +6178,7 @@ dependencies = [
|
||||
"slotmap",
|
||||
"smallvec",
|
||||
"smol",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"sum_tree",
|
||||
"taffy",
|
||||
"thiserror 2.0.12",
|
||||
@@ -6817,7 +6826,7 @@ name = "icons"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -7085,16 +7094,16 @@ dependencies = [
|
||||
"paths",
|
||||
"pretty_assertions",
|
||||
"serde",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.9.0"
|
||||
version = "2.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
|
||||
checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.15.2",
|
||||
@@ -7671,7 +7680,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"smol",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"telemetry_events",
|
||||
"thiserror 2.0.12",
|
||||
"util",
|
||||
@@ -7731,7 +7740,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"tiktoken-rs",
|
||||
@@ -7739,6 +7748,7 @@ dependencies = [
|
||||
"ui",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7954,9 +7964,9 @@ checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa"
|
||||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.42"
|
||||
version = "0.1.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec9d6fac27761dabcd4ee73571cdb06b7022dc99089acbe5435691edffaac0f4"
|
||||
checksum = "6b20daca3a4ac14dbdc753c5e90fc7b490a48a9131daed3c9a9ced7b2defd37b"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
@@ -8627,9 +8637,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.46"
|
||||
version = "0.1.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "995942f432bbb4822a7e9c3faa87a695185b0d09273ba85f097b54f4e458f2af"
|
||||
checksum = "03cb1f88093fe50061ca1195d336ffec131347c7b833db31f9ab62a2d1b7925f"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
@@ -8703,7 +8713,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -9550,7 +9560,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -12129,7 +12139,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"tracing",
|
||||
"util",
|
||||
"workspace-hack",
|
||||
@@ -12657,7 +12667,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"thiserror 2.0.12",
|
||||
"time",
|
||||
"tracing",
|
||||
@@ -13253,9 +13263,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.15.0"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9"
|
||||
checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
@@ -13702,7 +13712,7 @@ dependencies = [
|
||||
"settings",
|
||||
"simplelog",
|
||||
"story",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"theme",
|
||||
"title_bar",
|
||||
"ui",
|
||||
@@ -13784,7 +13794,16 @@ version = "0.26.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
|
||||
dependencies = [
|
||||
"strum_macros",
|
||||
"strum_macros 0.26.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32"
|
||||
dependencies = [
|
||||
"strum_macros 0.27.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -13800,6 +13819,19 @@ dependencies = [
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strum_macros"
|
||||
version = "0.27.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.6.1"
|
||||
@@ -14415,7 +14447,7 @@ dependencies = [
|
||||
"serde_json_lenient",
|
||||
"serde_repr",
|
||||
"settings",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"thiserror 2.0.12",
|
||||
"util",
|
||||
"uuid",
|
||||
@@ -14449,7 +14481,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"simplelog",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"theme",
|
||||
"vscode_theme",
|
||||
"workspace-hack",
|
||||
@@ -15450,7 +15482,7 @@ dependencies = [
|
||||
"settings",
|
||||
"smallvec",
|
||||
"story",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"theme",
|
||||
"ui_macros",
|
||||
"util",
|
||||
@@ -16583,6 +16615,36 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web_search"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"gpui",
|
||||
"serde",
|
||||
"workspace-hack",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web_search_providers"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"feature_flags",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"http_client",
|
||||
"language_model",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"web_search",
|
||||
"workspace-hack",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-root-certs"
|
||||
version = "0.26.8"
|
||||
@@ -17621,7 +17683,7 @@ dependencies = [
|
||||
"settings",
|
||||
"smallvec",
|
||||
"sqlez",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"task",
|
||||
"telemetry",
|
||||
"tempfile",
|
||||
@@ -17766,7 +17828,7 @@ dependencies = [
|
||||
"sqlx-macros-core",
|
||||
"sqlx-postgres",
|
||||
"sqlx-sqlite",
|
||||
"strum",
|
||||
"strum 0.26.3",
|
||||
"subtle",
|
||||
"syn 1.0.109",
|
||||
"syn 2.0.100",
|
||||
@@ -18138,7 +18200,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.183.0"
|
||||
version = "0.184.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -18261,6 +18323,8 @@ dependencies = [
|
||||
"uuid",
|
||||
"vim",
|
||||
"vim_mode_setting",
|
||||
"web_search",
|
||||
"web_search_providers",
|
||||
"welcome",
|
||||
"windows 0.61.1",
|
||||
"winresource",
|
||||
@@ -18325,12 +18389,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_llm_client"
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bf21350eced858d129840589158a8f6895c4fa4327ae56dd8c7d6a98495bed4"
|
||||
checksum = "57a5e1b5b3ace3fb55292a4c14036723bb8a01fac4aeaa3c2b63b51228412f94"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.1",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
||||
@@ -165,6 +165,8 @@ members = [
|
||||
"crates/util_macros",
|
||||
"crates/vim",
|
||||
"crates/vim_mode_setting",
|
||||
"crates/web_search",
|
||||
"crates/web_search_providers",
|
||||
"crates/welcome",
|
||||
"crates/workspace",
|
||||
"crates/worktree",
|
||||
@@ -370,6 +372,8 @@ util = { path = "crates/util" }
|
||||
util_macros = { path = "crates/util_macros" }
|
||||
vim = { path = "crates/vim" }
|
||||
vim_mode_setting = { path = "crates/vim_mode_setting" }
|
||||
web_search = { path = "crates/web_search" }
|
||||
web_search_providers = { path = "crates/web_search_providers" }
|
||||
welcome = { path = "crates/welcome" }
|
||||
workspace = { path = "crates/workspace" }
|
||||
worktree = { path = "crates/worktree" }
|
||||
@@ -601,7 +605,7 @@ wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.4"
|
||||
zed_llm_client = "0.5.0"
|
||||
zstd = "0.11"
|
||||
metal = "0.29"
|
||||
|
||||
|
||||
@@ -652,7 +652,8 @@
|
||||
"path_search": true,
|
||||
"read_file": true,
|
||||
"regex_search": true,
|
||||
"thinking": true
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
},
|
||||
"write": {
|
||||
@@ -678,7 +679,8 @@
|
||||
"regex_search": true,
|
||||
"rename": true,
|
||||
"symbol_info": true,
|
||||
"thinking": true
|
||||
"thinking": true,
|
||||
"web_search": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -5,11 +5,12 @@ use crate::thread::{
|
||||
ThreadEvent, ThreadFeedback,
|
||||
};
|
||||
use crate::thread_store::{RulesLoadingError, ThreadStore};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
|
||||
use crate::tool_use::{PendingToolUseStatus, ToolUse};
|
||||
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
|
||||
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
|
||||
use anyhow::Context as _;
|
||||
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
|
||||
use assistant_tool::ToolUseStatus;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::scroll::Autoscroll;
|
||||
use editor::{Editor, EditorElement, EditorStyle, MultiBuffer};
|
||||
@@ -766,10 +767,11 @@ impl ActiveThread {
|
||||
self.thread.read(cx).summary_or_default()
|
||||
}
|
||||
|
||||
pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
|
||||
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
|
||||
self.last_error.take();
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx))
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_last_completion(Some(window.window_handle()), cx)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn last_error(&self) -> Option<ThreadError> {
|
||||
@@ -943,8 +945,8 @@ impl ActiveThread {
|
||||
&tool_use.input,
|
||||
self.thread
|
||||
.read(cx)
|
||||
.tool_result(&tool_use.id)
|
||||
.map(|result| result.content.clone().into())
|
||||
.output_for_tool(&tool_use.id)
|
||||
.map(|output| output.clone().into())
|
||||
.unwrap_or("".into()),
|
||||
cx,
|
||||
);
|
||||
@@ -1142,7 +1144,7 @@ impl ActiveThread {
|
||||
fn confirm_editing_message(
|
||||
&mut self,
|
||||
_: &menu::Confirm,
|
||||
_: &mut Window,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let Some((message_id, state)) = self.editing_message.take() else {
|
||||
@@ -1171,7 +1173,12 @@ impl ActiveThread {
|
||||
}
|
||||
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.send_to_model(model.model, RequestKind::Chat, cx)
|
||||
thread.send_to_model(
|
||||
model.model,
|
||||
RequestKind::Chat,
|
||||
Some(window.window_handle()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
@@ -2279,12 +2286,15 @@ impl ActiveThread {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement + use<> {
|
||||
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
|
||||
return card.render(&tool_use.status, window, cx);
|
||||
}
|
||||
|
||||
let is_open = self
|
||||
.expanded_tool_uses
|
||||
.get(&tool_use.id)
|
||||
.copied()
|
||||
.unwrap_or_default();
|
||||
|
||||
let is_status_finished = matches!(&tool_use.status, ToolUseStatus::Finished(_));
|
||||
|
||||
let fs = self
|
||||
@@ -2375,6 +2385,7 @@ impl ActiveThread {
|
||||
open_markdown_link(text, workspace.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
.into_any_element()
|
||||
}),
|
||||
)),
|
||||
),
|
||||
@@ -2431,6 +2442,7 @@ impl ActiveThread {
|
||||
open_markdown_link(text, workspace.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
.into_any_element()
|
||||
})),
|
||||
),
|
||||
),
|
||||
@@ -2761,7 +2773,7 @@ impl ActiveThread {
|
||||
)
|
||||
})
|
||||
}
|
||||
})
|
||||
}).into_any_element()
|
||||
}
|
||||
|
||||
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
|
||||
@@ -2825,7 +2837,7 @@ impl ActiveThread {
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
_: &ClickEvent,
|
||||
_window: &mut Window,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
|
||||
@@ -2841,6 +2853,7 @@ impl ActiveThread {
|
||||
c.input.clone(),
|
||||
&c.messages,
|
||||
c.tool.clone(),
|
||||
Some(window.window_handle()),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
@@ -2852,11 +2865,12 @@ impl ActiveThread {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
_: &ClickEvent,
|
||||
_window: &mut Window,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let window_handle = window.window_handle();
|
||||
self.thread.update(cx, |thread, cx| {
|
||||
thread.deny_tool_use(tool_use_id, tool_name, cx);
|
||||
thread.deny_tool_use(tool_use_id, tool_name, Some(window_handle), cx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
|
||||
use anyhow::Result;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use collections::HashSet;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::{
|
||||
Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
|
||||
actions::{GoToHunk, GoToPreviousHunk},
|
||||
@@ -355,16 +355,24 @@ impl AgentDiff {
|
||||
self.update_selection(&diff_hunks_in_ranges, window, cx);
|
||||
}
|
||||
|
||||
let mut ranges_by_buffer = HashMap::default();
|
||||
for hunk in &diff_hunks_in_ranges {
|
||||
let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
|
||||
if let Some(buffer) = buffer {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.reject_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
ranges_by_buffer
|
||||
.entry(buffer.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(hunk.buffer_range.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for (buffer, ranges) in ranges_by_buffer {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.reject_edits_in_ranges(buffer, ranges, cx)
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn update_selection(
|
||||
|
||||
@@ -9,11 +9,14 @@ use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use context_server::manager::ContextServerManager;
|
||||
use fs::Fs;
|
||||
use gpui::{Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, Subscription};
|
||||
use gpui::{
|
||||
Action, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Subscription,
|
||||
};
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use ui::{
|
||||
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Switch, Tooltip, prelude::*,
|
||||
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
|
||||
Switch, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use zed_actions::ExtensionCategoryFilter;
|
||||
@@ -31,6 +34,8 @@ pub struct AssistantConfiguration {
|
||||
expanded_context_server_tools: HashMap<Arc<str>, bool>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
_registry_subscription: Subscription,
|
||||
scroll_handle: ScrollHandle,
|
||||
scrollbar_state: ScrollbarState,
|
||||
}
|
||||
|
||||
impl AssistantConfiguration {
|
||||
@@ -60,6 +65,9 @@ impl AssistantConfiguration {
|
||||
},
|
||||
);
|
||||
|
||||
let scroll_handle = ScrollHandle::new();
|
||||
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
|
||||
|
||||
let mut this = Self {
|
||||
fs,
|
||||
focus_handle,
|
||||
@@ -68,6 +76,8 @@ impl AssistantConfiguration {
|
||||
expanded_context_server_tools: HashMap::default(),
|
||||
tools,
|
||||
_registry_subscription: registry_subscription,
|
||||
scroll_handle,
|
||||
scrollbar_state,
|
||||
};
|
||||
this.build_provider_configuration_views(window, cx);
|
||||
this
|
||||
@@ -109,7 +119,7 @@ pub enum AssistantConfigurationEvent {
|
||||
impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
|
||||
|
||||
impl AssistantConfiguration {
|
||||
fn render_provider_configuration(
|
||||
fn render_provider_configuration_block(
|
||||
&mut self,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -164,7 +174,7 @@ impl AssistantConfiguration {
|
||||
.p(DynamicSpacing::Base08.rems(cx))
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_sm()
|
||||
.map(|parent| match configuration_view {
|
||||
Some(configuration_view) => parent.child(configuration_view),
|
||||
@@ -175,6 +185,33 @@ impl AssistantConfiguration {
|
||||
)
|
||||
}
|
||||
|
||||
fn render_provider_configuration_section(
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.gap_4()
|
||||
.flex_1()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers").size(HeadlineSize::Small))
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.children(
|
||||
providers
|
||||
.into_iter()
|
||||
.map(|provider| self.render_provider_configuration_block(&provider, cx)),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_command_permission(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let always_allow_tool_actions = AssistantSettings::get_global(cx).always_allow_tool_actions;
|
||||
|
||||
@@ -182,6 +219,7 @@ impl AssistantConfiguration {
|
||||
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.child(Headline::new("General Settings").size(HeadlineSize::Small))
|
||||
@@ -233,6 +271,7 @@ impl AssistantConfiguration {
|
||||
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.pr(DynamicSpacing::Base20.rems(cx))
|
||||
.gap_2()
|
||||
.flex_1()
|
||||
.child(
|
||||
@@ -426,39 +465,51 @@ impl AssistantConfiguration {
|
||||
|
||||
impl Render for AssistantConfiguration {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
|
||||
v_flex()
|
||||
.id("assistant-configuration")
|
||||
.key_context("AgentConfiguration")
|
||||
.track_focus(&self.focus_handle(cx))
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.relative()
|
||||
.size_full()
|
||||
.overflow_y_scroll()
|
||||
.child(self.render_command_permission(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_context_servers_section(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.pb_8()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.child(
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.mt_1()
|
||||
.gap_6()
|
||||
.flex_1()
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Headline::new("LLM Providers").size(HeadlineSize::Small))
|
||||
.child(
|
||||
Label::new("Add at least one provider to use AI-powered features.")
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.children(
|
||||
providers
|
||||
.into_iter()
|
||||
.map(|provider| self.render_provider_configuration(&provider, cx)),
|
||||
),
|
||||
.id("assistant-configuration-content")
|
||||
.track_scroll(&self.scroll_handle)
|
||||
.size_full()
|
||||
.overflow_y_scroll()
|
||||
.child(self.render_command_permission(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_context_servers_section(cx))
|
||||
.child(Divider::horizontal().color(DividerColor::Border))
|
||||
.child(self.render_provider_configuration_section(cx)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("assistant-configuration-scrollbar")
|
||||
.occlude()
|
||||
.absolute()
|
||||
.right(px(3.))
|
||||
.top_0()
|
||||
.bottom_0()
|
||||
.pb_6()
|
||||
.w(px(12.))
|
||||
.cursor_default()
|
||||
.on_mouse_move(cx.listener(|_, _, _window, cx| {
|
||||
cx.notify();
|
||||
cx.stop_propagation()
|
||||
}))
|
||||
.on_hover(|_, _window, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_any_mouse_down(|_, _window, cx| {
|
||||
cx.stop_propagation();
|
||||
})
|
||||
.on_scroll_wheel(cx.listener(|_, _, _window, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
.children(Scrollbar::vertical(self.scrollbar_state.clone())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ use language_model_selector::ToggleModelSelector;
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::PromptBuilder;
|
||||
use proto::Plan;
|
||||
use settings::{Settings, update_settings_file};
|
||||
use time::UtcOffset;
|
||||
use ui::{
|
||||
@@ -336,14 +337,9 @@ impl AssistantPanel {
|
||||
&self.thread_store
|
||||
}
|
||||
|
||||
fn cancel(
|
||||
&mut self,
|
||||
_: &editor::actions::Cancel,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
fn cancel(&mut self, _: &editor::actions::Cancel, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
}
|
||||
|
||||
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -1449,6 +1445,9 @@ impl AssistantPanel {
|
||||
ThreadError::MaxMonthlySpendReached => {
|
||||
self.render_max_monthly_spend_reached_error(cx)
|
||||
}
|
||||
ThreadError::ModelRequestLimitReached { plan } => {
|
||||
self.render_model_request_limit_reached_error(plan, cx)
|
||||
}
|
||||
ThreadError::Message { header, message } => {
|
||||
self.render_error_message(header, message, cx)
|
||||
}
|
||||
@@ -1551,6 +1550,67 @@ impl AssistantPanel {
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_model_request_limit_reached_error(
|
||||
&self,
|
||||
plan: Plan,
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let error_message = match plan {
|
||||
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
|
||||
Plan::ZedPro => {
|
||||
"Model request limit reached. Upgrade to usage-based billing for more requests."
|
||||
}
|
||||
};
|
||||
let call_to_action = match plan {
|
||||
Plan::Free => "Upgrade to Zed Pro",
|
||||
Plan::ZedPro => "Upgrade to usage-based billing",
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.items_center()
|
||||
.child(Icon::new(IconName::XCircle).color(Color::Error))
|
||||
.child(Label::new("Model Request Limit Reached").weight(FontWeight::MEDIUM)),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("error-message")
|
||||
.max_h_24()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(error_message)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.child(
|
||||
Button::new("subscribe", call_to_action).on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
cx.notify();
|
||||
},
|
||||
)),
|
||||
)
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
this.clear_last_error();
|
||||
});
|
||||
|
||||
cx.notify();
|
||||
},
|
||||
))),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn render_error_message(
|
||||
&self,
|
||||
header: SharedString,
|
||||
|
||||
@@ -263,6 +263,7 @@ impl MessageEditor {
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
|
||||
let window_handle = window.window_handle();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
@@ -297,7 +298,7 @@ impl MessageEditor {
|
||||
// Send to model after summaries are done
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send_to_model(model, request_kind, cx);
|
||||
thread.send_to_model(model, request_kind, Some(window_handle), cx);
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
@@ -305,9 +306,9 @@ impl MessageEditor {
|
||||
}
|
||||
|
||||
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let cancelled = self
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
|
||||
let cancelled = self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_last_completion(Some(window.window_handle()), cx)
|
||||
});
|
||||
|
||||
if cancelled {
|
||||
self.set_editor_is_expanded(false, cx);
|
||||
|
||||
@@ -6,24 +6,27 @@ use std::time::Instant;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt, StreamExt as _};
|
||||
use git::repository::DiffType;
|
||||
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
|
||||
Role, StopReason, TokenUsage,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use project::Project;
|
||||
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
|
||||
use prompt_store::PromptBuilder;
|
||||
use proto::Plan;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
@@ -630,6 +633,14 @@ impl Thread {
|
||||
self.tool_use.tool_result(id)
|
||||
}
|
||||
|
||||
pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
|
||||
Some(&self.tool_use.tool_result(id)?.content)
|
||||
}
|
||||
|
||||
pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
|
||||
self.tool_use.tool_result_card(id).cloned()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_use.message_has_tool_results(message_id)
|
||||
}
|
||||
@@ -838,6 +849,7 @@ impl Thread {
|
||||
&mut self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
request_kind: RequestKind,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let mut request = self.to_completion_request(request_kind, cx);
|
||||
@@ -864,7 +876,7 @@ impl Thread {
|
||||
};
|
||||
}
|
||||
|
||||
self.stream_completion(request, model, cx);
|
||||
self.stream_completion(request, model, window, cx);
|
||||
}
|
||||
|
||||
pub fn used_tools_since_last_user_message(&self) -> bool {
|
||||
@@ -1010,6 +1022,7 @@ impl Thread {
|
||||
&mut self,
|
||||
request: LanguageModelRequest,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let pending_completion_id = post_inc(&mut self.completion_count);
|
||||
@@ -1137,7 +1150,7 @@ impl Thread {
|
||||
match result.as_ref() {
|
||||
Ok(stop_reason) => match stop_reason {
|
||||
StopReason::ToolUse => {
|
||||
let tool_uses = thread.use_pending_tools(cx);
|
||||
let tool_uses = thread.use_pending_tools(window, cx);
|
||||
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
|
||||
}
|
||||
StopReason::EndTurn => {}
|
||||
@@ -1150,6 +1163,12 @@ impl Thread {
|
||||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::MaxMonthlySpendReached,
|
||||
));
|
||||
} else if let Some(error) =
|
||||
error.downcast_ref::<ModelRequestLimitReachedError>()
|
||||
{
|
||||
cx.emit(ThreadEvent::ShowError(
|
||||
ThreadError::ModelRequestLimitReached { plan: error.plan },
|
||||
));
|
||||
} else if let Some(known_error) =
|
||||
error.downcast_ref::<LanguageModelKnownError>()
|
||||
{
|
||||
@@ -1176,7 +1195,7 @@ impl Thread {
|
||||
}));
|
||||
}
|
||||
|
||||
thread.cancel_last_completion(cx);
|
||||
thread.cancel_last_completion(window, cx);
|
||||
}
|
||||
}
|
||||
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
|
||||
@@ -1342,7 +1361,11 @@ impl Thread {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
|
||||
pub fn use_pending_tools(
|
||||
&mut self,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Vec<PendingToolUse> {
|
||||
self.auto_capture_telemetry(cx);
|
||||
let request = self.to_completion_request(RequestKind::Chat, cx);
|
||||
let messages = Arc::new(request.messages);
|
||||
@@ -1374,6 +1397,7 @@ impl Thread {
|
||||
tool_use.input.clone(),
|
||||
&messages,
|
||||
tool,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1390,9 +1414,10 @@ impl Thread {
|
||||
input: serde_json::Value,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
tool: Arc<dyn Tool>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
|
||||
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
|
||||
self.tool_use
|
||||
.run_pending_tool(tool_use_id, ui_text.into(), task);
|
||||
}
|
||||
@@ -1403,6 +1428,7 @@ impl Thread {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
input: serde_json::Value,
|
||||
tool: Arc<dyn Tool>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) -> Task<()> {
|
||||
let tool_name: Arc<str> = tool.name().into();
|
||||
@@ -1415,10 +1441,17 @@ impl Thread {
|
||||
messages,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
|
||||
// Store the card separately if it exists
|
||||
if let Some(card) = tool_result.card.clone() {
|
||||
self.tool_use
|
||||
.insert_tool_result_card(tool_use_id.clone(), card);
|
||||
}
|
||||
|
||||
cx.spawn({
|
||||
async move |thread: WeakEntity<Thread>, cx| {
|
||||
let output = tool_result.output.await;
|
||||
@@ -1431,7 +1464,7 @@ impl Thread {
|
||||
output,
|
||||
cx,
|
||||
);
|
||||
thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
|
||||
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -1443,6 +1476,7 @@ impl Thread {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
pending_tool_use: Option<PendingToolUse>,
|
||||
canceled: bool,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.all_tools_finished() {
|
||||
@@ -1450,7 +1484,7 @@ impl Thread {
|
||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||
self.attach_tool_results(cx);
|
||||
if !canceled {
|
||||
self.send_to_model(model, RequestKind::Chat, cx);
|
||||
self.send_to_model(model, RequestKind::Chat, window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1477,7 +1511,11 @@ impl Thread {
|
||||
/// Cancels the last pending completion, if there are any pending.
|
||||
///
|
||||
/// Returns whether a completion was canceled.
|
||||
pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
pub fn cancel_last_completion(
|
||||
&mut self,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
let canceled = if self.pending_completions.pop().is_some() {
|
||||
true
|
||||
} else {
|
||||
@@ -1488,6 +1526,7 @@ impl Thread {
|
||||
pending_tool_use.id.clone(),
|
||||
Some(pending_tool_use),
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1801,14 +1840,14 @@ impl Thread {
|
||||
.update(cx, |action_log, cx| action_log.keep_all_edits(cx));
|
||||
}
|
||||
|
||||
pub fn reject_edits_in_range(
|
||||
pub fn reject_edits_in_ranges(
|
||||
&mut self,
|
||||
buffer: Entity<language::Buffer>,
|
||||
buffer_range: Range<language::Anchor>,
|
||||
buffer_ranges: Vec<Range<language::Anchor>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
self.action_log.update(cx, |action_log, cx| {
|
||||
action_log.reject_edits_in_range(buffer, buffer_range, cx)
|
||||
action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1911,6 +1950,7 @@ impl Thread {
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let err = Err(anyhow::anyhow!(
|
||||
@@ -1919,7 +1959,7 @@ impl Thread {
|
||||
|
||||
self.tool_use
|
||||
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
|
||||
self.tool_finished(tool_use_id.clone(), None, true, cx);
|
||||
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1929,6 +1969,8 @@ pub enum ThreadError {
|
||||
PaymentRequired,
|
||||
#[error("Max monthly spend reached")]
|
||||
MaxMonthlySpendReached,
|
||||
#[error("Model request limit reached")]
|
||||
ModelRequestLimitReached { plan: Plan },
|
||||
#[error("Message {header}: {message}")]
|
||||
Message {
|
||||
header: SharedString,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use assistant_tool::{Tool, ToolWorkingSet};
|
||||
use assistant_tool::{AnyToolCard, Tool, ToolUseStatus, ToolWorkingSet};
|
||||
use collections::HashMap;
|
||||
use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
@@ -27,26 +27,7 @@ pub struct ToolUse {
|
||||
pub needs_confirmation: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolUseStatus {
|
||||
NeedsConfirmation,
|
||||
Pending,
|
||||
Running,
|
||||
Finished(SharedString),
|
||||
Error(SharedString),
|
||||
}
|
||||
|
||||
impl ToolUseStatus {
|
||||
pub fn text(&self) -> SharedString {
|
||||
match self {
|
||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||
ToolUseStatus::Pending => "".into(),
|
||||
ToolUseStatus::Running => "".into(),
|
||||
ToolUseStatus::Finished(out) => out.clone(),
|
||||
ToolUseStatus::Error(out) => out.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||
|
||||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
@@ -54,10 +35,9 @@ pub struct ToolUseState {
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
}
|
||||
|
||||
pub const USING_TOOL_MARKER: &str = "<using_tool>";
|
||||
|
||||
impl ToolUseState {
|
||||
pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
|
||||
Self {
|
||||
@@ -66,6 +46,7 @@ impl ToolUseState {
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,6 +238,18 @@ impl ToolUseState {
|
||||
self.tool_results.get(tool_use_id)
|
||||
}
|
||||
|
||||
pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> {
|
||||
self.tool_result_cards.get(tool_use_id)
|
||||
}
|
||||
|
||||
pub fn insert_tool_result_card(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
card: AnyToolCard,
|
||||
) {
|
||||
self.tool_result_cards.insert(tool_use_id, card);
|
||||
}
|
||||
|
||||
pub fn request_tool_use(
|
||||
&mut self,
|
||||
assistant_message_id: MessageId,
|
||||
|
||||
@@ -3,7 +3,7 @@ use buffer_diff::BufferDiff;
|
||||
use collections::BTreeMap;
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
|
||||
use language::{Anchor, Buffer, BufferEvent, DiskState, Point, ToPoint};
|
||||
use project::{Project, ProjectItem, lsp_store::OpenLspBufferHandle};
|
||||
use std::{cmp, ops::Range, sync::Arc};
|
||||
use text::{Edit, Patch, Rope};
|
||||
@@ -363,10 +363,10 @@ impl ActionLog {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reject_edits_in_range(
|
||||
pub fn reject_edits_in_ranges(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
buffer_range: Range<impl language::ToPoint>,
|
||||
buffer_ranges: Vec<Range<impl language::ToPoint>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
|
||||
@@ -403,29 +403,15 @@ impl ActionLog {
|
||||
}
|
||||
TrackedBufferStatus::Modified => {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let buffer_range =
|
||||
buffer_range.start.to_point(buffer)..buffer_range.end.to_point(buffer);
|
||||
let mut buffer_row_ranges = buffer_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
range.start.to_point(buffer).row..range.end.to_point(buffer).row
|
||||
})
|
||||
.peekable();
|
||||
|
||||
let mut edits_to_revert = Vec::new();
|
||||
for edit in tracked_buffer.unreviewed_changes.edits() {
|
||||
if buffer_range.end.row < edit.new.start {
|
||||
break;
|
||||
} else if buffer_range.start.row > edit.new.end {
|
||||
continue;
|
||||
}
|
||||
|
||||
let old_range = tracked_buffer
|
||||
.base_text
|
||||
.point_to_offset(Point::new(edit.old.start, 0))
|
||||
..tracked_buffer.base_text.point_to_offset(cmp::min(
|
||||
Point::new(edit.old.end, 0),
|
||||
tracked_buffer.base_text.max_point(),
|
||||
));
|
||||
let old_text = tracked_buffer
|
||||
.base_text
|
||||
.chunks_in_range(old_range)
|
||||
.collect::<String>();
|
||||
|
||||
let new_range = tracked_buffer
|
||||
.snapshot
|
||||
.anchor_before(Point::new(edit.new.start, 0))
|
||||
@@ -433,7 +419,35 @@ impl ActionLog {
|
||||
Point::new(edit.new.end, 0),
|
||||
tracked_buffer.snapshot.max_point(),
|
||||
));
|
||||
edits_to_revert.push((new_range, old_text));
|
||||
let new_row_range = new_range.start.to_point(buffer).row
|
||||
..new_range.end.to_point(buffer).row;
|
||||
|
||||
let mut revert = false;
|
||||
while let Some(buffer_row_range) = buffer_row_ranges.peek() {
|
||||
if buffer_row_range.end < new_row_range.start {
|
||||
buffer_row_ranges.next();
|
||||
} else if buffer_row_range.start > new_row_range.end {
|
||||
break;
|
||||
} else {
|
||||
revert = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if revert {
|
||||
let old_range = tracked_buffer
|
||||
.base_text
|
||||
.point_to_offset(Point::new(edit.old.start, 0))
|
||||
..tracked_buffer.base_text.point_to_offset(cmp::min(
|
||||
Point::new(edit.old.end, 0),
|
||||
tracked_buffer.base_text.max_point(),
|
||||
));
|
||||
let old_text = tracked_buffer
|
||||
.base_text
|
||||
.chunks_in_range(old_range)
|
||||
.collect::<String>();
|
||||
edits_to_revert.push((new_range, old_text));
|
||||
}
|
||||
}
|
||||
|
||||
buffer.edit(edits_to_revert, None, cx);
|
||||
@@ -599,6 +613,7 @@ fn point_to_row_edit(edit: Edit<Point>, old_text: &Rope, new_text: &Rope) -> Edi
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
enum ChangeAuthor {
|
||||
User,
|
||||
Agent,
|
||||
@@ -1135,9 +1150,48 @@ mod tests {
|
||||
)]
|
||||
);
|
||||
|
||||
// If the rejected range doesn't overlap with any hunk, we ignore it.
|
||||
action_log
|
||||
.update(cx, |log, cx| {
|
||||
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(1, 0), cx)
|
||||
log.reject_edits_in_ranges(
|
||||
buffer.clone(),
|
||||
vec![Point::new(4, 0)..Point::new(4, 0)],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.text()),
|
||||
"abc\ndE\nXYZf\nghi\njkl\nmnO"
|
||||
);
|
||||
assert_eq!(
|
||||
unreviewed_hunks(&action_log, cx),
|
||||
vec![(
|
||||
buffer.clone(),
|
||||
vec![
|
||||
HunkStatus {
|
||||
range: Point::new(1, 0)..Point::new(3, 0),
|
||||
diff_status: DiffHunkStatusKind::Modified,
|
||||
old_text: "def\n".into(),
|
||||
},
|
||||
HunkStatus {
|
||||
range: Point::new(5, 0)..Point::new(5, 3),
|
||||
diff_status: DiffHunkStatusKind::Modified,
|
||||
old_text: "mno".into(),
|
||||
}
|
||||
],
|
||||
)]
|
||||
);
|
||||
|
||||
action_log
|
||||
.update(cx, |log, cx| {
|
||||
log.reject_edits_in_ranges(
|
||||
buffer.clone(),
|
||||
vec![Point::new(0, 0)..Point::new(1, 0)],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1160,7 +1214,11 @@ mod tests {
|
||||
|
||||
action_log
|
||||
.update(cx, |log, cx| {
|
||||
log.reject_edits_in_range(buffer.clone(), Point::new(4, 0)..Point::new(4, 0), cx)
|
||||
log.reject_edits_in_ranges(
|
||||
buffer.clone(),
|
||||
vec![Point::new(4, 0)..Point::new(4, 0)],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1172,6 +1230,82 @@ mod tests {
|
||||
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_reject_multiple_edits(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
|
||||
.unwrap()
|
||||
});
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(5, 2)..Point::new(5, 3), "O")], None, cx)
|
||||
.unwrap()
|
||||
});
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
});
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.text()),
|
||||
"abc\ndE\nXYZf\nghi\njkl\nmnO"
|
||||
);
|
||||
assert_eq!(
|
||||
unreviewed_hunks(&action_log, cx),
|
||||
vec![(
|
||||
buffer.clone(),
|
||||
vec![
|
||||
HunkStatus {
|
||||
range: Point::new(1, 0)..Point::new(3, 0),
|
||||
diff_status: DiffHunkStatusKind::Modified,
|
||||
old_text: "def\n".into(),
|
||||
},
|
||||
HunkStatus {
|
||||
range: Point::new(5, 0)..Point::new(5, 3),
|
||||
diff_status: DiffHunkStatusKind::Modified,
|
||||
old_text: "mno".into(),
|
||||
}
|
||||
],
|
||||
)]
|
||||
);
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
let range_1 = buffer.read(cx).anchor_before(Point::new(0, 0))
|
||||
..buffer.read(cx).anchor_before(Point::new(1, 0));
|
||||
let range_2 = buffer.read(cx).anchor_before(Point::new(5, 0))
|
||||
..buffer.read(cx).anchor_before(Point::new(5, 3));
|
||||
|
||||
log.reject_edits_in_ranges(buffer.clone(), vec![range_1, range_2], cx)
|
||||
.detach();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.text()),
|
||||
"abc\ndef\nghi\njkl\nmno"
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.text()),
|
||||
"abc\ndef\nghi\njkl\nmno"
|
||||
);
|
||||
assert_eq!(unreviewed_hunks(&action_log, cx), vec![]);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 10)]
|
||||
async fn test_reject_deleted_file(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
@@ -1215,7 +1349,11 @@ mod tests {
|
||||
|
||||
action_log
|
||||
.update(cx, |log, cx| {
|
||||
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 0), cx)
|
||||
log.reject_edits_in_ranges(
|
||||
buffer.clone(),
|
||||
vec![Point::new(0, 0)..Point::new(0, 0)],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1266,7 +1404,11 @@ mod tests {
|
||||
|
||||
action_log
|
||||
.update(cx, |log, cx| {
|
||||
log.reject_edits_in_range(buffer.clone(), Point::new(0, 0)..Point::new(0, 11), cx)
|
||||
log.reject_edits_in_ranges(
|
||||
buffer.clone(),
|
||||
vec![Point::new(0, 0)..Point::new(0, 11)],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1312,7 +1454,7 @@ mod tests {
|
||||
.update(cx, |log, cx| {
|
||||
let range = buffer.read(cx).random_byte_range(0, &mut rng);
|
||||
log::info!("rejecting edits in range {:?}", range);
|
||||
log.reject_edits_in_range(buffer.clone(), range, cx)
|
||||
log.reject_edits_in_ranges(buffer.clone(), vec![range], cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -9,6 +9,11 @@ use std::fmt::Formatter;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::AnyElement;
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use icons::IconName;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
@@ -24,16 +29,87 @@ pub fn init(cx: &mut App) {
|
||||
ToolRegistry::default_global(cx);
|
||||
}
|
||||
|
||||
/// The result of running a tool
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolUseStatus {
|
||||
NeedsConfirmation,
|
||||
Pending,
|
||||
Running,
|
||||
Finished(SharedString),
|
||||
Error(SharedString),
|
||||
}
|
||||
|
||||
impl ToolUseStatus {
|
||||
pub fn text(&self) -> SharedString {
|
||||
match self {
|
||||
ToolUseStatus::NeedsConfirmation => "".into(),
|
||||
ToolUseStatus::Pending => "".into(),
|
||||
ToolUseStatus::Running => "".into(),
|
||||
ToolUseStatus::Finished(out) => out.clone(),
|
||||
ToolUseStatus::Error(out) => out.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of running a tool, containing both the asynchronous output
|
||||
/// and an optional card view that can be rendered immediately.
|
||||
pub struct ToolResult {
|
||||
/// The asynchronous task that will eventually resolve to the tool's output
|
||||
pub output: Task<Result<String>>,
|
||||
/// An optional view to present the output of the tool.
|
||||
pub card: Option<AnyToolCard>,
|
||||
}
|
||||
|
||||
pub trait ToolCard: 'static + Sized {
|
||||
fn render(
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AnyToolCard {
|
||||
entity: gpui::AnyEntity,
|
||||
render: fn(
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyElement,
|
||||
}
|
||||
|
||||
impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||
fn from(entity: Entity<T>) -> Self {
|
||||
fn downcast_render<T: ToolCard>(
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let entity = entity.downcast::<T>().unwrap();
|
||||
entity.update(cx, |entity, cx| {
|
||||
entity.render(status, window, cx).into_any_element()
|
||||
})
|
||||
}
|
||||
|
||||
Self {
|
||||
entity: entity.into(),
|
||||
render: downcast_render::<T>,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnyToolCard {
|
||||
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
|
||||
(self.render)(self.entity.clone(), status, window, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Task<Result<String>>> for ToolResult {
|
||||
/// Convert from a task to a ToolResult
|
||||
/// Convert from a task to a ToolResult with no card
|
||||
fn from(output: Task<Result<String>>) -> Self {
|
||||
Self { output }
|
||||
Self { output, card: None }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,6 +156,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult;
|
||||
}
|
||||
|
||||
@@ -14,8 +14,11 @@ path = "src/assistant_tools.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
@@ -23,6 +26,8 @@ http_client.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
open = { workspace = true }
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
schemars.workspace = true
|
||||
@@ -30,9 +35,11 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
worktree.workspace = true
|
||||
open = { workspace = true }
|
||||
web_search.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -22,14 +22,17 @@ mod schema;
|
||||
mod symbol_info_tool;
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod web_search_tool;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use gpui::App;
|
||||
use http_client::HttpClientWithUrl;
|
||||
use move_path_tool::MovePathTool;
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
use crate::batch_tool::BatchTool;
|
||||
use crate::code_action_tool::CodeActionTool;
|
||||
@@ -56,28 +59,39 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
assistant_tool::init(cx);
|
||||
|
||||
let registry = ToolRegistry::global(cx);
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(BatchTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
registry.register_tool(FindReplaceFileTool);
|
||||
registry.register_tool(SymbolInfoTool);
|
||||
registry.register_tool(CodeActionTool);
|
||||
registry.register_tool(MovePathTool);
|
||||
registry.register_tool(DiagnosticsTool);
|
||||
registry.register_tool(ListDirectoryTool);
|
||||
registry.register_tool(NowTool);
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(CodeSymbolsTool);
|
||||
registry.register_tool(ContentsTool);
|
||||
registry.register_tool(CopyPathTool);
|
||||
registry.register_tool(CreateDirectoryTool);
|
||||
registry.register_tool(CreateFileTool);
|
||||
registry.register_tool(DeletePathTool);
|
||||
registry.register_tool(DiagnosticsTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
registry.register_tool(FindReplaceFileTool);
|
||||
registry.register_tool(ListDirectoryTool);
|
||||
registry.register_tool(MovePathTool);
|
||||
registry.register_tool(NowTool);
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(PathSearchTool);
|
||||
registry.register_tool(ReadFileTool);
|
||||
registry.register_tool(RegexSearchTool);
|
||||
registry.register_tool(RenameTool);
|
||||
registry.register_tool(SymbolInfoTool);
|
||||
registry.register_tool(TerminalTool);
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
|
||||
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
|
||||
use futures::future::join_all;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -218,6 +218,7 @@ impl Tool for BatchTool {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<BatchToolInput>(input) {
|
||||
@@ -258,7 +259,16 @@ impl Tool for BatchTool {
|
||||
let action_log = action_log.clone();
|
||||
let messages = messages.clone();
|
||||
let tool_result = cx
|
||||
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
|
||||
.update(|cx| {
|
||||
tool.run(
|
||||
invocation.input,
|
||||
&messages,
|
||||
project,
|
||||
action_log,
|
||||
window.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
|
||||
|
||||
tasks.push(tool_result.output);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{self, LspAction, Project};
|
||||
@@ -140,6 +140,7 @@ impl Tool for CodeActionTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use collections::IndexMap;
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use language::{OutlineItem, ParseStatus, Point};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, Symbol};
|
||||
@@ -128,6 +128,7 @@ impl Tool for CodeSymbolsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -102,6 +102,7 @@ impl Tool for ContentsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ContentsToolInput>(input) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -76,6 +77,7 @@ impl Tool for CopyPathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -67,6 +68,7 @@ impl Tool for CreateDirectoryTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -72,6 +73,7 @@ impl Tool for CreateFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateFileToolInput>(input) {
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
@@ -62,6 +62,7 @@ impl Tool for DeletePathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -82,6 +82,7 @@ impl Tool for DiagnosticsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
match serde_json::from_value::<DiagnosticsToolInput>(input)
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
@@ -145,6 +145,7 @@ impl Tool for FetchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<FetchToolInput>(input) {
|
||||
|
||||
@@ -1,13 +1,25 @@
|
||||
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{Editor, MultiBuffer, PathKey};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, IntoElement, Task, Window,
|
||||
};
|
||||
use language::{
|
||||
self, Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt as _, Rope,
|
||||
TextBuffer,
|
||||
};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::IconName;
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use ui::{Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::replace::replace_exact;
|
||||
|
||||
@@ -132,6 +144,274 @@ pub struct FindReplaceFileToolInput {
|
||||
pub replace: String,
|
||||
}
|
||||
|
||||
pub struct FindReplaceFileToolCard {
|
||||
path: PathBuf,
|
||||
description: String,
|
||||
editor: Entity<Editor>,
|
||||
multibuffer: Entity<MultiBuffer>,
|
||||
project: Entity<Project>,
|
||||
diff_task: Option<Task<Result<()>>>,
|
||||
}
|
||||
|
||||
impl FindReplaceFileToolCard {
|
||||
fn new(
|
||||
path: PathBuf,
|
||||
description: String,
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadOnly));
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor =
|
||||
Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor
|
||||
});
|
||||
|
||||
Self {
|
||||
path,
|
||||
description,
|
||||
project,
|
||||
editor,
|
||||
multibuffer,
|
||||
diff_task: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_diff(
|
||||
&mut self,
|
||||
path: Arc<Path>,
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
self.diff_task = Some(cx.spawn(async move |this, cx| {
|
||||
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
|
||||
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.multibuffer.update(cx, |multibuffer, cx| {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let diff = buffer_diff.read(cx);
|
||||
let diff_hunk_ranges = diff
|
||||
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
|
||||
.collect::<Vec<_>>();
|
||||
let _is_newly_added = multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&buffer, cx),
|
||||
buffer,
|
||||
diff_hunk_ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
multibuffer.add_diff(buffer_diff, cx);
|
||||
});
|
||||
cx.notify();
|
||||
})
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_buffer(
|
||||
mut text: String,
|
||||
path: Arc<Path>,
|
||||
language_registry: &Arc<language::LanguageRegistry>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Buffer>> {
|
||||
let line_ending = LineEnding::detect(&text);
|
||||
LineEnding::normalize(&mut text);
|
||||
let text = Rope::from(text);
|
||||
let language = cx
|
||||
.update(|_cx| language_registry.language_for_file_path(&path))?
|
||||
.await
|
||||
.ok();
|
||||
let buffer = cx.new(|cx| {
|
||||
let buffer = TextBuffer::new_normalized(
|
||||
0,
|
||||
cx.entity_id().as_non_zero_u64().into(),
|
||||
line_ending,
|
||||
text,
|
||||
);
|
||||
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
|
||||
buffer.set_language(language, cx);
|
||||
buffer
|
||||
})?;
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
async fn build_buffer_diff(
|
||||
mut old_text: String,
|
||||
buffer: &Entity<Buffer>,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<BufferDiff>> {
|
||||
LineEnding::normalize(&mut old_text);
|
||||
|
||||
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
|
||||
let base_buffer = cx
|
||||
.update(|cx| {
|
||||
Buffer::build_snapshot(
|
||||
old_text.clone().into(),
|
||||
buffer.language().cloned(),
|
||||
Some(language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let diff_snapshot = cx
|
||||
.update(|cx| {
|
||||
BufferDiffSnapshot::new_with_base_buffer(
|
||||
buffer.text.clone(),
|
||||
Some(old_text.into()),
|
||||
base_buffer,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut diff = BufferDiff::new(&buffer.text, cx);
|
||||
diff.set_snapshot(diff_snapshot, &buffer.text, cx);
|
||||
diff
|
||||
})
|
||||
}
|
||||
|
||||
impl ToolCard for FindReplaceFileToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let header = h_flex()
|
||||
.id("tool-label-container")
|
||||
.gap_1p5()
|
||||
.max_w_full()
|
||||
.overflow_x_scroll()
|
||||
.child(
|
||||
Icon::new(IconName::Pencil)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new("Edit ").size(LabelSize::Small))
|
||||
.child(
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text),
|
||||
)
|
||||
.child(Label::new(self.path.display().to_string()).size(LabelSize::Small))
|
||||
.into_any_element();
|
||||
|
||||
let header2 = h_flex()
|
||||
.id("code-block-header-label")
|
||||
.w_full()
|
||||
.max_w_full()
|
||||
.px_1()
|
||||
.gap_0p5()
|
||||
.cursor_pointer()
|
||||
.rounded_sm()
|
||||
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
|
||||
.tooltip(Tooltip::text("Jump to File"));
|
||||
// todo!
|
||||
// .child(
|
||||
// h_flex()
|
||||
// .gap_0p5()
|
||||
// .children(
|
||||
// file_icons::FileIcons::get_icon(&path_range.path, cx)
|
||||
// .map(Icon::from_path)
|
||||
// .map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
|
||||
// )
|
||||
// .child(content)
|
||||
// .child(
|
||||
// Icon::new(IconName::ArrowUpRight)
|
||||
// .size(IconSize::XSmall)
|
||||
// .color(Color::Ignored),
|
||||
// ),
|
||||
// )
|
||||
// .on_click({
|
||||
// let path_range = path_range.clone();
|
||||
// move |_, window, cx| {
|
||||
// workspace
|
||||
// .update(cx, {
|
||||
// |workspace, cx| {
|
||||
// if let Some(project_path) = workspace
|
||||
// .project()
|
||||
// .read(cx)
|
||||
// .find_project_path(&path_range.path, cx)
|
||||
// {
|
||||
// let target = path_range.range.as_ref().map(|range| {
|
||||
// Point::new(
|
||||
// // Line number is 1-based
|
||||
// range.start.line.saturating_sub(1),
|
||||
// range.start.col.unwrap_or(0),
|
||||
// )
|
||||
// });
|
||||
// let open_task =
|
||||
// workspace.open_path(project_path, None, true, window, cx);
|
||||
// window
|
||||
// .spawn(cx, async move |cx| {
|
||||
// let item = open_task.await?;
|
||||
// if let Some(target) = target {
|
||||
// if let Some(active_editor) =
|
||||
// item.downcast::<Editor>()
|
||||
// {
|
||||
// active_editor
|
||||
// .downgrade()
|
||||
// .update_in(cx, |editor, window, cx| {
|
||||
// editor.go_to_singleton_buffer_point(
|
||||
// target, window, cx,
|
||||
// );
|
||||
// })
|
||||
// .log_err();
|
||||
// }
|
||||
// }
|
||||
// anyhow::Ok(())
|
||||
// })
|
||||
// .detach_and_log_err(cx);
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// .ok();
|
||||
// }
|
||||
// })
|
||||
// .into_any_element();
|
||||
|
||||
let content = match status {
|
||||
ToolUseStatus::NeedsConfirmation | ToolUseStatus::Pending | ToolUseStatus::Running => {
|
||||
div()
|
||||
// .child(Label::new(&self.description).size(LabelSize::Small))
|
||||
.into_any_element()
|
||||
}
|
||||
ToolUseStatus::Finished(str) => {
|
||||
dbg!(&str);
|
||||
self.editor.clone().into_any_element()
|
||||
}
|
||||
ToolUseStatus::Error(error) => div()
|
||||
.child(
|
||||
Label::new(error.to_string())
|
||||
.color(Color::Error)
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.into_any_element(),
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.my_2()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_sm()
|
||||
.gap_1()
|
||||
.child(header)
|
||||
.child(content)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FindReplaceFileTool;
|
||||
|
||||
impl Tool for FindReplaceFileTool {
|
||||
@@ -168,14 +448,32 @@ impl Tool for FindReplaceFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let card = window.and_then(|window| {
|
||||
window
|
||||
.update(cx, |_, window, cx| {
|
||||
cx.new(|cx| {
|
||||
FindReplaceFileToolCard::new(
|
||||
input.path.clone(),
|
||||
input.display_description.clone(),
|
||||
project.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
let output = cx.spawn({
|
||||
let card = card.clone();
|
||||
async move |cx: &mut AsyncApp| {
|
||||
let project_path = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.find_project_path(&input.path, cx)
|
||||
@@ -183,7 +481,7 @@ impl Tool for FindReplaceFileTool {
|
||||
})??;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.update(cx, |project, cx| project.open_buffer(project_path.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
@@ -255,14 +553,29 @@ impl Tool for FindReplaceFileTool {
|
||||
project.save_buffer(buffer, cx)
|
||||
})?.await?;
|
||||
|
||||
let diff_str = cx.background_spawn(async move {
|
||||
let new_text = snapshot.text();
|
||||
language::unified_diff(&old_text, &new_text)
|
||||
let new_text = snapshot.text();
|
||||
|
||||
let diff_str = cx.background_spawn({
|
||||
// todo! probably don't need this
|
||||
let old_text = old_text.clone();
|
||||
let new_text = new_text.clone();
|
||||
async move {
|
||||
language::unified_diff(&old_text, &new_text)
|
||||
}
|
||||
}).await;
|
||||
|
||||
if let Some(card) = card {
|
||||
card.update(cx, |card, cx| {
|
||||
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
|
||||
}).log_err();
|
||||
}
|
||||
|
||||
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
|
||||
}});
|
||||
|
||||
}).into()
|
||||
ToolResult {
|
||||
output,
|
||||
card: card.map(|card| card.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -76,6 +76,7 @@ impl Tool for ListDirectoryTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -89,6 +89,7 @@ impl Tool for MovePathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<MovePathToolInput>(input) {
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -59,6 +59,7 @@ impl Tool for NowTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: NowToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -52,6 +52,7 @@ impl Tool for OpenTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: OpenToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -70,6 +70,7 @@ impl Tool for PathSearchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -87,6 +87,7 @@ impl Tool for ReadFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::StreamExt;
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::OffsetRangeExt;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{
|
||||
@@ -91,6 +91,7 @@ impl Tool for RegexSearchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{self, Buffer, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -87,6 +87,7 @@ impl Tool for RenameTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<RenameToolInput>(input) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -121,6 +121,7 @@ impl Tool for SymbolInfoTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::io::BufReader;
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -78,6 +78,7 @@ impl Tool for TerminalTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: TerminalToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -50,6 +50,7 @@ impl Tool for ThinkingTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
// This tool just "thinks out loud" and doesn't perform any actions.
|
||||
|
||||
214
crates/assistant_tools/src/web_search_tool.rs
Normal file
214
crates/assistant_tools/src/web_search_tool.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task,
|
||||
Window, pulsating_between,
|
||||
};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::{IconName, Tooltip, prelude::*};
|
||||
use web_search::WebSearchRegistry;
|
||||
use zed_llm_client::WebSearchResponse;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct WebSearchToolInput {
|
||||
/// The search term or question to query on the web.
|
||||
query: String,
|
||||
}
|
||||
|
||||
pub struct WebSearchTool;
|
||||
|
||||
impl Tool for WebSearchTool {
|
||||
fn name(&self) -> String {
|
||||
"web_search".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
"Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Globe
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<WebSearchToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, _input: &serde_json::Value) -> String {
|
||||
"Web Search".to_string()
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
|
||||
Ok(input) => input,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
|
||||
return Task::ready(Err(anyhow!("Web search is not available."))).into();
|
||||
};
|
||||
|
||||
let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
|
||||
let output = cx.background_spawn({
|
||||
let search_task = search_task.clone();
|
||||
async move {
|
||||
let response = search_task.await.map_err(|err| anyhow!(err))?;
|
||||
serde_json::to_string(&response).context("Failed to serialize search results")
|
||||
}
|
||||
});
|
||||
|
||||
ToolResult {
|
||||
output,
|
||||
card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WebSearchToolCard {
|
||||
response: Option<Result<WebSearchResponse>>,
|
||||
_task: Task<()>,
|
||||
}
|
||||
|
||||
impl WebSearchToolCard {
|
||||
fn new(
|
||||
search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let _task = cx.spawn(async move |this, cx| {
|
||||
let response = search_task.await.map_err(|err| anyhow!(err));
|
||||
this.update(cx, |this, cx| {
|
||||
this.response = Some(response);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
|
||||
Self {
|
||||
response: None,
|
||||
_task,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCard for WebSearchToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
_status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let header = h_flex()
|
||||
.id("tool-label-container")
|
||||
.gap_1p5()
|
||||
.max_w_full()
|
||||
.overflow_x_scroll()
|
||||
.child(
|
||||
Icon::new(IconName::Globe)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(match self.response.as_ref() {
|
||||
Some(Ok(response)) => {
|
||||
let text: SharedString = if response.citations.len() == 1 {
|
||||
"1 result".into()
|
||||
} else {
|
||||
format!("{} results", response.citations.len()).into()
|
||||
};
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Label::new("Searched the Web").size(LabelSize::Small))
|
||||
.child(
|
||||
div()
|
||||
.size(px(3.))
|
||||
.rounded_full()
|
||||
.bg(cx.theme().colors().text),
|
||||
)
|
||||
.child(Label::new(text).size(LabelSize::Small))
|
||||
.into_any_element()
|
||||
}
|
||||
Some(Err(error)) => div()
|
||||
.id("web-search-error")
|
||||
.child(Label::new("Web Search failed").size(LabelSize::Small))
|
||||
.tooltip(Tooltip::text(error.to_string()))
|
||||
.into_any_element(),
|
||||
|
||||
None => Label::new("Searching the Web…")
|
||||
.size(LabelSize::Small)
|
||||
.with_animation(
|
||||
"web-search-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.6, 1.)),
|
||||
|label, delta| label.alpha(delta),
|
||||
)
|
||||
.into_any_element(),
|
||||
})
|
||||
.into_any();
|
||||
|
||||
let content =
|
||||
self.response.as_ref().and_then(|response| match response {
|
||||
Ok(response) => {
|
||||
Some(
|
||||
v_flex()
|
||||
.ml_1p5()
|
||||
.pl_1p5()
|
||||
.border_l_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.gap_1()
|
||||
.children(response.citations.iter().enumerate().map(
|
||||
|(index, citation)| {
|
||||
let title = citation.title.clone();
|
||||
let url = citation.url.clone();
|
||||
|
||||
Button::new(("citation", index), title)
|
||||
.label_size(LabelSize::Small)
|
||||
.color(Color::Muted)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_position(IconPosition::End)
|
||||
.truncate(true)
|
||||
.tooltip({
|
||||
let url = url.clone();
|
||||
move |window, cx| {
|
||||
Tooltip::with_meta(
|
||||
"Citation Link",
|
||||
None,
|
||||
url.clone(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
})
|
||||
.on_click({
|
||||
let url = url.clone();
|
||||
move |_, _, cx| cx.open_url(&url)
|
||||
})
|
||||
},
|
||||
))
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
Err(_) => None,
|
||||
});
|
||||
|
||||
v_flex().my_2().gap_1().child(header).children(content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
create table subscription_usages (
|
||||
id serial primary key,
|
||||
user_id integer not null,
|
||||
period_start_at timestamp without time zone not null,
|
||||
period_end_at timestamp without time zone not null,
|
||||
model_requests int not null default 0,
|
||||
edit_predictions int not null default 0
|
||||
);
|
||||
|
||||
create unique index uix_subscription_usages_on_user_id_start_at_end_at on subscription_usages (user_id, period_start_at, period_end_at);
|
||||
@@ -15,10 +15,12 @@ use stripe::{
|
||||
BillingPortalSession, CancellationDetailsReason, CreateBillingPortalSession,
|
||||
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
use crate::api::events::SnowflakeRow;
|
||||
use crate::db::billing_subscription::{
|
||||
@@ -52,6 +54,7 @@ pub fn router() -> Router {
|
||||
post(manage_billing_subscription),
|
||||
)
|
||||
.route("/billing/monthly_spend", get(get_monthly_spend))
|
||||
.route("/billing/usage", get(get_current_usage))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -159,6 +162,7 @@ struct BillingSubscriptionJson {
|
||||
id: BillingSubscriptionId,
|
||||
name: String,
|
||||
status: StripeSubscriptionStatus,
|
||||
trial_end_at: Option<String>,
|
||||
cancel_at: Option<String>,
|
||||
/// Whether this subscription can be canceled.
|
||||
is_cancelable: bool,
|
||||
@@ -188,9 +192,21 @@ async fn list_billing_subscriptions(
|
||||
id: subscription.id,
|
||||
name: match subscription.kind {
|
||||
Some(SubscriptionKind::ZedPro) => "Zed Pro".to_string(),
|
||||
Some(SubscriptionKind::ZedProTrial) => "Zed Pro (Trial)".to_string(),
|
||||
Some(SubscriptionKind::ZedFree) => "Zed Free".to_string(),
|
||||
None => "Zed LLM Usage".to_string(),
|
||||
},
|
||||
status: subscription.stripe_subscription_status,
|
||||
trial_end_at: if subscription.kind == Some(SubscriptionKind::ZedProTrial) {
|
||||
maybe!({
|
||||
let end_at = subscription.stripe_current_period_end?;
|
||||
let end_at = DateTime::from_timestamp(end_at, 0)?;
|
||||
|
||||
Some(end_at.to_rfc3339_opts(SecondsFormat::Millis, true))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
cancel_at: subscription.stripe_cancel_at.map(|cancel_at| {
|
||||
cancel_at
|
||||
.and_utc()
|
||||
@@ -207,6 +223,7 @@ async fn list_billing_subscriptions(
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ProductCode {
|
||||
ZedPro,
|
||||
ZedProTrial,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -286,24 +303,36 @@ async fn create_billing_subscription(
|
||||
customer.id
|
||||
};
|
||||
|
||||
let success_url = format!(
|
||||
"{}/account?checkout_complete=1",
|
||||
app.config.zed_dot_dev_url()
|
||||
);
|
||||
|
||||
let checkout_session_url = match body.product {
|
||||
Some(ProductCode::ZedPro) => {
|
||||
let success_url = format!(
|
||||
"{}/account?checkout_complete=1",
|
||||
app.config.zed_dot_dev_url()
|
||||
);
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
|
||||
.checkout_with_price(
|
||||
app.config.zed_pro_price_id()?,
|
||||
customer_id,
|
||||
&user.github_login,
|
||||
&success_url,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedProTrial) => {
|
||||
stripe_billing
|
||||
.checkout_with_price(
|
||||
app.config.zed_pro_trial_price_id()?,
|
||||
customer_id,
|
||||
&user.github_login,
|
||||
&success_url,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
let default_model =
|
||||
llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
|
||||
let stripe_model = stripe_billing.register_model(default_model).await?;
|
||||
let success_url = format!(
|
||||
"{}/account?checkout_complete=1",
|
||||
app.config.zed_dot_dev_url()
|
||||
);
|
||||
stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?
|
||||
@@ -322,6 +351,8 @@ enum ManageSubscriptionIntent {
|
||||
///
|
||||
/// This will open the Stripe billing portal without putting the user in a specific flow.
|
||||
ManageSubscription,
|
||||
/// The user intends to upgrade to Zed Pro.
|
||||
UpgradeToPro,
|
||||
/// The user intends to cancel their subscription.
|
||||
Cancel,
|
||||
/// The user intends to stop the cancellation of their subscription.
|
||||
@@ -373,11 +404,10 @@ async fn manage_billing_subscription(
|
||||
.get_billing_subscription_by_id(body.subscription_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("subscription not found"))?;
|
||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
||||
.context("failed to parse subscription ID")?;
|
||||
|
||||
if body.intent == ManageSubscriptionIntent::StopCancellation {
|
||||
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
|
||||
.context("failed to parse subscription ID")?;
|
||||
|
||||
let updated_stripe_subscription = Subscription::update(
|
||||
&stripe_client,
|
||||
&subscription_id,
|
||||
@@ -410,6 +440,47 @@ async fn manage_billing_subscription(
|
||||
|
||||
let flow = match body.intent {
|
||||
ManageSubscriptionIntent::ManageSubscription => None,
|
||||
ManageSubscriptionIntent::UpgradeToPro => {
|
||||
let zed_pro_price_id = app.config.zed_pro_price_id()?;
|
||||
let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id()?;
|
||||
let zed_free_price_id = app.config.zed_free_price_id()?;
|
||||
|
||||
let stripe_subscription =
|
||||
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
|
||||
|
||||
let subscription_item_to_update = stripe_subscription
|
||||
.items
|
||||
.data
|
||||
.iter()
|
||||
.find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
|
||||
if price.id == zed_free_price_id || price.id == zed_pro_trial_price_id {
|
||||
Some(item.id.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| anyhow!("No subscription item to update"))?;
|
||||
|
||||
Some(CreateBillingPortalSessionFlowData {
|
||||
type_: CreateBillingPortalSessionFlowDataType::SubscriptionUpdateConfirm,
|
||||
subscription_update_confirm: Some(
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm {
|
||||
subscription: subscription.stripe_subscription_id,
|
||||
items: vec![
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems {
|
||||
id: subscription_item_to_update.to_string(),
|
||||
price: Some(zed_pro_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
},
|
||||
],
|
||||
discounts: None,
|
||||
},
|
||||
),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
ManageSubscriptionIntent::Cancel => Some(CreateBillingPortalSessionFlowData {
|
||||
type_: CreateBillingPortalSessionFlowDataType::SubscriptionCancel,
|
||||
after_completion: Some(CreateBillingPortalSessionFlowDataAfterCompletion {
|
||||
@@ -696,22 +767,25 @@ async fn handle_customer_subscription_event(
|
||||
|
||||
log::info!("handling Stripe {} event: {}", event.type_, event.id);
|
||||
|
||||
let subscription_kind =
|
||||
if let Some(zed_pro_price_id) = app.config.stripe_zed_pro_price_id.as_deref() {
|
||||
let has_zed_pro_price = subscription.items.data.iter().any(|item| {
|
||||
item.price
|
||||
.as_ref()
|
||||
.map_or(false, |price| price.id.as_str() == zed_pro_price_id)
|
||||
});
|
||||
let subscription_kind = maybe!({
|
||||
let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
|
||||
let zed_pro_trial_price_id = app.config.zed_pro_trial_price_id().ok()?;
|
||||
let zed_free_price_id = app.config.zed_free_price_id().ok()?;
|
||||
|
||||
if has_zed_pro_price {
|
||||
subscription.items.data.iter().find_map(|item| {
|
||||
let price = item.price.as_ref()?;
|
||||
|
||||
if price.id == zed_pro_price_id {
|
||||
Some(SubscriptionKind::ZedPro)
|
||||
} else if price.id == zed_pro_trial_price_id {
|
||||
Some(SubscriptionKind::ZedProTrial)
|
||||
} else if price.id == zed_free_price_id {
|
||||
Some(SubscriptionKind::ZedFree)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
})
|
||||
});
|
||||
|
||||
let billing_customer =
|
||||
find_or_create_billing_customer(app, stripe_client, subscription.customer)
|
||||
@@ -874,6 +948,93 @@ async fn get_monthly_spend(
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GetCurrentUsageParams {
|
||||
github_user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UsageCounts {
|
||||
pub used: i32,
|
||||
pub limit: Option<i32>,
|
||||
pub remaining: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GetCurrentUsageResponse {
|
||||
pub model_requests: UsageCounts,
|
||||
pub edit_predictions: UsageCounts,
|
||||
}
|
||||
|
||||
async fn get_current_usage(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
Query(params): Query<GetCurrentUsageParams>,
|
||||
) -> Result<Json<GetCurrentUsageResponse>> {
|
||||
let user = app
|
||||
.db
|
||||
.get_user_by_github_user_id(params.github_user_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
return Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"LLM database not available".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let empty_usage = GetCurrentUsageResponse {
|
||||
model_requests: UsageCounts {
|
||||
used: 0,
|
||||
limit: Some(0),
|
||||
remaining: Some(0),
|
||||
},
|
||||
edit_predictions: UsageCounts {
|
||||
used: 0,
|
||||
limit: Some(0),
|
||||
remaining: Some(0),
|
||||
},
|
||||
};
|
||||
|
||||
let Some(subscription) = app.db.get_active_billing_subscription(user.id).await? else {
|
||||
return Ok(Json(empty_usage));
|
||||
};
|
||||
|
||||
let subscription_period = maybe!({
|
||||
let period_start_at = subscription.current_period_start_at()?;
|
||||
let period_end_at = subscription.current_period_end_at()?;
|
||||
|
||||
Some((period_start_at, period_end_at))
|
||||
});
|
||||
|
||||
let Some((period_start_at, period_end_at)) = subscription_period else {
|
||||
return Ok(Json(empty_usage));
|
||||
};
|
||||
|
||||
let usage = llm_db
|
||||
.get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
|
||||
.await?;
|
||||
let Some(usage) = usage else {
|
||||
return Ok(Json(empty_usage));
|
||||
};
|
||||
|
||||
let model_requests_limit = Some(500);
|
||||
let edit_prediction_limit = Some(2000);
|
||||
|
||||
Ok(Json(GetCurrentUsageResponse {
|
||||
model_requests: UsageCounts {
|
||||
used: usage.model_requests,
|
||||
limit: model_requests_limit,
|
||||
remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
|
||||
},
|
||||
edit_predictions: UsageCounts {
|
||||
used: usage.edit_predictions,
|
||||
limit: edit_prediction_limit,
|
||||
remaining: edit_prediction_limit.map(|limit| (limit - usage.edit_predictions).max(0)),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
impl From<SubscriptionStatus> for StripeSubscriptionStatus {
|
||||
fn from(value: SubscriptionStatus) -> Self {
|
||||
match value {
|
||||
|
||||
@@ -62,11 +62,14 @@ impl Database {
|
||||
billing_subscription::Entity::update(billing_subscription::ActiveModel {
|
||||
id: ActiveValue::set(id),
|
||||
billing_customer_id: params.billing_customer_id.clone(),
|
||||
kind: params.kind.clone(),
|
||||
stripe_subscription_id: params.stripe_subscription_id.clone(),
|
||||
stripe_subscription_status: params.stripe_subscription_status.clone(),
|
||||
stripe_cancel_at: params.stripe_cancel_at.clone(),
|
||||
stripe_cancellation_reason: params.stripe_cancellation_reason.clone(),
|
||||
..Default::default()
|
||||
stripe_current_period_start: params.stripe_current_period_start.clone(),
|
||||
stripe_current_period_end: params.stripe_current_period_end.clone(),
|
||||
created_at: ActiveValue::not_set(),
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
@@ -105,6 +108,28 @@ impl Database {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_active_billing_subscription(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<billing_subscription::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(billing_subscription::Entity::find()
|
||||
.inner_join(billing_customer::Entity)
|
||||
.filter(billing_customer::Column::UserId.eq(user_id))
|
||||
.filter(
|
||||
Condition::all()
|
||||
.add(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.add(billing_subscription::Column::Kind.is_not_null()),
|
||||
)
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Returns all of the billing subscriptions for the user with the specified ID.
|
||||
///
|
||||
/// Note that this returns the subscriptions regardless of their status.
|
||||
@@ -142,6 +167,7 @@ impl Database {
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
)
|
||||
.filter(billing_subscription::Column::Kind.is_null())
|
||||
.order_by_asc(billing_subscription::Column::Id)
|
||||
.stream(&*tx)
|
||||
.await?;
|
||||
|
||||
@@ -19,6 +19,18 @@ pub struct Model {
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn current_period_start_at(&self) -> Option<DateTimeUtc> {
|
||||
let period_start = self.stripe_current_period_start?;
|
||||
chrono::DateTime::from_timestamp(period_start, 0)
|
||||
}
|
||||
|
||||
pub fn current_period_end_at(&self) -> Option<DateTimeUtc> {
|
||||
let period_end = self.stripe_current_period_end?;
|
||||
chrono::DateTime::from_timestamp(period_end, 0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
@@ -43,6 +55,10 @@ impl ActiveModelBehavior for ActiveModel {}
|
||||
pub enum SubscriptionKind {
|
||||
#[sea_orm(string_value = "zed_pro")]
|
||||
ZedPro,
|
||||
#[sea_orm(string_value = "zed_pro_trial")]
|
||||
ZedProTrial,
|
||||
#[sea_orm(string_value = "zed_free")]
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
/// The status of a Stripe subscription.
|
||||
|
||||
@@ -183,6 +183,8 @@ pub struct Config {
|
||||
pub auto_join_channel_id: Option<ChannelId>,
|
||||
pub stripe_api_key: Option<String>,
|
||||
pub stripe_zed_pro_price_id: Option<String>,
|
||||
pub stripe_zed_pro_trial_price_id: Option<String>,
|
||||
pub stripe_zed_free_price_id: Option<String>,
|
||||
pub supermaven_admin_api_key: Option<Arc<str>>,
|
||||
pub user_backfiller_github_access_token: Option<Arc<str>>,
|
||||
}
|
||||
@@ -201,6 +203,29 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
|
||||
Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
|
||||
}
|
||||
|
||||
pub fn zed_pro_trial_price_id(&self) -> anyhow::Result<stripe::PriceId> {
|
||||
Self::parse_stripe_price_id(
|
||||
"Zed Pro Trial",
|
||||
self.stripe_zed_pro_trial_price_id.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
|
||||
Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
|
||||
}
|
||||
|
||||
fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
|
||||
use std::str::FromStr as _;
|
||||
|
||||
let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
|
||||
|
||||
Ok(stripe::PriceId::from_str(price_id)?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
Self {
|
||||
@@ -239,6 +264,8 @@ impl Config {
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_zed_pro_price_id: None,
|
||||
stripe_zed_pro_trial_price_id: None,
|
||||
stripe_zed_free_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
@@ -324,12 +351,9 @@ impl AppState {
|
||||
llm_db,
|
||||
livekit_client,
|
||||
blob_store_client: build_blob_store_client(&config).await.log_err(),
|
||||
stripe_billing: stripe_client.clone().map(|stripe_client| {
|
||||
Arc::new(StripeBilling::new(
|
||||
stripe_client,
|
||||
config.stripe_zed_pro_price_id.clone(),
|
||||
))
|
||||
}),
|
||||
stripe_billing: stripe_client
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
|
||||
@@ -2,4 +2,5 @@ use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod subscription_usages;
|
||||
pub mod usages;
|
||||
|
||||
22
crates/collab/src/llm/db/queries/subscription_usages.rs
Normal file
22
crates/collab/src/llm/db/queries/subscription_usages.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use crate::db::UserId;
|
||||
|
||||
use super::*;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_subscription_usage_for_period(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
period_start_at: DateTimeUtc,
|
||||
period_end_at: DateTimeUtc,
|
||||
) -> Result<Option<subscription_usage::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
Ok(subscription_usage::Entity::find()
|
||||
.filter(subscription_usage::Column::UserId.eq(user_id))
|
||||
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
|
||||
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
|
||||
.one(&*tx)
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -2,5 +2,6 @@ pub mod billing_event;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
pub mod provider;
|
||||
pub mod subscription_usage;
|
||||
pub mod usage;
|
||||
pub mod usage_measure;
|
||||
|
||||
20
crates/collab/src/llm/db/tables/subscription_usage.rs
Normal file
20
crates/collab/src/llm/db/tables/subscription_usage.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use crate::db::UserId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
use time::PrimitiveDateTime;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "subscription_usages")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: i32,
|
||||
pub user_id: UserId,
|
||||
pub period_start_at: PrimitiveDateTime,
|
||||
pub period_end_at: PrimitiveDateTime,
|
||||
pub model_requests: i32,
|
||||
pub edit_predictions: i32,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::Cents;
|
||||
use crate::db::user;
|
||||
use crate::db::{billing_subscription, user};
|
||||
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
|
||||
use crate::{Config, db::billing_preference};
|
||||
use anyhow::{Result, anyhow};
|
||||
@@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use util::maybe;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
@@ -29,6 +30,8 @@ pub struct LlmTokenClaims {
|
||||
pub max_monthly_spend_in_cents: u32,
|
||||
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
|
||||
pub plan: rpc::proto::Plan,
|
||||
#[serde(default)]
|
||||
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
|
||||
}
|
||||
|
||||
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
|
||||
@@ -39,8 +42,9 @@ impl LlmTokenClaims {
|
||||
is_staff: bool,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
feature_flags: &Vec<String>,
|
||||
has_llm_subscription: bool,
|
||||
has_legacy_llm_subscription: bool,
|
||||
plan: rpc::proto::Plan,
|
||||
subscription: Option<billing_subscription::Model>,
|
||||
system_id: Option<String>,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
@@ -69,7 +73,7 @@ impl LlmTokenClaims {
|
||||
has_predict_edits_feature_flag: feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == "predict-edits"),
|
||||
has_llm_subscription,
|
||||
has_llm_subscription: has_legacy_llm_subscription,
|
||||
max_monthly_spend_in_cents: billing_preferences
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
|
||||
preferences.max_monthly_llm_usage_spending_in_cents as u32
|
||||
@@ -78,6 +82,13 @@ impl LlmTokenClaims {
|
||||
.custom_llm_monthly_allowance_in_cents
|
||||
.map(|allowance| allowance as u32),
|
||||
plan,
|
||||
subscription_period: maybe!({
|
||||
let subscription = subscription?;
|
||||
let period_start_at = subscription.current_period_start_at()?;
|
||||
let period_end_at = subscription.current_period_end_at()?;
|
||||
|
||||
Some((period_start_at.naive_utc(), period_end_at.naive_utc()))
|
||||
}),
|
||||
};
|
||||
|
||||
Ok(jsonwebtoken::encode(
|
||||
|
||||
@@ -4135,7 +4135,8 @@ async fn get_llm_api_token(
|
||||
Err(anyhow!("terms of service not accepted"))?
|
||||
}
|
||||
|
||||
let has_llm_subscription = session.has_llm_subscription(&db).await?;
|
||||
let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
|
||||
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
@@ -4143,8 +4144,9 @@ async fn get_llm_api_token(
|
||||
session.is_staff(),
|
||||
billing_preferences,
|
||||
&flags,
|
||||
has_llm_subscription,
|
||||
has_legacy_llm_subscription,
|
||||
session.current_plan(&db).await?,
|
||||
billing_subscription,
|
||||
session.system_id.clone(),
|
||||
&session.app_state.config,
|
||||
)?;
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{Cents, Result, llm};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use anyhow::Context as _;
|
||||
use chrono::{Datelike, Utc};
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::PriceId;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct StripeBilling {
|
||||
state: RwLock<StripeBillingState>,
|
||||
client: Arc<stripe::Client>,
|
||||
zed_pro_price_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -32,11 +32,10 @@ struct StripeBillingPrice {
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>, zed_pro_price_id: Option<String>) -> Self {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
client,
|
||||
state: RwLock::default(),
|
||||
zed_pro_price_id,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,23 +384,19 @@ impl StripeBilling {
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_pro(
|
||||
pub async fn checkout_with_price(
|
||||
&self,
|
||||
price_id: PriceId,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let zed_pro_price_id = self
|
||||
.zed_pro_price_id
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("Zed Pro price ID not set"))?;
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(zed_pro_price_id.clone()),
|
||||
price: Some(price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
|
||||
@@ -558,6 +558,8 @@ impl TestServer {
|
||||
seed_path: None,
|
||||
stripe_api_key: None,
|
||||
stripe_zed_pro_price_id: None,
|
||||
stripe_zed_pro_trial_price_id: None,
|
||||
stripe_zed_free_price_id: None,
|
||||
supermaven_admin_api_key: None,
|
||||
user_backfiller_github_access_token: None,
|
||||
kinesis_region: None,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
|
||||
use gpui::{App, Entity, Task};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use icons::IconName;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -77,6 +77,7 @@ impl Tool for ContextServerTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
|
||||
|
||||
@@ -12,8 +12,9 @@ use dap::{
|
||||
};
|
||||
use futures::{SinkExt as _, channel::mpsc};
|
||||
use gpui::{
|
||||
Action, App, AsyncWindowContext, Context, Entity, EntityId, EventEmitter, FocusHandle,
|
||||
Focusable, Subscription, Task, WeakEntity, actions,
|
||||
Action, App, AsyncWindowContext, Context, DismissEvent, Entity, EntityId, EventEmitter,
|
||||
FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task, WeakEntity,
|
||||
actions, anchored, deferred,
|
||||
};
|
||||
|
||||
use project::{
|
||||
@@ -64,6 +65,7 @@ pub struct DebugPanel {
|
||||
project: WeakEntity<Project>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
focus_handle: FocusHandle,
|
||||
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -126,6 +128,7 @@ impl DebugPanel {
|
||||
focus_handle: cx.focus_handle(),
|
||||
project: project.downgrade(),
|
||||
workspace: workspace.weak_handle(),
|
||||
context_menu: None,
|
||||
};
|
||||
|
||||
debug_panel
|
||||
@@ -438,7 +441,13 @@ impl DebugPanel {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
session.update(cx, |this, cx| {
|
||||
if let Some(running) = this.mode().as_running() {
|
||||
running.update(cx, |this, cx| {
|
||||
this.serialize_layout(window, cx);
|
||||
});
|
||||
}
|
||||
});
|
||||
let session_id = session.update(cx, |this, cx| this.session_id(cx));
|
||||
let should_prompt = self
|
||||
.project
|
||||
@@ -567,6 +576,57 @@ impl DebugPanel {
|
||||
)
|
||||
}
|
||||
|
||||
fn deploy_context_menu(
|
||||
&mut self,
|
||||
position: Point<Pixels>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(running_state) = self
|
||||
.active_session
|
||||
.as_ref()
|
||||
.and_then(|session| session.read(cx).mode().as_running().cloned())
|
||||
{
|
||||
let pane_items_status = running_state.read(cx).pane_items_status(cx);
|
||||
let this = cx.weak_entity();
|
||||
|
||||
let context_menu = ContextMenu::build(window, cx, |mut menu, _window, _cx| {
|
||||
for (item_kind, is_visible) in pane_items_status.into_iter() {
|
||||
menu = menu.toggleable_entry(item_kind, is_visible, IconPosition::End, None, {
|
||||
let this = this.clone();
|
||||
move |window, cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
if let Some(running_state) =
|
||||
this.active_session.as_ref().and_then(|session| {
|
||||
session.read(cx).mode().as_running().cloned()
|
||||
})
|
||||
{
|
||||
running_state.update(cx, |state, cx| {
|
||||
if is_visible {
|
||||
state.remove_pane_item(item_kind, window, cx);
|
||||
} else {
|
||||
state.add_pane_item(item_kind, position, window, cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
menu
|
||||
});
|
||||
|
||||
window.focus(&context_menu.focus_handle(cx));
|
||||
let subscription = cx.subscribe(&context_menu, |this, _, _: &DismissEvent, cx| {
|
||||
this.context_menu.take();
|
||||
cx.notify();
|
||||
});
|
||||
self.context_menu = Some((context_menu, position, subscription));
|
||||
}
|
||||
}
|
||||
|
||||
fn top_controls_strip(&self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
|
||||
let active_session = self.active_session.clone();
|
||||
|
||||
@@ -891,11 +951,49 @@ impl Render for DebugPanel {
|
||||
let has_sessions = self.sessions.len() > 0;
|
||||
debug_assert_eq!(has_sessions, self.active_session.is_some());
|
||||
|
||||
if self
|
||||
.active_session
|
||||
.as_ref()
|
||||
.and_then(|session| session.read(cx).mode().as_running().cloned())
|
||||
.map(|state| state.read(cx).has_open_context_menu(cx))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.context_menu.take();
|
||||
}
|
||||
|
||||
v_flex()
|
||||
.size_full()
|
||||
.key_context("DebugPanel")
|
||||
.child(h_flex().children(self.top_controls_strip(window, cx)))
|
||||
.track_focus(&self.focus_handle(cx))
|
||||
.when(self.active_session.is_some(), |this| {
|
||||
this.on_mouse_down(
|
||||
MouseButton::Right,
|
||||
cx.listener(|this, event: &MouseDownEvent, window, cx| {
|
||||
if this
|
||||
.active_session
|
||||
.as_ref()
|
||||
.and_then(|session| {
|
||||
session.read(cx).mode().as_running().map(|state| {
|
||||
state.read(cx).has_pane_at_position(event.position)
|
||||
})
|
||||
})
|
||||
.unwrap_or(false)
|
||||
{
|
||||
this.deploy_context_menu(event.position, window, cx);
|
||||
}
|
||||
}),
|
||||
)
|
||||
.children(self.context_menu.as_ref().map(|(menu, position, _)| {
|
||||
deferred(
|
||||
anchored()
|
||||
.position(*position)
|
||||
.anchor(gpui::Corner::TopLeft)
|
||||
.child(menu.clone()),
|
||||
)
|
||||
.with_priority(1)
|
||||
}))
|
||||
})
|
||||
.map(|this| {
|
||||
if has_sessions {
|
||||
this.children(self.active_session.clone())
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use collections::HashMap;
|
||||
use dap::Capabilities;
|
||||
use db::kvp::KEY_VALUE_STORE;
|
||||
use gpui::{Axis, Context, Entity, EntityId, Focusable, Subscription, WeakEntity, Window};
|
||||
use project::Project;
|
||||
@@ -9,19 +10,43 @@ use workspace::{Member, Pane, PaneAxis, Workspace};
|
||||
|
||||
use crate::session::running::{
|
||||
self, RunningState, SubView, breakpoint_list::BreakpointList, console::Console,
|
||||
module_list::ModuleList, stack_frame_list::StackFrameList, variable_list::VariableList,
|
||||
loaded_source_list::LoadedSourceList, module_list::ModuleList,
|
||||
stack_frame_list::StackFrameList, variable_list::VariableList,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[derive(Clone, Hash, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) enum DebuggerPaneItem {
|
||||
Console,
|
||||
Variables,
|
||||
BreakpointList,
|
||||
Frames,
|
||||
Modules,
|
||||
LoadedSources,
|
||||
}
|
||||
|
||||
impl DebuggerPaneItem {
|
||||
pub(crate) fn all() -> &'static [DebuggerPaneItem] {
|
||||
static VARIANTS: &[DebuggerPaneItem] = &[
|
||||
DebuggerPaneItem::Console,
|
||||
DebuggerPaneItem::Variables,
|
||||
DebuggerPaneItem::BreakpointList,
|
||||
DebuggerPaneItem::Frames,
|
||||
DebuggerPaneItem::Modules,
|
||||
DebuggerPaneItem::LoadedSources,
|
||||
];
|
||||
VARIANTS
|
||||
}
|
||||
|
||||
pub(crate) fn is_supported(&self, capabilities: &Capabilities) -> bool {
|
||||
match self {
|
||||
DebuggerPaneItem::Modules => capabilities.supports_modules_request.unwrap_or_default(),
|
||||
DebuggerPaneItem::LoadedSources => capabilities
|
||||
.supports_loaded_sources_request
|
||||
.unwrap_or_default(),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_shared_string(self) -> SharedString {
|
||||
match self {
|
||||
DebuggerPaneItem::Console => SharedString::new_static("Console"),
|
||||
@@ -29,10 +54,17 @@ impl DebuggerPaneItem {
|
||||
DebuggerPaneItem::BreakpointList => SharedString::new_static("Breakpoints"),
|
||||
DebuggerPaneItem::Frames => SharedString::new_static("Frames"),
|
||||
DebuggerPaneItem::Modules => SharedString::new_static("Modules"),
|
||||
DebuggerPaneItem::LoadedSources => SharedString::new_static("Sources"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DebuggerPaneItem> for SharedString {
|
||||
fn from(item: DebuggerPaneItem) -> Self {
|
||||
item.to_shared_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct SerializedAxis(pub Axis);
|
||||
|
||||
@@ -136,6 +168,7 @@ pub(crate) fn deserialize_pane_layout(
|
||||
module_list: &Entity<ModuleList>,
|
||||
console: &Entity<Console>,
|
||||
breakpoint_list: &Entity<BreakpointList>,
|
||||
loaded_sources: &Entity<LoadedSourceList>,
|
||||
subscriptions: &mut HashMap<EntityId, Subscription>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<RunningState>,
|
||||
@@ -157,6 +190,7 @@ pub(crate) fn deserialize_pane_layout(
|
||||
module_list,
|
||||
console,
|
||||
breakpoint_list,
|
||||
loaded_sources,
|
||||
subscriptions,
|
||||
window,
|
||||
cx,
|
||||
@@ -191,7 +225,7 @@ pub(crate) fn deserialize_pane_layout(
|
||||
.iter()
|
||||
.map(|child| match child {
|
||||
DebuggerPaneItem::Frames => Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
stack_frame_list.focus_handle(cx),
|
||||
stack_frame_list.clone().into(),
|
||||
DebuggerPaneItem::Frames,
|
||||
None,
|
||||
@@ -212,13 +246,19 @@ pub(crate) fn deserialize_pane_layout(
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Modules => Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
module_list.focus_handle(cx),
|
||||
module_list.clone().into(),
|
||||
DebuggerPaneItem::Modules,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
|
||||
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
|
||||
loaded_sources.focus_handle(cx),
|
||||
loaded_sources.clone().into(),
|
||||
DebuggerPaneItem::LoadedSources,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Console => Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
console.clone().into(),
|
||||
|
||||
@@ -11,12 +11,12 @@ use crate::persistence::{self, DebuggerPaneItem, SerializedPaneLayout};
|
||||
|
||||
use super::DebugPanelItemEvent;
|
||||
use breakpoint_list::BreakpointList;
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, IndexMap};
|
||||
use console::Console;
|
||||
use dap::{Capabilities, Thread, client::SessionId, debugger_settings::DebuggerSettings};
|
||||
use gpui::{
|
||||
Action as _, AnyView, AppContext, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
|
||||
NoAction, Subscription, Task, WeakEntity,
|
||||
NoAction, Pixels, Point, Subscription, Task, WeakEntity,
|
||||
};
|
||||
use loaded_source_list::LoadedSourceList;
|
||||
use module_list::ModuleList;
|
||||
@@ -49,8 +49,10 @@ pub struct RunningState {
|
||||
variable_list: Entity<variable_list::VariableList>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
stack_frame_list: Entity<stack_frame_list::StackFrameList>,
|
||||
_module_list: Entity<module_list::ModuleList>,
|
||||
loaded_sources_list: Entity<LoadedSourceList>,
|
||||
module_list: Entity<module_list::ModuleList>,
|
||||
_console: Entity<Console>,
|
||||
breakpoint_list: Entity<BreakpointList>,
|
||||
panes: PaneGroup,
|
||||
pane_close_subscriptions: HashMap<EntityId, Subscription>,
|
||||
_schedule_serialize: Option<Task<()>>,
|
||||
@@ -383,7 +385,6 @@ impl RunningState {
|
||||
|
||||
let module_list = cx.new(|cx| ModuleList::new(session.clone(), workspace.clone(), cx));
|
||||
|
||||
#[expect(unused)]
|
||||
let loaded_source_list = cx.new(|cx| LoadedSourceList::new(session.clone(), cx));
|
||||
|
||||
let console = cx.new(|cx| {
|
||||
@@ -396,7 +397,7 @@ impl RunningState {
|
||||
)
|
||||
});
|
||||
|
||||
let breakpoints = BreakpointList::new(session.clone(), workspace.clone(), &project, cx);
|
||||
let breakpoint_list = BreakpointList::new(session.clone(), workspace.clone(), &project, cx);
|
||||
|
||||
let _subscriptions = vec![
|
||||
cx.observe(&module_list, |_, _, cx| cx.notify()),
|
||||
@@ -421,6 +422,9 @@ impl RunningState {
|
||||
}
|
||||
cx.notify()
|
||||
}),
|
||||
cx.on_focus_out(&focus_handle, window, |this, _, window, cx| {
|
||||
this.serialize_layout(window, cx);
|
||||
}),
|
||||
];
|
||||
|
||||
let mut pane_close_subscriptions = HashMap::default();
|
||||
@@ -433,7 +437,8 @@ impl RunningState {
|
||||
&variable_list,
|
||||
&module_list,
|
||||
&console,
|
||||
&breakpoints,
|
||||
&breakpoint_list,
|
||||
&loaded_source_list,
|
||||
&mut pane_close_subscriptions,
|
||||
window,
|
||||
cx,
|
||||
@@ -449,7 +454,7 @@ impl RunningState {
|
||||
&variable_list,
|
||||
&module_list,
|
||||
&console,
|
||||
breakpoints,
|
||||
&breakpoint_list,
|
||||
&mut pane_close_subscriptions,
|
||||
window,
|
||||
cx,
|
||||
@@ -469,14 +474,140 @@ impl RunningState {
|
||||
stack_frame_list,
|
||||
session_id,
|
||||
panes,
|
||||
_module_list: module_list,
|
||||
module_list,
|
||||
_console: console,
|
||||
breakpoint_list,
|
||||
loaded_sources_list: loaded_source_list,
|
||||
pane_close_subscriptions,
|
||||
_schedule_serialize: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_layout(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
pub(crate) fn remove_pane_item(
|
||||
&mut self,
|
||||
item_kind: DebuggerPaneItem,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
debug_assert!(
|
||||
item_kind.is_supported(self.session.read(cx).capabilities()),
|
||||
"We should only allow removing supported item kinds"
|
||||
);
|
||||
|
||||
if let Some((pane, item_id)) = self.panes.panes().iter().find_map(|pane| {
|
||||
Some(pane).zip(
|
||||
pane.read(cx)
|
||||
.items()
|
||||
.find(|item| {
|
||||
item.act_as::<SubView>(cx)
|
||||
.is_some_and(|view| view.read(cx).kind == item_kind)
|
||||
})
|
||||
.map(|item| item.item_id()),
|
||||
)
|
||||
}) {
|
||||
pane.update(cx, |pane, cx| {
|
||||
pane.remove_item(item_id, false, true, window, cx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn has_pane_at_position(&self, position: Point<Pixels>) -> bool {
|
||||
self.panes.pane_at_pixel_position(position).is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn add_pane_item(
|
||||
&mut self,
|
||||
item_kind: DebuggerPaneItem,
|
||||
position: Point<Pixels>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
debug_assert!(
|
||||
item_kind.is_supported(self.session.read(cx).capabilities()),
|
||||
"We should only allow adding supported item kinds"
|
||||
);
|
||||
|
||||
if let Some(pane) = self.panes.pane_at_pixel_position(position) {
|
||||
let sub_view = match item_kind {
|
||||
DebuggerPaneItem::Console => {
|
||||
let weak_console = self._console.clone().downgrade();
|
||||
|
||||
Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
self._console.clone().into(),
|
||||
item_kind,
|
||||
Some(Box::new(move |cx| {
|
||||
weak_console
|
||||
.read_with(cx, |console, cx| console.show_indicator(cx))
|
||||
.unwrap_or_default()
|
||||
})),
|
||||
cx,
|
||||
))
|
||||
}
|
||||
DebuggerPaneItem::Variables => Box::new(SubView::new(
|
||||
self.variable_list.focus_handle(cx),
|
||||
self.variable_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::BreakpointList => Box::new(SubView::new(
|
||||
self.breakpoint_list.focus_handle(cx),
|
||||
self.breakpoint_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Frames => Box::new(SubView::new(
|
||||
self.stack_frame_list.focus_handle(cx),
|
||||
self.stack_frame_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Modules => Box::new(SubView::new(
|
||||
self.module_list.focus_handle(cx),
|
||||
self.module_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
|
||||
self.loaded_sources_list.focus_handle(cx),
|
||||
self.loaded_sources_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
};
|
||||
|
||||
pane.update(cx, |pane, cx| {
|
||||
pane.add_item(sub_view, false, false, None, window, cx);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn pane_items_status(&self, cx: &App) -> IndexMap<DebuggerPaneItem, bool> {
|
||||
let caps = self.session.read(cx).capabilities();
|
||||
let mut pane_item_status = IndexMap::from_iter(
|
||||
DebuggerPaneItem::all()
|
||||
.iter()
|
||||
.filter(|kind| kind.is_supported(&caps))
|
||||
.map(|kind| (*kind, false)),
|
||||
);
|
||||
self.panes.panes().iter().for_each(|pane| {
|
||||
pane.read(cx)
|
||||
.items()
|
||||
.filter_map(|item| item.act_as::<SubView>(cx))
|
||||
.for_each(|view| {
|
||||
pane_item_status.insert(view.read(cx).kind, true);
|
||||
});
|
||||
});
|
||||
|
||||
pane_item_status
|
||||
}
|
||||
|
||||
pub(crate) fn serialize_layout(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self._schedule_serialize.is_none() {
|
||||
self._schedule_serialize = Some(cx.spawn_in(window, async move |this, cx| {
|
||||
cx.background_executor()
|
||||
@@ -530,6 +661,10 @@ impl RunningState {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn has_open_context_menu(&self, cx: &App) -> bool {
|
||||
self.variable_list.read(cx).has_open_context_menu()
|
||||
}
|
||||
|
||||
pub fn session(&self) -> &Entity<Session> {
|
||||
&self.session
|
||||
}
|
||||
@@ -554,7 +689,7 @@ impl RunningState {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn module_list(&self) -> &Entity<ModuleList> {
|
||||
&self._module_list
|
||||
&self.module_list
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -790,7 +925,7 @@ impl RunningState {
|
||||
variable_list: &Entity<VariableList>,
|
||||
module_list: &Entity<ModuleList>,
|
||||
console: &Entity<Console>,
|
||||
breakpoints: Entity<BreakpointList>,
|
||||
breakpoints: &Entity<BreakpointList>,
|
||||
subscriptions: &mut HashMap<EntityId, Subscription>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<'_, RunningState>,
|
||||
@@ -814,7 +949,7 @@ impl RunningState {
|
||||
this.add_item(
|
||||
Box::new(SubView::new(
|
||||
breakpoints.focus_handle(cx),
|
||||
breakpoints.into(),
|
||||
breakpoints.clone().into(),
|
||||
DebuggerPaneItem::BreakpointList,
|
||||
None,
|
||||
cx,
|
||||
|
||||
@@ -3,7 +3,7 @@ use project::debugger::session::{Session, SessionEvent};
|
||||
use ui::prelude::*;
|
||||
use util::maybe;
|
||||
|
||||
pub struct LoadedSourceList {
|
||||
pub(crate) struct LoadedSourceList {
|
||||
list: ListState,
|
||||
invalidate: bool,
|
||||
focus_handle: FocusHandle,
|
||||
|
||||
@@ -194,6 +194,10 @@ impl VariableList {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn has_open_context_menu(&self) -> bool {
|
||||
self.open_context_menu.is_some()
|
||||
}
|
||||
|
||||
fn build_entries(&mut self, cx: &mut Context<Self>) {
|
||||
let Some(stack_frame_id) = self.selected_stack_frame_id else {
|
||||
return;
|
||||
|
||||
@@ -209,6 +209,7 @@ impl ProjectDiagnosticsEditor {
|
||||
.detach();
|
||||
cx.observe_global_in::<IncludeWarnings>(window, |this, window, cx| {
|
||||
this.include_warnings = cx.global::<IncludeWarnings>().0;
|
||||
this.diagnostics.clear();
|
||||
this.update_all_excerpts(window, cx);
|
||||
})
|
||||
.detach();
|
||||
@@ -300,11 +301,8 @@ impl ProjectDiagnosticsEditor {
|
||||
}
|
||||
}
|
||||
|
||||
fn toggle_warnings(&mut self, _: &ToggleWarnings, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.include_warnings = !self.include_warnings;
|
||||
cx.set_global(IncludeWarnings(self.include_warnings));
|
||||
self.update_all_excerpts(window, cx);
|
||||
cx.notify();
|
||||
fn toggle_warnings(&mut self, _: &ToggleWarnings, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.set_global(IncludeWarnings(!self.include_warnings));
|
||||
}
|
||||
|
||||
fn focus_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -482,7 +480,10 @@ impl ProjectDiagnosticsEditor {
|
||||
editor.change_selections(Some(Autoscroll::fit()), window, cx, |s| {
|
||||
s.select_anchor_ranges([range_to_select]);
|
||||
})
|
||||
})
|
||||
});
|
||||
if this.focus_handle.is_focused(window) {
|
||||
this.editor.read(cx).focus_handle(cx).focus(window);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4737,8 +4737,8 @@ impl Editor {
|
||||
let lookahead = replace_range
|
||||
.end
|
||||
.saturating_sub(newest_anchor.end.text_anchor.to_offset(buffer));
|
||||
let prefix = &old_text[..old_text.len() - lookahead];
|
||||
let suffix = &old_text[lookbehind..];
|
||||
let prefix = &old_text[..old_text.len().saturating_sub(lookahead)];
|
||||
let suffix = &old_text[lookbehind.min(old_text.len())..];
|
||||
|
||||
let selections = self.selections.all::<usize>(cx);
|
||||
let mut edits = Vec::new();
|
||||
@@ -4753,7 +4753,7 @@ impl Editor {
|
||||
|
||||
// if prefix is present, don't duplicate it
|
||||
if snapshot.contains_str_at(range.start.saturating_sub(lookbehind), prefix) {
|
||||
text = &new_text[lookbehind..];
|
||||
text = &new_text[lookbehind.min(new_text.len())..];
|
||||
|
||||
// if suffix is also present, mimic the newest cursor and replace it
|
||||
if selection.id != newest_anchor.id
|
||||
@@ -12519,6 +12519,45 @@ impl Editor {
|
||||
.iter()
|
||||
.map(|selection| {
|
||||
let old_range = selection.start..selection.end;
|
||||
|
||||
if let Some((node, _)) = buffer.syntax_ancestor(old_range.clone()) {
|
||||
// manually select word at selection
|
||||
if ["string_content", "inline"].contains(&node.kind()) {
|
||||
let word_range = {
|
||||
let display_point = buffer
|
||||
.offset_to_point(old_range.start)
|
||||
.to_display_point(&display_map);
|
||||
let Range { start, end } =
|
||||
movement::surrounding_word(&display_map, display_point);
|
||||
start.to_point(&display_map).to_offset(&buffer)
|
||||
..end.to_point(&display_map).to_offset(&buffer)
|
||||
};
|
||||
// ignore if word is already selected
|
||||
if !word_range.is_empty() && old_range != word_range {
|
||||
let last_word_range = {
|
||||
let display_point = buffer
|
||||
.offset_to_point(old_range.end)
|
||||
.to_display_point(&display_map);
|
||||
let Range { start, end } =
|
||||
movement::surrounding_word(&display_map, display_point);
|
||||
start.to_point(&display_map).to_offset(&buffer)
|
||||
..end.to_point(&display_map).to_offset(&buffer)
|
||||
};
|
||||
// only select word if start and end point belongs to same word
|
||||
if word_range == last_word_range {
|
||||
selected_larger_node = true;
|
||||
return Selection {
|
||||
id: selection.id,
|
||||
start: word_range.start,
|
||||
end: word_range.end,
|
||||
goal: SelectionGoal::None,
|
||||
reversed: selection.reversed,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut new_range = old_range.clone();
|
||||
let mut new_node = None;
|
||||
while let Some((node, containing_range)) = buffer.syntax_ancestor(new_range.clone())
|
||||
@@ -13723,8 +13762,6 @@ impl Editor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<Result<Navigated>>> {
|
||||
self.hide_mouse_cursor(&HideMouseCursorOrigin::TypingAction);
|
||||
|
||||
let selection = self.selections.newest::<usize>(cx);
|
||||
let multi_buffer = self.buffer.read(cx);
|
||||
let head = selection.head();
|
||||
|
||||
@@ -6309,7 +6309,187 @@ async fn test_select_larger_smaller_syntax_node(cx: &mut TestAppContext) {
|
||||
use mod1::mod2::«{mod3, mod4}ˇ»;
|
||||
|
||||
fn fn_1«ˇ(param1: bool, param2: &str)» {
|
||||
«ˇlet var1 = "text";»
|
||||
let var1 = "«ˇtext»";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_select_larger_smaller_syntax_node_for_string(cx: &mut TestAppContext) {
|
||||
init_test(cx, |_| {});
|
||||
|
||||
let language = Arc::new(Language::new(
|
||||
LanguageConfig::default(),
|
||||
Some(tree_sitter_rust::LANGUAGE.into()),
|
||||
));
|
||||
|
||||
let text = r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "hello world";
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language, cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
let (editor, cx) = cx.add_window_view(|window, cx| build_editor(buffer, window, cx));
|
||||
|
||||
editor
|
||||
.condition::<crate::EditorEvent>(cx, |editor, cx| !editor.buffer.read(cx).is_parsing(cx))
|
||||
.await;
|
||||
|
||||
// Test 1: Cursor on a letter of a string word
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.change_selections(None, window, cx, |s| {
|
||||
s.select_display_ranges([
|
||||
DisplayPoint::new(DisplayRow(3), 17)..DisplayPoint::new(DisplayRow(3), 17)
|
||||
]);
|
||||
});
|
||||
});
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "hˇello world";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "«ˇhello» world";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Test 2: Partial selection within a word
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.change_selections(None, window, cx, |s| {
|
||||
s.select_display_ranges([
|
||||
DisplayPoint::new(DisplayRow(3), 17)..DisplayPoint::new(DisplayRow(3), 19)
|
||||
]);
|
||||
});
|
||||
});
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "h«elˇ»lo world";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "«ˇhello» world";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Test 3: Complete word already selected
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.change_selections(None, window, cx, |s| {
|
||||
s.select_display_ranges([
|
||||
DisplayPoint::new(DisplayRow(3), 16)..DisplayPoint::new(DisplayRow(3), 21)
|
||||
]);
|
||||
});
|
||||
});
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "«helloˇ» world";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "«hello worldˇ»";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Test 4: Selection spanning across words
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.change_selections(None, window, cx, |s| {
|
||||
s.select_display_ranges([
|
||||
DisplayPoint::new(DisplayRow(3), 19)..DisplayPoint::new(DisplayRow(3), 24)
|
||||
]);
|
||||
});
|
||||
});
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "hel«lo woˇ»rld";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
let var1 = "«ˇhello world»";
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// Test 5: Expansion beyond string
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
|
||||
editor.select_larger_syntax_node(&SelectLargerSyntaxNode, window, cx);
|
||||
assert_text_with_selections(
|
||||
editor,
|
||||
indoc! {r#"
|
||||
use mod1::mod2::{mod3, mod4};
|
||||
|
||||
fn fn_1(param1: bool, param2: &str) {
|
||||
«ˇlet var1 = "hello world";»
|
||||
}
|
||||
"#},
|
||||
cx,
|
||||
|
||||
@@ -989,6 +989,16 @@ fn fetch_and_update_hints(
|
||||
}
|
||||
|
||||
let buffer = editor.buffer().read(cx).buffer(query.buffer_id)?;
|
||||
if !editor.registered_buffers.contains_key(&query.buffer_id) {
|
||||
if let Some(project) = editor.project.as_ref() {
|
||||
project.update(cx, |project, cx| {
|
||||
editor.registered_buffers.insert(
|
||||
query.buffer_id,
|
||||
project.register_buffer_with_language_servers(&buffer, cx),
|
||||
);
|
||||
})
|
||||
}
|
||||
}
|
||||
editor
|
||||
.semantics_provider
|
||||
.as_ref()?
|
||||
|
||||
@@ -16,6 +16,7 @@ client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
dap.workspace = true
|
||||
dirs = "5.0"
|
||||
env_logger.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -37,9 +38,11 @@ reqwest_client.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
shellexpand.workspace = true
|
||||
telemetry.workspace = true
|
||||
toml.workspace = true
|
||||
unindent.workspace = true
|
||||
util.workspace = true
|
||||
uuid = { version = "1.6", features = ["v4"] }
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[[bin]]
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
mod example;
|
||||
mod ids;
|
||||
|
||||
use client::{Client, ProxySettings, UserStore};
|
||||
pub(crate) use example::*;
|
||||
use telemetry;
|
||||
|
||||
use ::fs::RealFs;
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::Parser;
|
||||
use extension::ExtensionHostProxy;
|
||||
use futures::future;
|
||||
use futures::stream::StreamExt;
|
||||
use gpui::http_client::{Uri, read_proxy_from_env};
|
||||
use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
|
||||
use gpui_tokio::Tokio;
|
||||
@@ -39,9 +42,18 @@ struct Args {
|
||||
/// Model to use (default: "claude-3-7-sonnet-latest")
|
||||
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
|
||||
model: String,
|
||||
/// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
|
||||
#[arg(long, value_delimiter = ',')]
|
||||
languages: Option<Vec<String>>,
|
||||
/// How many times to run each example. Note that this is currently not very efficient as N
|
||||
/// worktrees will be created for the examples.
|
||||
#[arg(long, default_value = "1")]
|
||||
repetitions: u32,
|
||||
/// How many times to run the judge on each example run.
|
||||
#[arg(long, default_value = "3")]
|
||||
judge_repetitions: u32,
|
||||
/// Maximum number of examples to run concurrently.
|
||||
#[arg(long, default_value = "10")]
|
||||
concurrency: usize,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
@@ -74,6 +86,15 @@ fn main() {
|
||||
app.run(move |cx| {
|
||||
let app_state = init(cx);
|
||||
|
||||
let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
|
||||
let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
|
||||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
app_state
|
||||
.client
|
||||
.telemetry()
|
||||
.start(system_id, installation_id, session_id, cx);
|
||||
|
||||
let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
@@ -129,12 +150,20 @@ fn main() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let name_len = example.name.len();
|
||||
if name_len > max_name_width {
|
||||
max_name_width = example.name.len();
|
||||
}
|
||||
// TODO: This creates a worktree per repetition. Ideally these examples should
|
||||
// either be run sequentially on the same worktree, or reuse worktrees when there
|
||||
// are more examples to run than the concurrency limit.
|
||||
for repetition_number in 0..args.repetitions {
|
||||
let mut example = example.clone();
|
||||
example.set_repetition_number(repetition_number);
|
||||
|
||||
examples.push(example);
|
||||
let name_len = example.name.len();
|
||||
if name_len > max_name_width {
|
||||
max_name_width = example.name.len();
|
||||
}
|
||||
|
||||
examples.push(example);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Skipped examples: {}\n", skipped.join(", "));
|
||||
@@ -203,18 +232,26 @@ fn main() {
|
||||
example.setup().await?;
|
||||
}
|
||||
|
||||
let judge_repetitions = args.judge_repetitions;
|
||||
let concurrency = args.concurrency;
|
||||
|
||||
let tasks = examples
|
||||
.into_iter()
|
||||
.map(|example| {
|
||||
let app_state = app_state.clone();
|
||||
let model = model.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
(run_example(&example, model, app_state, cx).await, example)
|
||||
let result =
|
||||
run_example(&example, model, app_state, judge_repetitions, cx).await;
|
||||
(result, example)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
|
||||
let results = futures::stream::iter(tasks)
|
||||
.buffer_unordered(concurrency)
|
||||
.collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
|
||||
.await;
|
||||
|
||||
println!("\n\n");
|
||||
println!("========================================");
|
||||
@@ -229,16 +266,25 @@ fn main() {
|
||||
Err(err) => {
|
||||
println!("💥 {}{:?}", example.log_prefix, err);
|
||||
}
|
||||
Ok(judge_output) => {
|
||||
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
|
||||
Ok(judge_results) => {
|
||||
for judge_result in judge_results {
|
||||
match judge_result {
|
||||
Ok(judge_output) => {
|
||||
const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
|
||||
|
||||
println!(
|
||||
"{} {}{}",
|
||||
SCORES[judge_output.score.min(5) as usize],
|
||||
example.log_prefix,
|
||||
judge_output.score,
|
||||
);
|
||||
judge_scores.push(judge_output.score);
|
||||
println!(
|
||||
"{} {}{}",
|
||||
SCORES[judge_output.score.min(5) as usize],
|
||||
example.log_prefix,
|
||||
judge_output.score,
|
||||
);
|
||||
judge_scores.push(judge_output.score);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("💥 {}{:?}", example.log_prefix, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
println!(
|
||||
@@ -256,6 +302,11 @@ fn main() {
|
||||
/ (score_count as f32);
|
||||
println!("\nAverage score: {average_score}");
|
||||
|
||||
std::thread::sleep(std::time::Duration::from_secs(2));
|
||||
|
||||
// Flush telemetry events before exiting
|
||||
app_state.client.telemetry().flush_events();
|
||||
|
||||
cx.update(|cx| cx.quit())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
@@ -266,12 +317,52 @@ async fn run_example(
|
||||
example: &Example,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
app_state: Arc<AgentAppState>,
|
||||
judge_repetitions: u32,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
cx.update(|cx| example.run(model.clone(), app_state, cx))?
|
||||
) -> Result<Vec<Result<JudgeOutput>>> {
|
||||
let run_output = cx
|
||||
.update(|cx| example.run(model.clone(), app_state.clone(), cx))?
|
||||
.await?;
|
||||
let diff = example.repository_diff().await?;
|
||||
example.judge(model, diff, cx).await
|
||||
|
||||
// Run judge for each repetition
|
||||
let mut results = Vec::new();
|
||||
for round in 0..judge_repetitions {
|
||||
let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
|
||||
|
||||
// Log telemetry for this judge result
|
||||
if let Ok(judge_output) = &judge_result {
|
||||
let cohort_id = example
|
||||
.output_file_path
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|name| name.to_string_lossy().to_string())
|
||||
.unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
|
||||
|
||||
telemetry::event!(
|
||||
"Agent Eval Completed",
|
||||
cohort_id = cohort_id,
|
||||
example_name = example.name.clone(),
|
||||
round = round,
|
||||
score = judge_output.score,
|
||||
analysis = judge_output.analysis,
|
||||
tool_use_counts = run_output.tool_use_counts,
|
||||
response_count = run_output.response_count,
|
||||
token_usage = run_output.token_usage,
|
||||
model = model.telemetry_id(),
|
||||
model_provider = model.provider_id().to_string(),
|
||||
repository_url = example.base.url.clone(),
|
||||
repository_revision = example.base.revision.clone(),
|
||||
diagnostics_summary = run_output.diagnostics
|
||||
);
|
||||
}
|
||||
|
||||
results.push(judge_result);
|
||||
}
|
||||
|
||||
app_state.client.telemetry().flush_events();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn list_all_examples() -> Result<Vec<PathBuf>> {
|
||||
|
||||
@@ -58,6 +58,8 @@ pub struct Example {
|
||||
pub criteria: String,
|
||||
/// Markdown output file to append to
|
||||
pub output_file: Option<Arc<Mutex<File>>>,
|
||||
/// Path to the output run directory.
|
||||
pub run_dir: PathBuf,
|
||||
/// Path to markdown output file
|
||||
pub output_file_path: PathBuf,
|
||||
/// Prefix used for logging that identifies this example
|
||||
@@ -92,23 +94,27 @@ impl Example {
|
||||
let base_path = dir_path.join("base.toml");
|
||||
let prompt_path = dir_path.join("prompt.md");
|
||||
let criteria_path = dir_path.join("criteria.md");
|
||||
|
||||
let output_file_path = run_dir.join(format!(
|
||||
"{}.md",
|
||||
dir_path.file_name().unwrap().to_str().unwrap()
|
||||
));
|
||||
let output_file_path = run_dir.join(format!("{}.md", name));
|
||||
|
||||
Ok(Example {
|
||||
name: name.clone(),
|
||||
base: toml::from_str(&fs::read_to_string(&base_path)?)?,
|
||||
prompt: fs::read_to_string(prompt_path.clone())?,
|
||||
criteria: fs::read_to_string(criteria_path.clone())?,
|
||||
run_dir: run_dir.to_path_buf(),
|
||||
output_file: None,
|
||||
output_file_path,
|
||||
log_prefix: name,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_repetition_number(&mut self, repetition_number: u32) {
|
||||
if repetition_number > 0 {
|
||||
self.name = format!("{}-{}", self.name, repetition_number);
|
||||
self.output_file_path = self.run_dir.join(format!("{}.md", self.name));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
|
||||
self.log_prefix = format!(
|
||||
"{}{:<width$}\x1b[0m | ",
|
||||
@@ -134,13 +140,21 @@ impl Example {
|
||||
pub async fn setup(&mut self) -> Result<()> {
|
||||
let repo_path = repo_path_for_url(&self.base.url);
|
||||
|
||||
println!("{}Fetching", self.log_prefix);
|
||||
let revision_exists = run_git(&repo_path, &["rev-parse", "--verify", &self.base.revision])
|
||||
.await
|
||||
.is_ok();
|
||||
|
||||
run_git(
|
||||
&repo_path,
|
||||
&["fetch", "--depth", "1", "origin", &self.base.revision],
|
||||
)
|
||||
.await?;
|
||||
if !revision_exists {
|
||||
println!(
|
||||
"{}Fetching revision {}",
|
||||
self.log_prefix, &self.base.revision
|
||||
);
|
||||
run_git(
|
||||
&repo_path,
|
||||
&["fetch", "--depth", "1", "origin", &self.base.revision],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let worktree_path = self.worktree_path();
|
||||
|
||||
@@ -372,18 +386,26 @@ impl Example {
|
||||
pending_tool_use,
|
||||
..
|
||||
} => {
|
||||
if let Some(tool_use) = pending_tool_use {
|
||||
let message = format!("TOOL FINISHED: {}", tool_use.name);
|
||||
println!("{}{message}", log_prefix);
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
}
|
||||
thread.update(cx, |thread, _cx| {
|
||||
if let Some(tool_result) = thread.tool_result(&tool_use_id) {
|
||||
writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
|
||||
let mut tool_use_counts = tool_use_counts.lock().unwrap();
|
||||
*tool_use_counts
|
||||
.entry(tool_result.tool_name.clone())
|
||||
.or_insert(0) += 1;
|
||||
if let Some(tool_use) = pending_tool_use {
|
||||
if let Some(tool_result) = thread.tool_result(&tool_use_id) {
|
||||
let message = if tool_result.is_error {
|
||||
format!("TOOL FAILED: {}", tool_use.name)
|
||||
} else {
|
||||
format!("TOOL FINISHED: {}", tool_use.name)
|
||||
};
|
||||
println!("{log_prefix}{message}");
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err();
|
||||
let mut tool_use_counts = tool_use_counts.lock().unwrap();
|
||||
*tool_use_counts
|
||||
.entry(tool_result.tool_name.clone())
|
||||
.or_insert(0) += 1;
|
||||
} else {
|
||||
let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
|
||||
println!("{log_prefix}{message}");
|
||||
writeln!(&mut output_file, "\n{}", message).log_err();
|
||||
}
|
||||
}
|
||||
})?;
|
||||
}
|
||||
@@ -425,6 +447,10 @@ impl Example {
|
||||
println!("{}Getting repository diff", this.log_prefix);
|
||||
let repository_diff = this.repository_diff().await?;
|
||||
|
||||
let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name));
|
||||
let mut repository_diff_output_file = File::create(&repository_diff_path)?;
|
||||
writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
|
||||
|
||||
println!("{}Getting diagnostics", this.log_prefix);
|
||||
let diagnostics = cx
|
||||
.update(move |cx| {
|
||||
@@ -456,6 +482,7 @@ impl Example {
|
||||
&self,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
repository_diff: String,
|
||||
judge_repetitions: u32,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<JudgeOutput> {
|
||||
let judge_prompt = include_str!("judge_prompt.hbs");
|
||||
@@ -483,14 +510,14 @@ impl Example {
|
||||
|
||||
let response = send_language_model_request(model, request, cx).await?;
|
||||
|
||||
let output_file_ref = self.output_file();
|
||||
let mut output_file = output_file_ref.lock().unwrap();
|
||||
let judge_file_path = self.run_dir.join(format!(
|
||||
"{}_judge_{}.md",
|
||||
self.name, // This is the eval_name
|
||||
judge_repetitions
|
||||
));
|
||||
|
||||
writeln!(&mut output_file, "\n\n").log_err();
|
||||
writeln!(&mut output_file, "========================================").log_err();
|
||||
writeln!(&mut output_file, " JUDGE OUTPUT ").log_err();
|
||||
writeln!(&mut output_file, "========================================").log_err();
|
||||
writeln!(&mut output_file, "\n{}", &response).log_err();
|
||||
let mut judge_output_file = File::create(&judge_file_path)?;
|
||||
writeln!(&mut judge_output_file, "{}", &response).log_err();
|
||||
|
||||
parse_judge_output(&response)
|
||||
}
|
||||
|
||||
28
crates/eval/src/ids.rs
Normal file
28
crates/eval/src/ids.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use anyhow::Result;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn get_or_create_id(path: &Path) -> Result<String> {
|
||||
if let Ok(id) = fs::read_to_string(path) {
|
||||
let trimmed = id.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Ok(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
let new_id = Uuid::new_v4().to_string();
|
||||
fs::write(path, &new_id)?;
|
||||
Ok(new_id)
|
||||
}
|
||||
|
||||
pub fn eval_system_id_path() -> PathBuf {
|
||||
dirs::data_local_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join("zed-eval-system-id")
|
||||
}
|
||||
|
||||
pub fn eval_installation_id_path() -> PathBuf {
|
||||
dirs::data_local_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join("zed-eval-installation-id")
|
||||
}
|
||||
@@ -84,6 +84,11 @@ impl FeatureFlag for ZedPro {
|
||||
const NAME: &'static str = "zed-pro";
|
||||
}
|
||||
|
||||
pub struct ZedProWebSearchTool {}
|
||||
impl FeatureFlag for ZedProWebSearchTool {
|
||||
const NAME: &'static str = "zed-pro-web-search-tool";
|
||||
}
|
||||
|
||||
pub struct NotebookFeatureFlag;
|
||||
|
||||
impl FeatureFlag for NotebookFeatureFlag {
|
||||
|
||||
@@ -2,6 +2,7 @@ use gpui::{App, ClipboardItem, PromptLevel, actions};
|
||||
use system_specs::SystemSpecs;
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use zed_actions::feedback::FileBugReport;
|
||||
|
||||
pub mod feedback_modal;
|
||||
|
||||
@@ -12,7 +13,6 @@ actions!(
|
||||
[
|
||||
CopySystemSpecsIntoClipboard,
|
||||
EmailZed,
|
||||
FileBugReport,
|
||||
OpenZedRepo,
|
||||
RequestFeature,
|
||||
]
|
||||
@@ -27,7 +27,7 @@ fn file_bug_report_url(specs: &SystemSpecs) -> String {
|
||||
concat!(
|
||||
"https://github.com/zed-industries/zed/issues/new",
|
||||
"?",
|
||||
"template=1_bug_report.yml",
|
||||
"template=10_bug_report.yml",
|
||||
"&",
|
||||
"environment={}"
|
||||
),
|
||||
|
||||
@@ -1333,13 +1333,23 @@ impl FakeFs {
|
||||
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
|
||||
anyhow::bail!("pointed-to git dir {path:?} not found")
|
||||
};
|
||||
let FakeFsEntry::Dir { git_repo_state, .. } = &mut *git_dir_entry.lock() else {
|
||||
let FakeFsEntry::Dir {
|
||||
git_repo_state,
|
||||
entries,
|
||||
..
|
||||
} = &mut *git_dir_entry.lock()
|
||||
else {
|
||||
anyhow::bail!("gitfile points to a non-directory")
|
||||
};
|
||||
let common_dir = canonical_path
|
||||
.ancestors()
|
||||
.find(|ancestor| ancestor.ends_with(".git"))
|
||||
.ok_or_else(|| anyhow!("repository dir not contained in any .git"))?;
|
||||
let common_dir = if let Some(child) = entries.get("commondir") {
|
||||
Path::new(
|
||||
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
|
||||
.context("commondir content")?,
|
||||
)
|
||||
.to_owned()
|
||||
} else {
|
||||
canonical_path.clone()
|
||||
};
|
||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
||||
state.git_event_tx.clone(),
|
||||
@@ -1347,7 +1357,7 @@ impl FakeFs {
|
||||
});
|
||||
let mut repo_state = repo_state.lock();
|
||||
|
||||
let result = f(&mut repo_state, &canonical_path, common_dir);
|
||||
let result = f(&mut repo_state, &canonical_path, &common_dir);
|
||||
|
||||
if emit_git_event {
|
||||
state.emit_event([(canonical_path, None)]);
|
||||
|
||||
@@ -1013,7 +1013,6 @@ impl GitRepository for RealGitRepository {
|
||||
let mut command = new_smol_command("git");
|
||||
command
|
||||
.envs(env.iter())
|
||||
.env("GIT_HTTP_USER_AGENT", "Zed")
|
||||
.current_dir(&working_directory)
|
||||
.args(["push"])
|
||||
.args(options.map(|option| match option {
|
||||
@@ -1045,7 +1044,6 @@ impl GitRepository for RealGitRepository {
|
||||
let mut command = new_smol_command("git");
|
||||
command
|
||||
.envs(env.iter())
|
||||
.env("GIT_HTTP_USER_AGENT", "Zed")
|
||||
.current_dir(&working_directory?)
|
||||
.args(["pull"])
|
||||
.arg(remote_name)
|
||||
@@ -1070,7 +1068,6 @@ impl GitRepository for RealGitRepository {
|
||||
let mut command = new_smol_command("git");
|
||||
command
|
||||
.envs(env.iter())
|
||||
.env("GIT_HTTP_USER_AGENT", "Zed")
|
||||
.current_dir(&working_directory?)
|
||||
.args(["fetch", "--all"])
|
||||
.stdout(smol::process::Stdio::piped())
|
||||
|
||||
@@ -599,33 +599,11 @@ impl GitPanel {
|
||||
}
|
||||
|
||||
pub fn entry_by_path(&self, path: &RepoPath) -> Option<usize> {
|
||||
fn binary_search<F>(mut low: usize, mut high: usize, is_target: F) -> Option<usize>
|
||||
where
|
||||
F: Fn(usize) -> std::cmp::Ordering,
|
||||
{
|
||||
while low < high {
|
||||
let mid = low + (high - low) / 2;
|
||||
match is_target(mid) {
|
||||
std::cmp::Ordering::Equal => return Some(mid),
|
||||
std::cmp::Ordering::Less => low = mid + 1,
|
||||
std::cmp::Ordering::Greater => high = mid,
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
if self.conflicted_count > 0 {
|
||||
let conflicted_start = 1;
|
||||
if let Some(ix) = binary_search(
|
||||
conflicted_start,
|
||||
conflicted_start + self.conflicted_count,
|
||||
|ix| {
|
||||
self.entries[ix]
|
||||
.status_entry()
|
||||
.unwrap()
|
||||
.repo_path
|
||||
.cmp(&path)
|
||||
},
|
||||
) {
|
||||
if let Ok(ix) = self.entries[conflicted_start..conflicted_start + self.conflicted_count]
|
||||
.binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path))
|
||||
{
|
||||
return Some(ix);
|
||||
}
|
||||
}
|
||||
@@ -635,14 +613,8 @@ impl GitPanel {
|
||||
} else {
|
||||
0
|
||||
} + 1;
|
||||
if let Some(ix) =
|
||||
binary_search(tracked_start, tracked_start + self.tracked_count, |ix| {
|
||||
self.entries[ix]
|
||||
.status_entry()
|
||||
.unwrap()
|
||||
.repo_path
|
||||
.cmp(&path)
|
||||
})
|
||||
if let Ok(ix) = self.entries[tracked_start..tracked_start + self.tracked_count]
|
||||
.binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path))
|
||||
{
|
||||
return Some(ix);
|
||||
}
|
||||
@@ -657,14 +629,8 @@ impl GitPanel {
|
||||
} else {
|
||||
0
|
||||
} + 1;
|
||||
if let Some(ix) =
|
||||
binary_search(untracked_start, untracked_start + self.new_count, |ix| {
|
||||
self.entries[ix]
|
||||
.status_entry()
|
||||
.unwrap()
|
||||
.repo_path
|
||||
.cmp(&path)
|
||||
})
|
||||
if let Ok(ix) = self.entries[untracked_start..untracked_start + self.new_count]
|
||||
.binary_search_by(|entry| entry.status_entry().unwrap().repo_path.cmp(&path))
|
||||
{
|
||||
return Some(ix);
|
||||
}
|
||||
@@ -3611,6 +3577,15 @@ impl GitPanel {
|
||||
items
|
||||
}
|
||||
})
|
||||
.when(
|
||||
!self.horizontal_scrollbar.show_track
|
||||
&& self.horizontal_scrollbar.show_scrollbar,
|
||||
|this| {
|
||||
// when not showing the horizontal scrollbar track, make sure we don't
|
||||
// obscure the last entry
|
||||
this.pb(scroll_track_size)
|
||||
},
|
||||
)
|
||||
.size_full()
|
||||
.flex_grow()
|
||||
.with_sizing_behavior(ListSizingBehavior::Auto)
|
||||
|
||||
@@ -589,11 +589,6 @@ impl<V> Entity<V> {
|
||||
use postage::prelude::{Sink as _, Stream as _};
|
||||
|
||||
let (tx, mut rx) = postage::mpsc::channel(1024);
|
||||
let timeout_duration = if cfg!(target_os = "macos") {
|
||||
Duration::from_millis(100)
|
||||
} else {
|
||||
Duration::from_secs(1)
|
||||
};
|
||||
|
||||
let mut cx = cx.app.borrow_mut();
|
||||
let subscriptions = (
|
||||
@@ -615,7 +610,7 @@ impl<V> Entity<V> {
|
||||
let handle = self.downgrade();
|
||||
|
||||
async move {
|
||||
crate::util::timeout(timeout_duration, async move {
|
||||
crate::util::timeout(Duration::from_secs(1), async move {
|
||||
loop {
|
||||
{
|
||||
let cx = cx.borrow();
|
||||
|
||||
@@ -27,6 +27,8 @@ use objc::{
|
||||
};
|
||||
use std::{cell::RefCell, ffi::c_void, mem, ptr, rc::Rc};
|
||||
|
||||
use super::NSStringExt;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MacScreenCaptureSource {
|
||||
sc_display: id,
|
||||
@@ -184,7 +186,10 @@ pub(crate) fn get_sources() -> oneshot::Receiver<Result<Vec<Box<dyn ScreenCaptur
|
||||
Ok(result)
|
||||
} else {
|
||||
let msg: id = msg_send![error, localizedDescription];
|
||||
Err(anyhow!("Failed to register: {:?}", msg))
|
||||
Err(anyhow!(
|
||||
"Screen share failed: {:?}",
|
||||
NSStringExt::to_str(&msg)
|
||||
))
|
||||
};
|
||||
tx.send(result).ok();
|
||||
});
|
||||
|
||||
@@ -142,6 +142,24 @@ impl fmt::Display for MaxMonthlySpendReachedError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub struct ModelRequestLimitReachedError {
|
||||
pub plan: Plan,
|
||||
}
|
||||
|
||||
impl fmt::Display for ModelRequestLimitReachedError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let message = match self.plan {
|
||||
Plan::Free => "Model request limit reached. Upgrade to Zed Pro for more requests.",
|
||||
Plan::ZedPro => {
|
||||
"Model request limit reached. Upgrade to usage-based billing for more requests."
|
||||
}
|
||||
};
|
||||
|
||||
write!(f, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -16,18 +16,21 @@ use language_model::{
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||
MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
|
||||
};
|
||||
use proto::Plan;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use serde_json::value::RawValue;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::Timer;
|
||||
use smol::io::{AsyncReadExt, BufReader};
|
||||
use std::str::FromStr as _;
|
||||
use std::{
|
||||
sync::{Arc, LazyLock},
|
||||
time::Duration,
|
||||
@@ -35,6 +38,7 @@ use std::{
|
||||
use strum::IntoEnumIterator;
|
||||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
use zed_llm_client::{CURRENT_PLAN_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
|
||||
@@ -551,6 +555,32 @@ impl CloudLanguageModel {
|
||||
.is_some()
|
||||
{
|
||||
return Err(anyhow!(MaxMonthlySpendReachedError));
|
||||
} else if status == StatusCode::FORBIDDEN
|
||||
&& response
|
||||
.headers()
|
||||
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
if let Some("model_requests") = response
|
||||
.headers()
|
||||
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
|
||||
.and_then(|resource| resource.to_str().ok())
|
||||
{
|
||||
if let Some(plan) = response
|
||||
.headers()
|
||||
.get(CURRENT_PLAN_HEADER_NAME)
|
||||
.and_then(|plan| plan.to_str().ok())
|
||||
.and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
|
||||
{
|
||||
let plan = match plan {
|
||||
zed_llm_client::Plan::Free => Plan::Free,
|
||||
zed_llm_client::Plan::ZedPro => Plan::ZedPro,
|
||||
};
|
||||
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(anyhow!("Forbidden"));
|
||||
} else if status.as_u16() >= 500 && status.as_u16() < 600 {
|
||||
// If we encounter an error in the 500 range, retry after a delay.
|
||||
// We've seen at least these in the wild from API providers:
|
||||
|
||||
@@ -2,25 +2,17 @@
|
||||
(_ "{" "}" @end) @indent
|
||||
(_ "(" ")" @end) @indent
|
||||
|
||||
[
|
||||
(if_statement)
|
||||
(for_statement)
|
||||
(while_statement)
|
||||
(with_statement)
|
||||
(function_definition)
|
||||
(class_definition)
|
||||
(match_statement)
|
||||
(try_statement)
|
||||
] @indent
|
||||
(try_statement
|
||||
body: (_) @start
|
||||
[(except_clause) (finally_clause)] @end
|
||||
) @indent
|
||||
|
||||
[
|
||||
(else_clause)
|
||||
(elif_clause)
|
||||
(except_clause)
|
||||
(finally_clause)
|
||||
] @outdent
|
||||
(if_statement
|
||||
consequence: (_) @start
|
||||
alternative: (_) @end
|
||||
) @indent
|
||||
|
||||
[
|
||||
(block)
|
||||
(case_clause)
|
||||
] @indent
|
||||
(_
|
||||
alternative: (elif_clause) @start
|
||||
alternative: (_) @end
|
||||
) @indent
|
||||
|
||||
@@ -61,11 +61,14 @@ impl Anchor {
|
||||
return Ordering::Equal;
|
||||
}
|
||||
|
||||
let excerpt_id_cmp = self.excerpt_id.cmp(&other.excerpt_id, snapshot);
|
||||
let self_excerpt_id = snapshot.latest_excerpt_id(self.excerpt_id);
|
||||
let other_excerpt_id = snapshot.latest_excerpt_id(other.excerpt_id);
|
||||
|
||||
let excerpt_id_cmp = self_excerpt_id.cmp(&other_excerpt_id, snapshot);
|
||||
if excerpt_id_cmp.is_ne() {
|
||||
return excerpt_id_cmp;
|
||||
}
|
||||
if self.excerpt_id == ExcerptId::min() || self.excerpt_id == ExcerptId::max() {
|
||||
if self_excerpt_id == ExcerptId::min() || self_excerpt_id == ExcerptId::max() {
|
||||
return Ordering::Equal;
|
||||
}
|
||||
if let Some(excerpt) = snapshot.excerpt(self.excerpt_id) {
|
||||
|
||||
@@ -746,19 +746,20 @@ fn test_expand_excerpts(cx: &mut App) {
|
||||
drop(snapshot);
|
||||
|
||||
multibuffer.update(cx, |multibuffer, cx| {
|
||||
let line_zero = multibuffer.snapshot(cx).anchor_before(Point::new(0, 0));
|
||||
multibuffer.expand_excerpts(
|
||||
multibuffer.excerpt_ids(),
|
||||
1,
|
||||
ExpandExcerptDirection::UpAndDown,
|
||||
cx,
|
||||
)
|
||||
);
|
||||
let snapshot = multibuffer.snapshot(cx);
|
||||
let line_two = snapshot.anchor_before(Point::new(2, 0));
|
||||
assert_eq!(line_two.cmp(&line_zero, &snapshot), cmp::Ordering::Greater);
|
||||
});
|
||||
|
||||
let snapshot = multibuffer.read(cx).snapshot(cx);
|
||||
|
||||
// Expanding context lines causes the line containing 'fff' to appear in two different excerpts.
|
||||
// We don't attempt to merge them, because removing the excerpt could create inconsistency with other layers
|
||||
// that are tracking excerpt ids.
|
||||
assert_eq!(
|
||||
snapshot.text(),
|
||||
concat!(
|
||||
|
||||
@@ -21,7 +21,10 @@ use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
future::{Shared, join_all},
|
||||
};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, SharedString,
|
||||
Task,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language::{BinaryStatus, LanguageRegistry, LanguageToolchainStore};
|
||||
use lsp::LanguageServerName;
|
||||
@@ -90,6 +93,17 @@ impl LocalDapStore {
|
||||
fn next_session_id(&self) -> SessionId {
|
||||
SessionId(self.next_session_id.fetch_add(1, SeqCst))
|
||||
}
|
||||
pub(crate) fn locate_binary(
|
||||
&self,
|
||||
mut definition: DebugTaskDefinition,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Task<DebugTaskDefinition> {
|
||||
let locator_store = self.locator_store.clone();
|
||||
executor.spawn(async move {
|
||||
let _ = locator_store.resolve_debug_config(&mut definition).await;
|
||||
definition
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RemoteDapStore {
|
||||
@@ -335,7 +349,7 @@ impl DapStore {
|
||||
pub fn new_session(
|
||||
&mut self,
|
||||
binary: DebugAdapterBinary,
|
||||
mut config: DebugTaskDefinition,
|
||||
config: DebugTaskDefinition,
|
||||
parent_session: Option<Entity<Session>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> (SessionId, Task<Result<Entity<Session>>>) {
|
||||
@@ -352,22 +366,10 @@ impl DapStore {
|
||||
}
|
||||
|
||||
let (initialized_tx, initialized_rx) = oneshot::channel();
|
||||
let locator_store = local_store.locator_store.clone();
|
||||
|
||||
let start_debugging_tx = local_store.start_debugging_tx.clone();
|
||||
|
||||
let task = cx.spawn(async move |this, cx| {
|
||||
if config.locator.is_some() {
|
||||
config = cx
|
||||
.background_spawn(async move {
|
||||
locator_store
|
||||
.resolve_debug_config(&mut config)
|
||||
.await
|
||||
.map(|_| config)
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
let start_client_task = this.update(cx, |this, cx| {
|
||||
Session::local(
|
||||
this.breakpoint_store.clone(),
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use super::DapLocator;
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
use serde_json::Value;
|
||||
use smol::{
|
||||
io::AsyncReadExt,
|
||||
process::{Command, Stdio},
|
||||
};
|
||||
use task::DebugTaskDefinition;
|
||||
use util::maybe;
|
||||
|
||||
pub(super) struct CargoLocator;
|
||||
|
||||
@@ -109,43 +108,13 @@ impl DapLocator for CargoLocator {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let Some(executable) = executable.or_else(|| executables.first().cloned()) else {
|
||||
return Err(anyhow!("Couldn't get executable in cargo locator"));
|
||||
};
|
||||
|
||||
launch_config.program = executable;
|
||||
|
||||
if debug_config.adapter == "LLDB" && debug_config.initialize_args.is_none() {
|
||||
// Find Rust pretty-printers in current toolchain's sysroot
|
||||
let cwd = launch_config.cwd.clone();
|
||||
debug_config.initialize_args = maybe!(async move {
|
||||
let cwd = cwd?;
|
||||
|
||||
let output = Command::new("rustc")
|
||||
.arg("--print")
|
||||
.arg("sysroot")
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let sysroot_path = String::from_utf8(output.stdout).ok()?;
|
||||
let sysroot_path = sysroot_path.trim_end();
|
||||
let first_command = format!(
|
||||
r#"command script import "{sysroot_path}/lib/rustlib/etc/lldb_lookup.py"#
|
||||
);
|
||||
let second_command =
|
||||
format!(r#"command source -s 0 '{sysroot_path}/lib/rustlib/etc/lldb_commands"#);
|
||||
|
||||
Some(json!({"initCommands": [first_command, second_command]}))
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
launch_config.args.clear();
|
||||
if let Some(test_name) = test_name {
|
||||
launch_config.args.push(test_name);
|
||||
|
||||
@@ -1482,6 +1482,18 @@ impl Project {
|
||||
.update(cx, |dap_store, cx| dap_store.delegate(&worktree, cx))
|
||||
})?;
|
||||
|
||||
let task = this.update(cx, |project, cx| {
|
||||
project.dap_store.read(cx).as_local().and_then(|local| {
|
||||
config.locator.is_some().then(|| {
|
||||
local.locate_binary(config.clone(), cx.background_executor().clone())
|
||||
})
|
||||
})
|
||||
})?;
|
||||
let config = if let Some(task) = task {
|
||||
task.await
|
||||
} else {
|
||||
config
|
||||
};
|
||||
let binary = adapter
|
||||
.get_binary(&delegate, &config, user_installed_path, cx)
|
||||
.await?;
|
||||
|
||||
@@ -8273,17 +8273,34 @@ async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) {
|
||||
json!({
|
||||
".git": {
|
||||
"worktrees": {
|
||||
"some-worktree": {}
|
||||
"some-worktree": {
|
||||
"commondir": "../..\n"
|
||||
}
|
||||
},
|
||||
"modules": {
|
||||
"subdir": {
|
||||
"some-submodule": {
|
||||
// For is_git_dir
|
||||
"HEAD": "",
|
||||
"config": "",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"src": {
|
||||
"a.txt": "A",
|
||||
},
|
||||
"some-worktree": {
|
||||
".git": "gitdir: ../.git/worktrees/some-worktree",
|
||||
".git": "gitdir: ../.git/worktrees/some-worktree\n",
|
||||
"src": {
|
||||
"b.txt": "B",
|
||||
}
|
||||
},
|
||||
"subdir": {
|
||||
"some-submodule": {
|
||||
".git": "gitdir: ../../.git/modules/subdir/some-submodule\n",
|
||||
"c.txt": "C",
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -8315,9 +8332,11 @@ async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) {
|
||||
[
|
||||
Path::new(path!("/project")).into(),
|
||||
Path::new(path!("/project/some-worktree")).into(),
|
||||
Path::new(path!("/project/subdir/some-submodule")).into(),
|
||||
]
|
||||
);
|
||||
|
||||
// Generate a git-related event for the worktree and check that it's refreshed.
|
||||
fs.with_git_state(
|
||||
path!("/project/some-worktree/.git").as_ref(),
|
||||
true,
|
||||
@@ -8359,6 +8378,45 @@ async fn test_git_worktrees_and_submodules(cx: &mut gpui::TestAppContext) {
|
||||
StatusCode::Modified.worktree(),
|
||||
);
|
||||
});
|
||||
|
||||
// The same for the submodule.
|
||||
fs.with_git_state(
|
||||
path!("/project/subdir/some-submodule/.git").as_ref(),
|
||||
true,
|
||||
|state| {
|
||||
state.head_contents.insert("c.txt".into(), "c".to_owned());
|
||||
state.index_contents.insert("c.txt".into(), "c".to_owned());
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/project/subdir/some-submodule/c.txt"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let (submodule_repo, barrier) = project.update(cx, |project, cx| {
|
||||
let (repo, _) = project
|
||||
.git_store()
|
||||
.read(cx)
|
||||
.repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx)
|
||||
.unwrap();
|
||||
pretty_assertions::assert_eq!(
|
||||
repo.read(cx).work_directory_abs_path,
|
||||
Path::new(path!("/project/subdir/some-submodule")).into(),
|
||||
);
|
||||
let barrier = repo.update(cx, |repo, _| repo.barrier());
|
||||
(repo.clone(), barrier)
|
||||
});
|
||||
barrier.await.unwrap();
|
||||
submodule_repo.update(cx, |repo, _| {
|
||||
pretty_assertions::assert_eq!(
|
||||
repo.status_for_path(&"c.txt".into()).unwrap().status,
|
||||
StatusCode::Modified.worktree(),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
||||
@@ -160,7 +160,11 @@ impl Render for Tooltip {
|
||||
}),
|
||||
)
|
||||
.when_some(self.meta.clone(), |this, meta| {
|
||||
this.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted))
|
||||
this.child(
|
||||
div()
|
||||
.max_w_72()
|
||||
.child(Label::new(meta).size(LabelSize::Small).color(Color::Muted)),
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
20
crates/web_search/Cargo.toml
Normal file
20
crates/web_search/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "web_search"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/web_search.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
serde.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
1
crates/web_search/LICENSE-GPL
Symbolic link
1
crates/web_search/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
64
crates/web_search/src/web_search.rs
Normal file
64
crates/web_search/src/web_search.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AppContext as _, Context, Entity, Global, SharedString, Task};
|
||||
use std::sync::Arc;
|
||||
use zed_llm_client::WebSearchResponse;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
let registry = cx.new(|_cx| WebSearchRegistry::default());
|
||||
cx.set_global(GlobalWebSearchRegistry(registry));
|
||||
}
|
||||
|
||||
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
|
||||
pub struct WebSearchProviderId(pub SharedString);
|
||||
|
||||
pub trait WebSearchProvider {
|
||||
fn id(&self) -> WebSearchProviderId;
|
||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>>;
|
||||
}
|
||||
|
||||
struct GlobalWebSearchRegistry(Entity<WebSearchRegistry>);
|
||||
|
||||
impl Global for GlobalWebSearchRegistry {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WebSearchRegistry {
|
||||
providers: HashMap<WebSearchProviderId, Arc<dyn WebSearchProvider>>,
|
||||
active_provider: Option<Arc<dyn WebSearchProvider>>,
|
||||
}
|
||||
|
||||
impl WebSearchRegistry {
|
||||
pub fn global(cx: &App) -> Entity<Self> {
|
||||
cx.global::<GlobalWebSearchRegistry>().0.clone()
|
||||
}
|
||||
|
||||
pub fn read_global(cx: &App) -> &Self {
|
||||
cx.global::<GlobalWebSearchRegistry>().0.read(cx)
|
||||
}
|
||||
|
||||
pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn WebSearchProvider>> {
|
||||
self.providers.values()
|
||||
}
|
||||
|
||||
pub fn active_provider(&self) -> Option<Arc<dyn WebSearchProvider>> {
|
||||
self.active_provider.clone()
|
||||
}
|
||||
|
||||
pub fn set_active_provider(&mut self, provider: Arc<dyn WebSearchProvider>) {
|
||||
self.active_provider = Some(provider.clone());
|
||||
self.providers.insert(provider.id(), provider);
|
||||
}
|
||||
|
||||
pub fn register_provider<T: WebSearchProvider + 'static>(
|
||||
&mut self,
|
||||
provider: T,
|
||||
_cx: &mut Context<Self>,
|
||||
) {
|
||||
let id = provider.id();
|
||||
let provider = Arc::new(provider);
|
||||
self.providers.insert(id.clone(), provider.clone());
|
||||
if self.active_provider.is_none() {
|
||||
self.active_provider = Some(provider);
|
||||
}
|
||||
}
|
||||
}
|
||||
26
crates/web_search_providers/Cargo.toml
Normal file
26
crates/web_search_providers/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "web_search_providers"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/web_search_providers.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language_model.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
web_search.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
1
crates/web_search_providers/LICENSE-GPL
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-GPL
|
||||
103
crates/web_search_providers/src/cloud.rs
Normal file
103
crates/web_search_providers/src/cloud.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::Client;
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
|
||||
use http_client::{HttpClient, Method};
|
||||
use language_model::{LlmApiToken, RefreshLlmTokenListener};
|
||||
use web_search::{WebSearchProvider, WebSearchProviderId};
|
||||
use zed_llm_client::{WebSearchBody, WebSearchResponse};
|
||||
|
||||
pub struct CloudWebSearchProvider {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
impl CloudWebSearchProvider {
|
||||
pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State::new(client, cx));
|
||||
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
_llm_token_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
||||
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
|
||||
|
||||
Self {
|
||||
client,
|
||||
llm_api_token: LlmApiToken::default(),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSearchProvider for CloudWebSearchProvider {
|
||||
fn id(&self) -> WebSearchProviderId {
|
||||
WebSearchProviderId("zed.dev".into())
|
||||
}
|
||||
|
||||
fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
|
||||
let state = self.state.read(cx);
|
||||
let client = state.client.clone();
|
||||
let llm_api_token = state.llm_api_token.clone();
|
||||
let body = WebSearchBody { query };
|
||||
cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_web_search(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
body: WebSearchBody,
|
||||
) -> Result<WebSearchResponse> {
|
||||
let http_client = &client.http_client();
|
||||
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
|
||||
request_builder.uri(web_search_url)
|
||||
} else {
|
||||
request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
|
||||
};
|
||||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&body)?.into())?;
|
||||
let mut response = http_client
|
||||
.send(request)
|
||||
.await
|
||||
.context("failed to send web search request")?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Ok(serde_json::from_str(&body)?);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
return Err(anyhow!(
|
||||
"error performing web search.\nStatus: {:?}\nBody: {body}",
|
||||
response.status(),
|
||||
));
|
||||
}
|
||||
}
|
||||
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
35
crates/web_search_providers/src/web_search_providers.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
mod cloud;
|
||||
|
||||
use client::Client;
|
||||
use feature_flags::{FeatureFlagAppExt, ZedProWebSearchTool};
|
||||
use gpui::{App, Context};
|
||||
use std::sync::Arc;
|
||||
use web_search::WebSearchRegistry;
|
||||
|
||||
pub fn init(client: Arc<Client>, cx: &mut App) {
|
||||
let registry = WebSearchRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_web_search_providers(registry, client, cx);
|
||||
});
|
||||
}
|
||||
|
||||
fn register_web_search_providers(
|
||||
_registry: &mut WebSearchRegistry,
|
||||
client: Arc<Client>,
|
||||
cx: &mut Context<WebSearchRegistry>,
|
||||
) {
|
||||
cx.observe_flag::<ZedProWebSearchTool, _>({
|
||||
let client = client.clone();
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
WebSearchRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.register_provider(
|
||||
cloud::CloudWebSearchProvider::new(client.clone(), cx),
|
||||
cx,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
@@ -106,6 +106,7 @@ use uuid::Uuid;
|
||||
pub use workspace_settings::{
|
||||
AutosaveSetting, BottomDockLayout, RestoreOnStartupBehavior, TabBarSettings, WorkspaceSettings,
|
||||
};
|
||||
use zed_actions::feedback::FileBugReport;
|
||||
|
||||
use crate::notifications::NotificationId;
|
||||
use crate::persistence::{
|
||||
@@ -5395,8 +5396,6 @@ enum ActivateInDirectionTarget {
|
||||
}
|
||||
|
||||
fn notify_if_database_failed(workspace: WindowHandle<Workspace>, cx: &mut AsyncApp) {
|
||||
const REPORT_ISSUE_URL: &str = "https://github.com/zed-industries/zed/issues/new?assignees=&labels=admin+read%2Ctriage%2Cbug&projects=&template=1_bug_report.yml";
|
||||
|
||||
workspace
|
||||
.update(cx, |workspace, _, cx| {
|
||||
if (*db::ALL_FILE_DB_FAILED).load(std::sync::atomic::Ordering::Acquire) {
|
||||
@@ -5410,7 +5409,9 @@ fn notify_if_database_failed(workspace: WindowHandle<Workspace>, cx: &mut AsyncA
|
||||
MessageNotification::new("Failed to load the database file.", cx)
|
||||
.primary_message("File an Issue")
|
||||
.primary_icon(IconName::Plus)
|
||||
.primary_on_click(|_window, cx| cx.open_url(REPORT_ISSUE_URL))
|
||||
.primary_on_click(|window, cx| {
|
||||
window.dispatch_action(Box::new(FileBugReport), cx)
|
||||
})
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
@@ -3120,32 +3120,12 @@ impl BackgroundScannerState {
|
||||
.as_path()
|
||||
.into();
|
||||
|
||||
let mut common_dir_abs_path = dot_git_abs_path.clone();
|
||||
let mut repository_dir_abs_path = dot_git_abs_path.clone();
|
||||
// Parse .git if it's a "gitfile" pointing to a repository directory elsewhere.
|
||||
if let Some(dot_git_contents) = smol::block_on(fs.load(&dot_git_abs_path)).ok() {
|
||||
if let Some(path) = dot_git_contents.strip_prefix("gitdir:") {
|
||||
let path = path.trim();
|
||||
let path = dot_git_abs_path
|
||||
.parent()
|
||||
.unwrap_or(Path::new(""))
|
||||
.join(path);
|
||||
if let Some(path) = smol::block_on(fs.canonicalize(&path)).log_err() {
|
||||
repository_dir_abs_path = Path::new(&path).into();
|
||||
common_dir_abs_path = repository_dir_abs_path.clone();
|
||||
if let Some(ancestor_dot_git) = path
|
||||
.ancestors()
|
||||
.skip(1)
|
||||
.find(|ancestor| smol::block_on(is_git_dir(ancestor, fs)))
|
||||
{
|
||||
common_dir_abs_path = ancestor_dot_git.into();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::error!("failed to parse contents of .git file: {dot_git_contents:?}");
|
||||
}
|
||||
};
|
||||
let (repository_dir_abs_path, common_dir_abs_path) =
|
||||
discover_git_paths(&dot_git_abs_path, fs);
|
||||
watcher.add(&common_dir_abs_path).log_err();
|
||||
if !repository_dir_abs_path.starts_with(&common_dir_abs_path) {
|
||||
watcher.add(&repository_dir_abs_path).log_err();
|
||||
}
|
||||
|
||||
let work_directory_id = work_dir_entry.id;
|
||||
|
||||
@@ -5508,3 +5488,40 @@ impl CreatedEntry {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_gitfile(content: &str) -> anyhow::Result<&Path> {
|
||||
let path = content
|
||||
.strip_prefix("gitdir:")
|
||||
.ok_or_else(|| anyhow!("failed to parse gitfile content {content:?}"))?;
|
||||
Ok(Path::new(path.trim()))
|
||||
}
|
||||
|
||||
fn discover_git_paths(dot_git_abs_path: &Arc<Path>, fs: &dyn Fs) -> (Arc<Path>, Arc<Path>) {
|
||||
let mut repository_dir_abs_path = dot_git_abs_path.clone();
|
||||
let mut common_dir_abs_path = dot_git_abs_path.clone();
|
||||
|
||||
if let Some(path) = smol::block_on(fs.load(&dot_git_abs_path))
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|contents| parse_gitfile(contents).log_err())
|
||||
{
|
||||
let path = dot_git_abs_path
|
||||
.parent()
|
||||
.unwrap_or(Path::new(""))
|
||||
.join(path);
|
||||
if let Some(path) = smol::block_on(fs.canonicalize(&path)).log_err() {
|
||||
repository_dir_abs_path = Path::new(&path).into();
|
||||
common_dir_abs_path = repository_dir_abs_path.clone();
|
||||
if let Some(commondir_contents) = smol::block_on(fs.load(&path.join("commondir"))).ok()
|
||||
{
|
||||
if let Some(commondir_path) =
|
||||
smol::block_on(fs.canonicalize(&path.join(commondir_contents.trim()))).log_err()
|
||||
{
|
||||
common_dir_abs_path = commondir_path.as_path().into();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(repository_dir_abs_path, common_dir_abs_path)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description = "The fast, collaborative code editor."
|
||||
edition.workspace = true
|
||||
name = "zed"
|
||||
version = "0.183.0"
|
||||
version = "0.184.0"
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
authors = ["Zed Team <hi@zed.dev>"]
|
||||
@@ -133,6 +133,8 @@ util.workspace = true
|
||||
uuid.workspace = true
|
||||
vim.workspace = true
|
||||
vim_mode_setting.workspace = true
|
||||
web_search.workspace = true
|
||||
web_search_providers.workspace = true
|
||||
welcome.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
|
||||
@@ -490,6 +490,8 @@ fn main() {
|
||||
app_state.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
web_search::init(cx);
|
||||
web_search_providers::init(app_state.client.clone(), cx);
|
||||
snippet_provider::init(cx);
|
||||
inline_completion_registry::init(
|
||||
app_state.client.clone(),
|
||||
|
||||
@@ -4258,6 +4258,8 @@ mod tests {
|
||||
app_state.fs.clone(),
|
||||
cx,
|
||||
);
|
||||
web_search::init(cx);
|
||||
web_search_providers::init(app_state.client.clone(), cx);
|
||||
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), false, cx);
|
||||
assistant::init(
|
||||
app_state.fs.clone(),
|
||||
|
||||
@@ -150,7 +150,7 @@ pub mod command_palette {
|
||||
pub mod feedback {
|
||||
use gpui::actions;
|
||||
|
||||
actions!(feedback, [GiveFeedback]);
|
||||
actions!(feedback, [FileBugReport, GiveFeedback]);
|
||||
}
|
||||
|
||||
pub mod theme_selector {
|
||||
|
||||
@@ -1311,10 +1311,10 @@ To interpret all `.c` files as C++, files called `MyLockFile` as TOML and files
|
||||
"include_warnings": true,
|
||||
"inline": {
|
||||
"enabled": false
|
||||
}
|
||||
},
|
||||
"update_with_cursor": false,
|
||||
"primary_only": false,
|
||||
"use_rendered": false,
|
||||
"use_rendered": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -291,14 +291,15 @@ To run tests in your Ruby project, you can set up custom tasks in your local `.z
|
||||
```json
|
||||
[
|
||||
{
|
||||
"label": "test $ZED_RELATIVE_FILE:$ZED_ROW",
|
||||
"command": "bin/rails",
|
||||
"args": ["test", "\"$ZED_RELATIVE_FILE:$ZED_ROW\""],
|
||||
"label": "test $ZED_RELATIVE_FILE -n /$ZED_SYMBOL/",
|
||||
"command": "bin/rails test $ZED_RELATIVE_FILE -n /$ZED_SYMBOL/",
|
||||
"tags": ["ruby-test"]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Note: We can't use `args` here because of the way quotes are handled.
|
||||
|
||||
### Minitest
|
||||
|
||||
Plain minitest does not support running tests by line number, only by name, so we need to use `$ZED_SYMBOL` instead:
|
||||
|
||||
@@ -154,3 +154,33 @@ If you are seeing "too many open files" then first try `sysctl fs.inotify`.
|
||||
- You should see that `max_user_watches` is 8000 or higher (you can change the limit with `sudo sysctl fs.inotify.max_user_watches=64000`). Zed needs one watch per directory in all your open projects + one per git repository + a handful more for settings, themes, keymaps, extensions.
|
||||
|
||||
It is also possible that you are running out of file descriptors. You can check the limits with `ulimit` and update them by editing `/etc/security/limits.conf`.
|
||||
|
||||
### No sound or wrong output device
|
||||
|
||||
If you're not hearing any sound in Zed or the audio is routed to the wrong device, it could be due to a mismatch between audio systems. Zed relies on ALSA, while your system may be using PipeWire or PulseAudio. To resolve this, you need to configure ALSA to route audio through PipeWire/PulseAudio.
|
||||
|
||||
If your system uses PipeWire:
|
||||
|
||||
1. **Install the PipeWire ALSA plugin**
|
||||
|
||||
On Debian-based systems, run:
|
||||
|
||||
```bash
|
||||
sudo apt install pipewire-alsa
|
||||
```
|
||||
|
||||
2. **Configure ALSA to use PipeWire**
|
||||
|
||||
Add the following configuration to your ALSA settings file. You can use either `~/.asoundrc` (user-level) or `/etc/asound.conf` (system-wide):
|
||||
|
||||
```bash
|
||||
pcm.!default {
|
||||
type pipewire
|
||||
}
|
||||
|
||||
ctl.!default {
|
||||
type pipewire
|
||||
}
|
||||
```
|
||||
|
||||
3. **Restart your system**
|
||||
|
||||
@@ -11,4 +11,4 @@ The naming convention of these databases takes on the form of `0-<zed_channel>`:
|
||||
- Stable: `0-stable`
|
||||
- Preview: `0-preview`
|
||||
|
||||
**If you encounter workspace persistence issues in Zed, deleting the database and restarting Zed often resolves the problem, as the database may have been corrupted at some point.** If your issue continues after restarting Zed and regenerating a new database, please [file an issue](https://github.com/zed-industries/zed/issues/new?template=1_bug_report.yml).
|
||||
**If you encounter workspace persistence issues in Zed, deleting the database and restarting Zed often resolves the problem, as the database may have been corrupted at some point.** If your issue continues after restarting Zed and regenerating a new database, please [file an issue](https://github.com/zed-industries/zed/issues/new?template=10_bug_report.yml).
|
||||
|
||||
Reference in New Issue
Block a user