Compare commits

...

32 Commits

Author SHA1 Message Date
Zed Bot
47edd4a675 Bump to 0.182.3 for @maxdeviant 2025-04-11 00:18:30 +00:00
Antonio Scandurra
d487d464a1 Actually run the eval and fix a hang when retrieving outline (#28547)
Release Notes:

- Fixed a regression that caused the agent to hang sometimes.

---------

Co-authored-by: Thomas Mickley-Doyle <tmickleydoyle@gmail.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
Co-authored-by: Michael Sloan <mgsloan@gmail.com>
2025-04-10 18:11:46 -06:00
Bennet Bo Fenner
d1dcbbac45 markdown: Track code block metadata in parser (#28543)
This allows us to not scan the codeblock content for newlines on every
frame in `active_thread`

Release Notes:

- N/A
2025-04-10 17:46:19 -06:00
Danilo Leal
3430d431cf Change zed.dev's default model to Claude 3.7 Sonnet (#28541)
From Claude 3.5 Sonnet to **Claude 3.7 Sonnet**.

Release Notes:

- Change the default model of Zed's hosted LLM service to Claude 3.7
Sonnet.
2025-04-10 17:46:12 -06:00
Marshall Bowers
16395bd78d language_models: Fix non-streaming Copilot Chat models (#28537)
This PR fixes usage of non-streaming Copilot Chat models.

Closes https://github.com/zed-industries/zed/issues/28528.

Release Notes:

- Fixed an issue with using non-streaming Copilot Chat models (e.g., o1,
o3-mini).
2025-04-10 17:46:01 -06:00
Marko Kungla
4261201eb9 Add --user-data-dir CLI flag and propose renaming support_dir to data_dir (#26886)
This PR introduces support for a `--user-data-dir` CLI flag to override
Zed's data directory and proposes renaming `support_dir` to `data_dir`
for better cross-platform clarity. It builds on the discussion in #25349
about custom data directories, aiming to provide a flexible
cross-platform solution.

### Changes

The PR is split into two commits:
1. **[feat(cli): add --user-data-dir to override data
directory](28e8889105)**
2. **[refactor(paths): rename support_dir to data_dir for cross-platform
clarity](affd2fc606)**


### Context
Inspired by the need for custom data directories discussed in #25349,
this PR provides an immediate implementation in the first commit, while
the second commit suggests a naming improvement for broader appeal.
@mikayla-maki, I’d appreciate your feedback, especially on the rename
proposal, given your involvement in the original discussion!

### Testing
- `cargo build `
- `./target/debug/zed --user-data-dir ~/custom-data-dir`

Release Notes:
- Added --user-data-dir CLI flag

---------

Signed-off-by: Marko Kungla <marko.kungla@gmail.com>
2025-04-10 17:32:12 -04:00
gcp-cherry-pick-bot[bot]
ff12554c07 Do not query for LSP tasks buffers that do not belong to the position given (cherry-pick #28536) (#28538)
Cherry-picked Do not query for LSP tasks buffers that do not belong to
the position given (#28536)

Follow-up of https://github.com/zed-industries/zed/pull/28359

Release Notes:

- Fixed a panic when LSP tasks are queried in certain multi buffer
excerpts

Co-authored-by: Kirill Bulatov <kirill@zed.dev>
2025-04-10 15:03:45 -06:00
Zed Bot
aae71f512e Bump to 0.182.2 for @ConradIrwin 2025-04-10 20:42:53 +00:00
gcp-cherry-pick-bot[bot]
793da80471 Bump rustls (cherry-pick #28531) (#28535)
Cherry-picked Bump rustls (#28531)

Closes #26699

Release Notes:

- Fixed a panic when enabling or disabling a VPN on macOS

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
2025-04-10 14:38:31 -06:00
Nate Butler
4e82a27503 Add progress bar component (#28518)
- Adds the progress bar component

Release Notes:

- N/A
2025-04-10 14:18:12 -04:00
Thomas Mickley-Doyle
69335b3c73 agent: Add selected tool names to agent panel telemetry (#28247)
Release Notes:

- N/A
2025-04-10 14:16:00 -04:00
gcp-cherry-pick-bot[bot]
0b13333ca3 Fix merge conflicts jumping (cherry-pick #28508) (#28512)
Cherry-picked Fix merge conflicts jumping (#28508)

This regressed in #27568, oops.

Release Notes:

- Fixed a bug causing conflicted files in the git panel to jump to the
"Tracked" section as soon as they were staged.

Co-authored-by: Cole Miller <cole@zed.dev>
2025-04-10 10:51:58 -04:00
gcp-cherry-pick-bot[bot]
f6ef8662d4 Downgrade environment-related logging (cherry-pick #28509) (#28514)
Cherry-picked Downgrade environment-related logging (#28509)

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Co-authored-by: Cole Miller <cole@zed.dev>
2025-04-10 10:51:37 -04:00
Joseph T. Lyons
d5cc01abe7 zed 0.182.1 2025-04-10 09:21:04 -04:00
Agus Zubiaga
08fecda7ee agent: Use current shell (#28470)
Release Notes:

- agent: Replace `bash` tool with `terminal` tool which uses the current
shell

---------

Co-authored-by: Bennet <bennet@zed.dev>
Co-authored-by: Antonio <antonio@zed.dev>
2025-04-10 09:19:42 -04:00
Antonio Scandurra
538b88c260 Lay the groundwork for a Rust-based eval (#28488)
Also, we moved the logic for driving the agentic loop into `Thread` so
that we don't have to re-implement it.

Release Notes:

- N/A

---------

Co-authored-by: Nathan Sobo <nathan@zed.dev>
2025-04-10 09:19:42 -04:00
Bennet Bo Fenner
701e033fea agent: Optimize render_markdown_block function (#28487)
Co-Authored-by: Agus <agus@zed.dev>

Closes #ISSUE

Release Notes:

- N/A

Co-authored-by: Agus <agus@zed.dev>
2025-04-10 09:19:42 -04:00
Antonio Scandurra
9ed28c5e31 Revert "Add reminder message about system prompt" (#28482)
This breaks the agentic loop.
2025-04-10 09:19:42 -04:00
Danilo Leal
23b7a98152 agent: Fix toolbar spacing (#28485)
Release Notes:

- N/A
2025-04-10 09:19:42 -04:00
Danilo Leal
fe075ac273 agent: Add button to open thread as markdown (#28481)
<img
src="https://github.com/user-attachments/assets/92ca8f64-a949-4cc1-a657-3978a2c65839"
width="600"/>

Release Notes:

- agent: The action to open the current active thread in Markdown is now
exposed in the UI.
2025-04-09 23:12:03 -04:00
5brian
2fb6ec2d5a agent: Prevent sending whitespace only messages (#28409)
Prevent this from happening when sending a prompt with only spaces and
newlines:


![image](https://github.com/user-attachments/assets/b275f4c5-c013-4695-8fb4-e3ad75d41750)

Release Notes:

- agent: Prevent from sending messages containing only white space.
2025-04-09 23:12:03 -04:00
Danilo Leal
cef835c901 agent: Collapse code blocks in the active thread (#28467)
Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
2025-04-09 22:47:58 -04:00
Richard Feldman
0623a6a400 Add code action tool and rename tool (#28453)
Having a separate rename tool seems to make the agent more likely to use
it compared to having it be part of the code actions tool.

Release Notes:

- Added code action tool and rename tool.
2025-04-09 22:47:51 -04:00
Michael Sloan
a7722c9bc7 Fix directory context paths (#28459)
Release Notes:

- N/A
2025-04-09 22:47:30 -04:00
Bennet Bo Fenner
782b35aeb5 agent: Fuzzy match on paths and symbols when typing @ (#28357)
Release Notes:

- agent: Improve fuzzy matching when using @-mentions
2025-04-09 22:47:10 -04:00
Thomas Mickley-Doyle
612e30ea9c agent: Add reactions at the response level (#27958)
Release Notes:

- Added the user reaction (👍 or 👎) to each agent response.
- 👎 will trigger a comment box linked to the response

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Agus Zubiaga <hi@aguz.me>
2025-04-09 22:46:52 -04:00
Michael Sloan
8794d545c3 Reapply "Use Project instead of Workspace in ContextStore (#28402)" (#28441)
Motivation for this change is to use `ContextStore` in headless
assistant, which requires it to not depend on UI entities like
`Workspace`.

This reapplies a change that was revert was in #28428, and fixes the panic.

Release Notes:

- N/A
2025-04-09 22:46:45 -04:00
Richard Feldman
ec1ae32b8e Make regex search tool optionally case-sensitive (#28427)
Release Notes:

- The agent panel's regex search tool is now optionally case-sensitive.
2025-04-09 22:45:56 -04:00
Richard Feldman
17a271d4e2 Revert to fix panic in inline assistant (#28428)
This reverts commit f12a554f86, which
introduced a panic in inline assistant (cc @mgsloan) - I'm not sure what
the motivation was for that change, but I figure we can revert to fix
the inline assistant now and deal with that later. 😄

Panic was:

> Thread "main" panicked with "cannot read workspace::Workspace while it
is already being updated" at
/Users/rtfeldman/code/zed/crates/gpui/src/app/entity_map.rs:139:32


Release Notes:

- N/A
2025-04-09 11:25:13 -04:00
Agus Zubiaga
3d9dbfe902 Fix bash tool output (#28391) 2025-04-09 10:38:26 -04:00
Richard Feldman
b9aa296d4a Add reminder message about system prompt (#28344)
Trying out sending the model a reminder message about code blocks in the
system prompt. If this seems to work well, we can include more specific
reminder messages, e.g. tool-specific ones.

Release Notes:

- N/A
2025-04-09 10:38:19 -04:00
Joseph T. Lyons
305f113741 v0.182.x preview 2025-04-09 09:10:17 -04:00
96 changed files with 3771 additions and 2700 deletions

116
Cargo.lock generated
View File

@@ -52,7 +52,6 @@ dependencies = [
name = "agent"
version = "0.1.0"
dependencies = [
"agent_rules",
"anyhow",
"assistant_context_editor",
"assistant_settings",
@@ -65,6 +64,7 @@ dependencies = [
"clock",
"collections",
"command_palette_hooks",
"component",
"context_server",
"convert_case 0.8.0",
"db",
@@ -85,6 +85,7 @@ dependencies = [
"language",
"language_model",
"language_model_selector",
"linkme",
"log",
"lsp",
"markdown",
@@ -114,6 +115,7 @@ dependencies = [
"terminal_view",
"text",
"theme",
"thiserror 2.0.12",
"time",
"time_format",
"ui",
@@ -125,57 +127,6 @@ dependencies = [
"zed_actions",
]
[[package]]
name = "agent_eval"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"assistant_tools",
"clap",
"client",
"collections",
"context_server",
"dap",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language",
"language_model",
"language_models",
"node_runtime",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"serde_json",
"serde_json_lenient",
"settings",
"smol",
"tempfile",
"util",
"walkdir",
"workspace-hack",
]
[[package]]
name = "agent_rules"
version = "0.1.0"
dependencies = [
"anyhow",
"fs",
"gpui",
"indoc",
"prompt_store",
"util",
"workspace-hack",
"worktree",
]
[[package]]
name = "ahash"
version = "0.7.8"
@@ -1639,7 +1590,7 @@ dependencies = [
"hyper-util",
"pin-project-lite",
"rustls 0.21.12",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
@@ -4900,6 +4851,37 @@ dependencies = [
"num-traits",
]
[[package]]
name = "eval"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_settings",
"assistant_tool",
"assistant_tools",
"client",
"context_server",
"dap",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language",
"language_model",
"language_models",
"node_runtime",
"project",
"prompt_store",
"release_channel",
"reqwest_client",
"serde",
"settings",
"toml 0.8.20",
"workspace-hack",
]
[[package]]
name = "evals"
version = "0.1.0"
@@ -6635,7 +6617,7 @@ dependencies = [
name = "http_client_tls"
version = "0.1.0"
dependencies = [
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-platform-verifier",
"workspace-hack",
]
@@ -6734,7 +6716,7 @@ dependencies = [
"http 1.3.1",
"hyper 1.6.0",
"hyper-util",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-pki-types",
"tokio",
@@ -11262,7 +11244,7 @@ dependencies = [
"quinn-proto",
"quinn-udp",
"rustc-hash 2.1.1",
"rustls 0.23.25",
"rustls 0.23.26",
"socket2",
"thiserror 2.0.12",
"tokio",
@@ -11281,7 +11263,7 @@ dependencies = [
"rand 0.9.0",
"ring",
"rustc-hash 2.1.1",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pki-types",
"slab",
"thiserror 2.0.12",
@@ -11889,7 +11871,7 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-pemfile 2.2.0",
"rustls-pki-types",
@@ -12280,9 +12262,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.25"
version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [
"aws-lc-rs",
"log",
@@ -12356,7 +12338,7 @@ dependencies = [
"jni",
"log",
"once_cell",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-native-certs 0.8.1",
"rustls-platform-verifier-android",
"rustls-webpki 0.103.1",
@@ -13454,7 +13436,7 @@ dependencies = [
"once_cell",
"percent-encoding",
"rust_decimal",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pemfile 2.2.0",
"serde",
"serde_json",
@@ -14749,7 +14731,7 @@ version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
dependencies = [
"rustls 0.23.25",
"rustls 0.23.26",
"tokio",
]
@@ -14809,7 +14791,7 @@ checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.2",
@@ -15368,7 +15350,7 @@ dependencies = [
"httparse",
"log",
"rand 0.9.0",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-pki-types",
"sha1",
"thiserror 2.0.12",
@@ -17716,7 +17698,7 @@ dependencies = [
"rust_decimal",
"rustix 0.38.44",
"rustix 1.0.5",
"rustls 0.23.25",
"rustls 0.23.26",
"rustls-webpki 0.103.1",
"scopeguard",
"sea-orm",
@@ -18105,7 +18087,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.182.0"
version = "0.182.3"
dependencies = [
"activity_indicator",
"agent",

View File

@@ -3,13 +3,11 @@ resolver = "2"
members = [
"crates/activity_indicator",
"crates/agent",
"crates/agent_rules",
"crates/anthropic",
"crates/askpass",
"crates/assets",
"crates/assistant",
"crates/assistant_context_editor",
"crates/agent_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -47,6 +45,7 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/editor",
"crates/eval",
"crates/evals",
"crates/extension",
"crates/extension_api",
@@ -210,14 +209,12 @@ edition = "2024"
activity_indicator = { path = "crates/activity_indicator" }
agent = { path = "crates/agent" }
agent_rules = { path = "crates/agent_rules" }
ai = { path = "crates/ai" }
anthropic = { path = "crates/anthropic" }
askpass = { path = "crates/askpass" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
assistant_eval = { path = "crates/agent_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -509,7 +506,7 @@ runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804
rustc-demangle = "0.1.23"
rust-embed = { version = "8.4", features = ["include-exclude"] }
rustc-hash = "2.1.0"
rustls = { version = "0.23.22" }
rustls = { version = "0.23.26" }
rustls-platform-verifier = "0.5.0"
scap = { git = "https://github.com/zed-industries/scap", rev = "08f0a01417505cc0990b9931a37e5120db92e0d0", default-features = false }
schemars = { version = "0.8", features = ["impl_json_schema", "indexmap2"] }

View File

@@ -163,3 +163,8 @@ There are rules that apply to these root directories:
{{/if}}
{{/each}}
{{/if}}
<user_environment>
Operating System: {{os}} ({{arch}})
Shell: {{shell}}
</user_environment>

View File

@@ -624,14 +624,14 @@
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-5-sonnet-latest"
"model": "claude-3-7-sonnet-latest"
},
// The model to use when applying edits from the assistant.
"editor_model": {
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-5-sonnet-latest"
"model": "claude-3-7-sonnet-latest"
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
@@ -656,8 +656,9 @@
"name": "Write",
"enable_all_context_servers": true,
"tools": {
"bash": true,
"terminal": true,
"batch_tool": true,
"code_actions": true,
"code_symbols": true,
"copy_path": false,
"create_file": true,
@@ -671,6 +672,7 @@
"path_search": true,
"read_file": true,
"regex_search": true,
"rename": true,
"symbol_info": true,
"thinking": true
}

View File

@@ -19,7 +19,6 @@ test-support = [
]
[dependencies]
agent_rules.workspace = true
anyhow.workspace = true
assistant_context_editor.workspace = true
assistant_settings.workspace = true
@@ -32,6 +31,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
context_server.workspace = true
convert_case.workspace = true
db.workspace = true
@@ -51,6 +51,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
linkme.workspace = true
log.workspace = true
lsp.workspace = true
markdown.workspace = true
@@ -79,15 +80,16 @@ terminal.workspace = true
terminal_view.workspace = true
text.workspace = true
theme.workspace = true
thiserror.workspace = true
time.workspace = true
time_format.workspace = true
ui.workspace = true
ui_input.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
buffer_diff = { workspace = true, features = ["test-support"] }

View File

@@ -1,13 +1,13 @@
use crate::AssistantPanel;
use crate::context::{AssistantContext, ContextId};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, RequestKind, Thread, ThreadError,
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::ThreadStore;
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse, ToolUseStatus};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
use collections::{HashMap, HashSet};
@@ -21,13 +21,12 @@ use gpui::{
linear_color_stop, linear_gradient, list, percentage, pulsating_between,
};
use language::{Buffer, LanguageRegistry};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::parser::CodeBlockKind;
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown, without_fences};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role, StopReason};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
use project::ProjectItem as _;
use rope::Point;
use settings::{Settings as _, update_settings_file};
use std::ops::Range;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
@@ -57,14 +56,17 @@ pub struct ActiveThread {
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
expanded_thinking_segments: HashMap<(MessageId, usize), bool>,
expanded_code_blocks: HashMap<(MessageId, usize), bool>,
last_error: Option<ThreadError>,
notifications: Vec<WindowHandle<AgentNotification>>,
copied_code_block_ids: HashSet<(MessageId, usize)>,
_subscriptions: Vec<Subscription>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
feedback_message_editor: Option<Entity<Editor>>,
open_feedback_editors: HashMap<MessageId, Entity<Editor>>,
}
const MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK: usize = 5;
struct RenderedMessage {
language_registry: Arc<LanguageRegistry>,
segments: Vec<RenderedMessageSegment>,
@@ -294,10 +296,10 @@ fn render_markdown_code_block(
ix: usize,
kind: &CodeBlockKind,
parsed_markdown: &ParsedMarkdown,
codeblock_range: Range<usize>,
metadata: CodeBlockMetadata,
active_thread: Entity<ActiveThread>,
workspace: WeakEntity<Workspace>,
_window: &mut Window,
_window: &Window,
cx: &App,
) -> Div {
let label = match kind {
@@ -377,16 +379,20 @@ fn render_markdown_code_block(
.rounded_sm()
.hover(|item| item.bg(cx.theme().colors().element_hover.opacity(0.5)))
.tooltip(Tooltip::text("Jump to File"))
.children(
file_icons::FileIcons::get_icon(&path_range.path, cx)
.map(Icon::from_path)
.map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
)
.child(content)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
h_flex()
.gap_0p5()
.children(
file_icons::FileIcons::get_icon(&path_range.path, cx)
.map(Icon::from_path)
.map(|icon| icon.color(Color::Muted).size(IconSize::XSmall)),
)
.child(content)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
),
)
.on_click({
let path_range = path_range.clone();
@@ -444,17 +450,24 @@ fn render_markdown_code_block(
}),
};
let codeblock_was_copied = active_thread
.read(cx)
.copied_code_block_ids
.contains(&(message_id, ix));
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, ix))
.copied()
.unwrap_or(false);
let codeblock_header_bg = cx
.theme()
.colors()
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.01));
let codeblock_was_copied = active_thread
.read(cx)
.copied_code_block_ids
.contains(&(message_id, ix));
let codeblock_header = h_flex()
.group("codeblock_header")
.p_1()
@@ -466,57 +479,108 @@ fn render_markdown_code_block(
.rounded_t_md()
.children(label)
.child(
div().visible_on_hover("codeblock_header").child(
IconButton::new(
("copy-markdown-code", ix),
if codeblock_was_copied {
IconName::Check
} else {
IconName::Copy
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
h_flex()
.gap_1()
.child(
div().visible_on_hover("codeblock_header").child(
IconButton::new(
("copy-markdown-code", ix),
if codeblock_was_copied {
IconName::Check
} else {
IconName::Copy
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Copy Code"))
.on_click({
let active_thread = active_thread.clone();
let parsed_markdown = parsed_markdown.clone();
let code_block_range = metadata.content_range.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
this.copied_code_block_ids.insert((message_id, ix));
let code =
without_fences(&parsed_markdown.source()[codeblock_range.clone()])
.to_string();
let code = parsed_markdown.source()[code_block_range.clone()]
.to_string();
cx.write_to_clipboard(ClipboardItem::new_string(code));
cx.write_to_clipboard(ClipboardItem::new_string(code.clone()));
cx.spawn(async move |this, cx| {
cx.background_executor()
.timer(Duration::from_secs(2))
.await;
cx.spawn(async move |this, cx| {
cx.background_executor().timer(Duration::from_secs(2)).await;
cx.update(|cx| {
this.update(cx, |this, cx| {
this.copied_code_block_ids.remove(&(message_id, ix));
cx.notify();
cx.update(|cx| {
this.update(cx, |this, cx| {
this.copied_code_block_ids
.remove(&(message_id, ix));
cx.notify();
})
})
.ok();
})
})
.ok();
})
.detach();
});
}
}),
),
.detach();
});
}
}),
),
)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|header| {
header.child(
IconButton::new(
("expand-collapse-code", ix),
if is_expanded {
IconName::ChevronUp
} else {
IconName::ChevronDown
},
)
.icon_color(Color::Muted)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text(if is_expanded {
"Collapse Code"
} else {
"Expand Code"
}))
.on_click({
let active_thread = active_thread.clone();
move |_event, _window, cx| {
active_thread.update(cx, |this, cx| {
let is_expanded = this
.expanded_code_blocks
.entry((message_id, ix))
.or_insert(false);
*is_expanded = !*is_expanded;
cx.notify();
});
}
}),
)
},
),
);
v_flex()
.mb_2()
.relative()
.my_2()
.overflow_hidden()
.rounded_lg()
.border_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
.child(codeblock_header)
.when(
metadata.line_count > MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK,
|this| {
if is_expanded {
this.h_full()
} else {
this.max_h_40()
}
},
)
}
fn open_markdown_link(
@@ -604,6 +668,7 @@ impl ActiveThread {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe_in(&thread, window, Self::handle_thread_event),
cx.subscribe(&thread_store, Self::handle_rules_loading_error),
];
let list_state = ListState::new(0, ListAlignment::Bottom, px(2048.), {
@@ -626,6 +691,7 @@ impl ActiveThread {
rendered_tool_uses: HashMap::default(),
expanded_tool_uses: HashMap::default(),
expanded_thinking_segments: HashMap::default(),
expanded_code_blocks: HashMap::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state),
show_scrollbar: false,
@@ -636,7 +702,7 @@ impl ActiveThread {
notifications: Vec::new(),
_subscriptions: subscriptions,
notification_subscriptions: HashMap::default(),
feedback_message_editor: None,
open_feedback_editors: HashMap::default(),
};
for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
@@ -768,10 +834,9 @@ impl ActiveThread {
| ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
ThreadEvent::DoneStreaming => {
let thread = self.thread.read(cx);
if !thread.is_generating() {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn | StopReason::MaxTokens) => {
let thread = self.thread.read(cx);
self.show_notification(
if thread.used_tools_since_last_user_message() {
"Finished running tools"
@@ -783,7 +848,8 @@ impl ActiveThread {
cx,
);
}
}
_ => {}
},
ThreadEvent::ToolConfirmationNeeded => {
self.show_notification("Waiting for tool confirmation", IconName::Info, window, cx);
}
@@ -828,11 +894,7 @@ impl ActiveThread {
self.save_thread(cx);
cx.notify();
}
ThreadEvent::UsePendingTools => {
let tool_uses = self
.thread
.update(cx, |thread, cx| thread.use_pending_tools(cx));
ThreadEvent::UsePendingTools { tool_uses } => {
for tool_use in tool_uses {
self.render_tool_use_markdown(
tool_use.id.clone(),
@@ -844,11 +906,8 @@ impl ActiveThread {
}
}
ThreadEvent::ToolFinished {
pending_tool_use,
canceled,
..
pending_tool_use, ..
} => {
let canceled = *canceled;
if let Some(tool_use) = pending_tool_use {
self.render_tool_use_markdown(
tool_use.id.clone(),
@@ -862,23 +921,24 @@ impl ActiveThread {
cx,
);
}
if self.thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
if !canceled {
thread.send_to_model(model, RequestKind::Chat, cx);
}
});
}
}
}
ThreadEvent::CheckpointChanged => cx.notify(),
}
}
fn handle_rules_loading_error(
&mut self,
_thread_store: Entity<ThreadStore>,
error: &RulesLoadingError,
cx: &mut Context<Self>,
) {
self.last_error = Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: error.message.clone(),
});
cx.notify();
}
fn show_notification(
&mut self,
caption: impl Into<SharedString>,
@@ -939,7 +999,7 @@ impl ActiveThread {
|this, _, event, window, cx| match event {
AgentNotificationEvent::Accepted => {
let handle = window.window_handle();
cx.activate(true); // Switch back to the Zed application
cx.activate(true);
let workspace_handle = this.workspace.clone();
@@ -1111,34 +1171,37 @@ impl ActiveThread {
fn handle_feedback_click(
&mut self,
message_id: MessageId,
feedback: ThreadFeedback,
window: &mut Window,
cx: &mut Context<Self>,
) {
let report = self.thread.update(cx, |thread, cx| {
thread.report_message_feedback(message_id, feedback, cx)
});
cx.spawn(async move |this, cx| {
report.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
match feedback {
ThreadFeedback::Positive => {
let report = self
.thread
.update(cx, |thread, cx| thread.report_feedback(feedback, cx));
let this = cx.entity().downgrade();
cx.spawn(async move |_, cx| {
report.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
self.open_feedback_editors.remove(&message_id);
}
ThreadFeedback::Negative => {
self.handle_show_feedback_comments(window, cx);
self.handle_show_feedback_comments(message_id, window, cx);
}
}
}
fn handle_show_feedback_comments(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self.feedback_message_editor.is_some() {
return;
}
fn handle_show_feedback_comments(
&mut self,
message_id: MessageId,
window: &mut Window,
cx: &mut Context<Self>,
) {
let buffer = cx.new(|cx| {
let empty_string = String::new();
MultiBuffer::singleton(cx.new(|cx| Buffer::local(empty_string, cx)), cx)
@@ -1160,34 +1223,47 @@ impl ActiveThread {
});
editor.read(cx).focus_handle(cx).focus(window);
self.feedback_message_editor = Some(editor);
self.open_feedback_editors.insert(message_id, editor);
cx.notify();
}
fn submit_feedback_message(&mut self, cx: &mut Context<Self>) {
let Some(editor) = self.feedback_message_editor.clone() else {
fn submit_feedback_message(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
let Some(editor) = self.open_feedback_editors.get(&message_id) else {
return;
};
let report_task = self.thread.update(cx, |thread, cx| {
thread.report_feedback(ThreadFeedback::Negative, cx)
thread.report_message_feedback(message_id, ThreadFeedback::Negative, cx)
});
let comments = editor.read(cx).text(cx);
if !comments.is_empty() {
let thread_id = self.thread.read(cx).id().clone();
let comments_value = String::from(comments.as_str());
telemetry::event!("Assistant Thread Feedback Comments", thread_id, comments);
let message_content = self
.thread
.read(cx)
.message(message_id)
.map(|msg| msg.to_string())
.unwrap_or_default();
telemetry::event!(
"Assistant Thread Feedback Comments",
thread_id,
message_id = message_id.0,
message_content,
comments = comments_value
);
self.open_feedback_editors.remove(&message_id);
cx.spawn(async move |this, cx| {
report_task.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
}
self.feedback_message_editor = None;
let this = cx.entity().downgrade();
cx.spawn(async move |_, cx| {
report_task.await?;
this.update(cx, |_this, cx| cx.notify())
})
.detach_and_log_err(cx);
}
fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
@@ -1214,7 +1290,18 @@ impl ActiveThread {
let is_first_message = ix == 0;
let is_last_message = ix == self.messages.len() - 1;
let show_feedback = is_last_message && message.role != Role::User;
let show_feedback = (!is_generating && is_last_message && message.role != Role::User)
|| self.messages.get(ix + 1).map_or(false, |next_id| {
self.thread
.read(cx)
.message(*next_id)
.map_or(false, |next_message| {
next_message.role == Role::User
&& thread.tool_uses_for_message(*next_id, cx).is_empty()
&& thread.tool_results_for_message(*next_id).is_empty()
})
});
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
@@ -1287,8 +1374,17 @@ impl ActiveThread {
let editor_bg_color = colors.editor_background;
let bg_user_message_header = editor_bg_color.blend(active_color.opacity(0.25));
let feedback_container = h_flex().pt_2().pb_4().px_4().gap_1().justify_between();
let feedback_items = match self.thread.read(cx).feedback() {
let open_as_markdown = IconButton::new("open-as-markdown", IconName::FileCode)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.tooltip(Tooltip::text("Open Thread as Markdown"))
.on_click(|_event, window, cx| {
window.dispatch_action(Box::new(OpenActiveThreadAsMarkdown), cx)
});
let feedback_container = h_flex().py_2().px_4().gap_1().justify_between();
let feedback_items = match self.thread.read(cx).message_feedback(message_id) {
Some(feedback) => feedback_container
.child(
Label::new(match feedback {
@@ -1302,18 +1398,20 @@ impl ActiveThread {
)
.child(
h_flex()
.pr_1()
.gap_1()
.child(
IconButton::new("feedback-thumbs-up", IconName::ThumbsUp)
IconButton::new(("feedback-thumbs-up", ix), IconName::ThumbsUp)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(match feedback {
ThreadFeedback::Positive => Color::Accent,
ThreadFeedback::Negative => Color::Ignored,
})
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Helpful Response"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Positive,
window,
cx,
@@ -1321,22 +1419,24 @@ impl ActiveThread {
})),
)
.child(
IconButton::new("feedback-thumbs-down", IconName::ThumbsDown)
IconButton::new(("feedback-thumbs-down", ix), IconName::ThumbsDown)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(match feedback {
ThreadFeedback::Positive => Color::Ignored,
ThreadFeedback::Negative => Color::Accent,
})
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Not Helpful"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Negative,
window,
cx,
);
})),
),
)
.child(open_as_markdown),
)
.into_any_element(),
None => feedback_container
@@ -1349,15 +1449,17 @@ impl ActiveThread {
)
.child(
h_flex()
.pr_1()
.gap_1()
.child(
IconButton::new("feedback-thumbs-up", IconName::ThumbsUp)
IconButton::new(("feedback-thumbs-up", ix), IconName::ThumbsUp)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Helpful Response"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Positive,
window,
cx,
@@ -1365,19 +1467,21 @@ impl ActiveThread {
})),
)
.child(
IconButton::new("feedback-thumbs-down", IconName::ThumbsDown)
IconButton::new(("feedback-thumbs-down", ix), IconName::ThumbsDown)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
.shape(ui::IconButtonShape::Square)
.tooltip(Tooltip::text("Not Helpful"))
.on_click(cx.listener(move |this, _, window, cx| {
this.handle_feedback_click(
message_id,
ThreadFeedback::Negative,
window,
cx,
);
})),
),
)
.child(open_as_markdown),
)
.into_any_element(),
};
@@ -1669,31 +1773,31 @@ impl ActiveThread {
.child(generating_label.unwrap()),
)
})
.when(show_feedback && !is_generating, |parent| {
.when(show_feedback, move |parent| {
parent.child(feedback_items).when_some(
self.feedback_message_editor.clone(),
|parent, feedback_editor| {
self.open_feedback_editors.get(&message_id),
move |parent, feedback_editor| {
let focus_handle = feedback_editor.focus_handle(cx);
parent.child(
v_flex()
.key_context("AgentFeedbackMessageEditor")
.on_action(cx.listener(|this, _: &menu::Cancel, _, cx| {
this.feedback_message_editor = None;
.on_action(cx.listener(move |this, _: &menu::Cancel, _, cx| {
this.open_feedback_editors.remove(&message_id);
cx.notify();
}))
.on_action(cx.listener(|this, _: &menu::Confirm, _, cx| {
this.submit_feedback_message(cx);
.on_action(cx.listener(move |this, _: &menu::Confirm, _, cx| {
this.submit_feedback_message(message_id, cx);
cx.notify();
}))
.on_action(cx.listener(Self::confirm_editing_message))
.my_3()
.mb_2()
.mx_4()
.p_2()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().editor_background)
.child(feedback_editor)
.child(feedback_editor.clone())
.child(
h_flex()
.gap_1()
@@ -1710,10 +1814,13 @@ impl ActiveThread {
)
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(cx.listener(|this, _, _, cx| {
this.feedback_message_editor = None;
cx.notify();
})),
.on_click(cx.listener(
move |this, _, _window, cx| {
this.open_feedback_editors
.remove(&message_id);
cx.notify();
},
)),
)
.child(
Button::new(
@@ -1732,9 +1839,9 @@ impl ActiveThread {
.map(|kb| kb.size(rems_from_px(10.))),
)
.on_click(
cx.listener(|this, _, _, cx| {
this.submit_feedback_message(cx);
cx.notify();
cx.listener(move |this, _, _window, cx| {
this.submit_feedback_message(message_id, cx);
cx.notify()
}),
),
),
@@ -1799,13 +1906,13 @@ impl ActiveThread {
render: Arc::new({
let workspace = workspace.clone();
let active_thread = cx.entity();
move |id, kind, parsed_markdown, range, window, cx| {
move |kind, parsed_markdown, range, metadata, window, cx| {
render_markdown_code_block(
message_id,
id,
range.start,
kind,
parsed_markdown,
range,
metadata,
active_thread.clone(),
workspace.clone(),
window,
@@ -1813,6 +1920,47 @@ impl ActiveThread {
)
}
}),
transform: Some(Arc::new({
let active_thread = cx.entity();
move |el, range, metadata, _, cx| {
let is_expanded = active_thread
.read(cx)
.expanded_code_blocks
.get(&(message_id, range.start))
.copied()
.unwrap_or(false);
if is_expanded
|| metadata.line_count
<= MAX_UNCOLLAPSED_LINES_IN_CODE_BLOCK
{
return el;
}
el.child(
div()
.absolute()
.bottom_0()
.left_0()
.w_full()
.h_1_4()
.rounded_b_lg()
.bg(gpui::linear_gradient(
0.,
gpui::linear_color_stop(
cx.theme().colors().editor_background,
0.,
),
gpui::linear_color_stop(
cx.theme()
.colors()
.editor_background
.opacity(0.),
1.,
),
)),
)
}
})),
})
.on_url_click({
let workspace = self.workspace.clone();
@@ -2567,12 +2715,13 @@ impl ActiveThread {
}
fn render_rules_item(&self, cx: &Context<Self>) -> AnyElement {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
let project_context = self.thread.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
return div().into_any();
};
let rules_files = system_prompt_context
let rules_files = project_context
.worktrees
.iter()
.filter_map(|worktree| worktree.rules_file.as_ref())
@@ -2662,12 +2811,13 @@ impl ActiveThread {
}
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
let project_context = self.thread.read(cx).project_context();
let project_context = project_context.borrow();
let Some(project_context) = project_context.as_ref() else {
return;
};
let abs_paths = system_prompt_context
let abs_paths = project_context
.worktrees
.iter()
.flat_map(|worktree| worktree.rules_file.as_ref())
@@ -2804,10 +2954,10 @@ pub(crate) fn open_context(
}
}
AssistantContext::Directory(directory_context) => {
let path = directory_context.project_path.clone();
let project_path = directory_context.project_path(cx);
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
if let Some(entry) = project.entry_for_path(&path, cx) {
if let Some(entry) = project.entry_for_path(&project_path, cx) {
cx.emit(project::Event::RevealInProjectPanel(entry.id));
}
})

View File

@@ -921,15 +921,16 @@ mod tests {
})
.unwrap();
let thread_store = cx.update(|cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread_store = cx
.update(|cx| {
ThreadStore::load(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await;
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());

View File

@@ -194,10 +194,12 @@ impl AssistantPanel {
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), prompt_builder.clone(), cx)
})??;
let thread_store = workspace
.update(cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
})?
.await;
let slash_commands = Arc::new(SlashCommandWorkingSet::default());
let context_store = workspace
@@ -863,7 +865,11 @@ impl AssistantPanel {
.truncate()
.into_any_element()
} else {
change_title_editor.clone().into_any_element()
div()
.ml_2()
.w_full()
.child(change_title_editor.clone())
.into_any_element()
}
}
ActiveView::PromptEditor => {
@@ -1624,7 +1630,21 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
cx: &mut Context<PromptLibrary>,
) {
InlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&prompt_editor, self.workspace.clone(), None, window, cx)
let Some(project) = self
.workspace
.upgrade()
.map(|workspace| workspace.read(cx).project().downgrade())
else {
return;
};
assistant.assist(
&prompt_editor,
self.workspace.clone(),
project,
None,
window,
cx,
)
})
}

View File

@@ -1,9 +1,9 @@
use std::{ops::Range, sync::Arc};
use std::{ops::Range, path::Path, sync::Arc};
use gpui::{App, Entity, SharedString};
use language::{Buffer, File};
use language_model::LanguageModelRequestMessage;
use project::ProjectPath;
use project::{ProjectPath, Worktree};
use serde::{Deserialize, Serialize};
use text::{Anchor, BufferId};
use ui::IconName;
@@ -69,10 +69,21 @@ pub struct FileContext {
#[derive(Debug, Clone)]
pub struct DirectoryContext {
pub id: ContextId,
pub project_path: ProjectPath,
pub worktree: Entity<Worktree>,
pub path: Arc<Path>,
/// Buffers of the files within the directory.
pub context_buffers: Vec<ContextBuffer>,
}
impl DirectoryContext {
pub fn project_path(&self, cx: &App) -> ProjectPath {
ProjectPath {
worktree_id: self.worktree.read(cx).id(),
path: self.path.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct SymbolContext {
pub id: ContextId,
@@ -86,12 +97,11 @@ pub struct FetchedUrlContext {
pub text: SharedString,
}
// TODO: Model<Thread> holds onto the thread even if the thread is deleted. Can either handle this
// explicitly or have a WeakModel<Thread> and remove during snapshot.
#[derive(Debug, Clone)]
pub struct ThreadContext {
pub id: ContextId,
// TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
// a WeakEntity and handle removal from the UI when it has dropped.
pub thread: Entity<Thread>,
pub text: SharedString,
}
@@ -105,12 +115,11 @@ impl ThreadContext {
}
}
// TODO: Model<Buffer> holds onto the buffer even if the file is deleted and closed. Should remove
// the context from the message editor in this case.
#[derive(Clone)]
pub struct ContextBuffer {
pub id: BufferId,
// TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
// a WeakEntity and handle removal from the UI when it has dropped.
pub buffer: Entity<Buffer>,
pub file: Arc<dyn File>,
pub version: clock::Global,

View File

@@ -289,12 +289,14 @@ impl ContextPicker {
path_prefix,
} => {
let context_store = self.context_store.clone();
let worktree_id = project_path.worktree_id;
let path = project_path.path.clone();
ContextMenuItem::custom_entry(
move |_window, cx| {
render_file_context_entry(
ElementId::NamedInteger("ctx-recent".into(), ix),
worktree_id,
&path,
&path_prefix,
false,
@@ -466,7 +468,7 @@ fn recent_context_picker_entries(
recent.extend(
workspace
.recent_navigation_history_iter(cx)
.filter(|(path, _)| !current_files.contains(&path.path.to_path_buf()))
.filter(|(path, _)| !current_files.contains(path))
.take(4)
.filter_map(|(project_path, _)| {
project

View File

@@ -18,16 +18,133 @@ use text::{Anchor, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::AssistantContext;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::thread_context_picker::ThreadContextEntry;
use super::file_context_picker::FileMatch;
use super::symbol_context_picker::SymbolMatch;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerMode, MentionLink, recent_context_picker_entries, supported_context_picker_modes,
ContextPickerMode, MentionLink, RecentEntry, recent_context_picker_entries,
supported_context_picker_modes,
};
pub(crate) enum Match {
Symbol(SymbolMatch),
File(FileMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Mode(ContextPickerMode),
}
fn search(
mode: Option<ContextPickerMode>,
query: String,
cancellation_flag: Arc<AtomicBool>,
recent_entries: Vec<RecentEntry>,
thread_store: Option<WeakEntity<ThreadStore>>,
workspace: Entity<Workspace>,
cx: &mut App,
) -> Task<Vec<Match>> {
match mode {
Some(ContextPickerMode::File) => {
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_files_task
.await
.into_iter()
.map(Match::File)
.collect()
})
}
Some(ContextPickerMode::Symbol) => {
let search_symbols_task =
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_symbols_task
.await
.into_iter()
.map(Match::Symbol)
.collect()
})
}
Some(ContextPickerMode::Thread) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_threads_task =
search_threads(query.clone(), cancellation_flag.clone(), thread_store, cx);
cx.background_spawn(async move {
search_threads_task
.await
.into_iter()
.map(Match::Thread)
.collect()
})
} else {
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Fetch) => {
if !query.is_empty() {
Task::ready(vec![Match::Fetch(query.into())])
} else {
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
.into_iter()
.map(|entry| match entry {
super::RecentEntry::File {
project_path,
path_prefix,
} => Match::File(FileMatch {
mat: fuzzy::PathMatch {
score: 1.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix,
is_dir: false,
distance_to_relative_ancestor: 0,
},
is_recent: true,
}),
super::RecentEntry::Thread(thread_context_entry) => {
Match::Thread(ThreadMatch {
thread: thread_context_entry,
is_recent: true,
})
}
})
.collect::<Vec<_>>();
matches.extend(
supported_context_picker_modes(&thread_store)
.into_iter()
.map(Match::Mode),
);
Task::ready(matches)
} else {
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
cx.background_spawn(async move {
search_files_task
.await
.into_iter()
.map(Match::File)
.collect()
})
}
}
}
}
pub struct ContextPickerCompletionProvider {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
@@ -50,97 +167,20 @@ impl ContextPickerCompletionProvider {
}
}
fn default_completions(
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
context_store: Entity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
editor: Entity<Editor>,
workspace: Entity<Workspace>,
cx: &App,
) -> Vec<Completion> {
let mut completions = Vec::new();
completions.extend(
recent_context_picker_entries(
context_store.clone(),
thread_store.clone(),
workspace.clone(),
cx,
)
.iter()
.filter_map(|entry| match entry {
super::RecentEntry::File {
project_path,
path_prefix,
} => Some(Self::completion_for_path(
project_path.clone(),
path_prefix,
true,
false,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
cx,
)),
super::RecentEntry::Thread(thread_context_entry) => {
let thread_store = thread_store
.as_ref()
.and_then(|thread_store| thread_store.upgrade())?;
Some(Self::completion_for_thread(
thread_context_entry.clone(),
excerpt_id,
source_range.clone(),
true,
editor.clone(),
context_store.clone(),
thread_store,
))
}
}),
);
completions.extend(
supported_context_picker_modes(&thread_store)
.iter()
.map(|mode| {
Completion {
replace_range: source_range.clone(),
new_text: format!("@{} ", mode.mention_prefix()),
label: CodeLabel::plain(mode.label().to_string(), None),
icon_path: Some(mode.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(Arc::new(|_, _, _| true)),
}
}),
);
completions
}
fn build_code_label_for_full_path(
file_name: &str,
directory: Option<&str>,
cx: &App,
) -> CodeLabel {
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
let mut label = CodeLabel::default();
label.push_str(&file_name, None);
label.push_str(" ", None);
if let Some(directory) = directory {
label.push_str(&directory, comment_id);
fn completion_for_mode(source_range: Range<Anchor>, mode: ContextPickerMode) -> Completion {
Completion {
replace_range: source_range.clone(),
new_text: format!("@{} ", mode.mention_prefix()),
label: CodeLabel::plain(mode.label().to_string(), None),
icon_path: Some(mode.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(Arc::new(|_, _, _| true)),
}
label.filter_range = 0..label.text().len();
label
}
fn completion_for_thread(
@@ -261,11 +301,8 @@ impl ContextPickerCompletionProvider {
path_prefix,
);
let label = Self::build_code_label_for_full_path(
&file_name,
directory.as_ref().map(|s| s.as_ref()),
cx,
);
let label =
build_code_label_for_full_path(&file_name, directory.as_ref().map(|s| s.as_ref()), cx);
let full_path = if let Some(directory) = directory {
format!("{}{}", directory, file_name)
} else {
@@ -382,6 +419,22 @@ impl ContextPickerCompletionProvider {
}
}
fn build_code_label_for_full_path(file_name: &str, directory: Option<&str>, cx: &App) -> CodeLabel {
let comment_id = cx.theme().syntax().highlight_id("comment").map(HighlightId);
let mut label = CodeLabel::default();
label.push_str(&file_name, None);
label.push_str(" ", None);
if let Some(directory) = directory {
label.push_str(&directory, comment_id);
}
label.filter_range = 0..label.text().len();
label
}
impl CompletionProvider for ContextPickerCompletionProvider {
fn completions(
&self,
@@ -404,10 +457,9 @@ impl CompletionProvider for ContextPickerCompletionProvider {
return Task::ready(Ok(None));
};
let Some(workspace) = self.workspace.upgrade() else {
return Task::ready(Ok(None));
};
let Some(context_store) = self.context_store.upgrade() else {
let Some((workspace, context_store)) =
self.workspace.upgrade().zip(self.context_store.upgrade())
else {
return Task::ready(Ok(None));
};
@@ -419,154 +471,89 @@ impl CompletionProvider for ContextPickerCompletionProvider {
let editor = self.editor.clone();
let http_client = workspace.read(cx).client().http_client().clone();
let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string());
let recent_entries = recent_context_picker_entries(
context_store.clone(),
thread_store.clone(),
workspace.clone(),
cx,
);
let search_task = search(
mode,
query,
Arc::<AtomicBool>::default(),
recent_entries,
thread_store.clone(),
workspace.clone(),
cx,
);
cx.spawn(async move |_, cx| {
let mut completions = Vec::new();
let matches = search_task.await;
let Some(editor) = editor.upgrade() else {
return Ok(None);
};
let MentionCompletion { mode, argument, .. } = state;
let query = argument.unwrap_or_else(|| "".to_string());
match mode {
Some(ContextPickerMode::File) => {
let path_matches = cx
.update(|cx| {
super::file_context_picker::search_paths(
query,
Arc::<AtomicBool>::default(),
&workspace,
cx,
)
})?
.await;
if let Some(editor) = editor.upgrade() {
completions.reserve(path_matches.len());
cx.update(|cx| {
completions.extend(path_matches.iter().map(|mat| {
Self::completion_for_path(
ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
},
&mat.path_prefix,
false,
mat.is_dir,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
cx,
)
}));
})?;
}
}
Some(ContextPickerMode::Symbol) => {
if let Some(editor) = editor.upgrade() {
let symbol_matches = cx
.update(|cx| {
super::symbol_context_picker::search_symbols(
query,
Arc::new(AtomicBool::default()),
&workspace,
cx,
)
})?
.await?;
cx.update(|cx| {
completions.extend(symbol_matches.into_iter().filter_map(
|(_, symbol)| {
Self::completion_for_symbol(
symbol,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
workspace.clone(),
cx,
)
Ok(Some(cx.update(|cx| {
matches
.into_iter()
.filter_map(|mat| match mat {
Match::File(FileMatch { mat, is_recent }) => {
Some(Self::completion_for_path(
ProjectPath {
worktree_id: WorktreeId::from_usize(mat.worktree_id),
path: mat.path.clone(),
},
));
})?;
}
}
Some(ContextPickerMode::Fetch) => {
if let Some(editor) = editor.upgrade() {
if !query.is_empty() {
completions.push(Self::completion_for_fetch(
source_range.clone(),
query.into(),
&mat.path_prefix,
is_recent,
mat.is_dir,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
http_client.clone(),
));
}
context_store.update(cx, |store, _| {
let urls = store.context().iter().filter_map(|context| {
if let AssistantContext::FetchedUrl(context) = context {
Some(context.url.clone())
} else {
None
}
});
for url in urls {
completions.push(Self::completion_for_fetch(
source_range.clone(),
url,
excerpt_id,
editor.clone(),
context_store.clone(),
http_client.clone(),
));
}
})?;
}
}
Some(ContextPickerMode::Thread) => {
if let Some((thread_store, editor)) = thread_store
.and_then(|thread_store| thread_store.upgrade())
.zip(editor.upgrade())
{
let threads = cx
.update(|cx| {
super::thread_context_picker::search_threads(
query,
thread_store.clone(),
cx,
)
})?
.await;
for thread in threads {
completions.push(Self::completion_for_thread(
thread.clone(),
excerpt_id,
source_range.clone(),
false,
editor.clone(),
context_store.clone(),
thread_store.clone(),
));
}
}
}
None => {
cx.update(|cx| {
if let Some(editor) = editor.upgrade() {
completions.extend(Self::default_completions(
excerpt_id,
source_range.clone(),
context_store.clone(),
thread_store.clone(),
editor,
workspace.clone(),
cx,
));
))
}
})?;
}
}
Ok(Some(completions))
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
symbol,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
workspace.clone(),
cx,
),
Match::Thread(ThreadMatch {
thread, is_recent, ..
}) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_thread(
thread,
excerpt_id,
source_range.clone(),
is_recent,
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
excerpt_id,
editor.clone(),
context_store.clone(),
http_client.clone(),
)),
Match::Mode(mode) => {
Some(Self::completion_for_mode(source_range.clone(), mode))
}
})
.collect()
})?))
})
}
@@ -676,7 +663,12 @@ impl MentionCompletion {
let mut end = last_mention_start + 1;
if let Some(mode_text) = parts.next() {
end += mode_text.len();
mode = ContextPickerMode::try_from(mode_text).ok();
if let Some(parsed_mode) = ContextPickerMode::try_from(mode_text).ok() {
mode = Some(parsed_mode);
} else {
argument = Some(mode_text.to_string());
}
match rest_of_line[mode_text.len()..].find(|c: char| !c.is_whitespace()) {
Some(whitespace_count) => {
if let Some(argument_text) = parts.next() {
@@ -702,13 +694,13 @@ impl MentionCompletion {
#[cfg(test)]
mod tests {
use super::*;
use gpui::{Focusable, TestAppContext, VisualTestContext};
use gpui::{EventEmitter, FocusHandle, Focusable, TestAppContext, VisualTestContext};
use project::{Project, ProjectPath};
use serde_json::json;
use settings::SettingsStore;
use std::{ops::Deref, path::PathBuf};
use std::ops::Deref;
use util::{path, separator};
use workspace::AppState;
use workspace::{AppState, Item};
#[test]
fn test_mention_completion_parse() {
@@ -768,9 +760,42 @@ mod tests {
})
);
assert_eq!(
MentionCompletion::try_parse("Lorem @main", 0),
Some(MentionCompletion {
source_range: 6..11,
mode: None,
argument: Some("main".to_string()),
})
);
assert_eq!(MentionCompletion::try_parse("test@", 0), None);
}
struct AtMentionEditor(Entity<Editor>);
impl Item for AtMentionEditor {
type Event = ();
fn include_in_nav_history() -> bool {
false
}
}
impl EventEmitter<()> for AtMentionEditor {}
impl Focusable for AtMentionEditor {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.0.read(cx).focus_handle(cx).clone()
}
}
impl Render for AtMentionEditor {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.0.clone().into_any_element()
}
}
#[gpui::test]
async fn test_context_completion_provider(cx: &mut TestAppContext) {
init_test(cx);
@@ -846,25 +871,27 @@ mod tests {
.unwrap();
}
let item = workspace
.update_in(&mut cx, |workspace, window, cx| {
workspace.open_path(
ProjectPath {
worktree_id,
path: PathBuf::from("editor").into(),
},
let editor = workspace.update_in(&mut cx, |workspace, window, cx| {
let editor = cx.new(|cx| {
Editor::new(
editor::EditorMode::Full,
multi_buffer::MultiBuffer::build_simple("", cx),
None,
true,
window,
cx,
)
})
.await
.expect("Could not open test file");
let editor = cx.update(|_, cx| {
item.act_as::<Editor>(cx)
.expect("Opened test file wasn't an editor")
});
workspace.active_pane().update(cx, |pane, cx| {
pane.add_item(
Box::new(cx.new(|_| AtMentionEditor(editor.clone()))),
true,
true,
None,
window,
cx,
);
});
editor
});
let context_store = cx.new(|_| ContextStore::new(project.downgrade(), None));
@@ -895,10 +922,10 @@ mod tests {
assert_eq!(
current_completion_labels(editor),
&[
"editor dir/",
"seven.txt dir/b/",
"six.txt dir/b/",
"five.txt dir/b/",
"four.txt dir/a/",
"Files & Directories",
"Symbols",
"Fetch"
@@ -993,14 +1020,14 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@seven.txt](@file:dir/b/seven.txt)"
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71)
Point::new(0, 44)..Point::new(0, 79)
]
);
});
@@ -1010,14 +1037,14 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)\n@"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@seven.txt](@file:dir/b/seven.txt)\n@"
);
assert!(editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71)
Point::new(0, 44)..Point::new(0, 79)
]
);
});
@@ -1031,15 +1058,15 @@ mod tests {
editor.update(&mut cx, |editor, cx| {
assert_eq!(
editor.text(cx),
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@editor](@file:dir/editor)\n[@seven.txt](@file:dir/b/seven.txt)"
"Lorem [@one.txt](@file:dir/a/one.txt) Ipsum [@seven.txt](@file:dir/b/seven.txt)\n[@six.txt](@file:dir/b/six.txt)"
);
assert!(!editor.has_visible_completions_menu());
assert_eq!(
crease_ranges(editor, cx),
vec![
Point::new(0, 6)..Point::new(0, 37),
Point::new(0, 44)..Point::new(0, 71),
Point::new(1, 0)..Point::new(1, 35)
Point::new(0, 44)..Point::new(0, 79),
Point::new(1, 0)..Point::new(1, 31)
]
);
});

View File

@@ -58,7 +58,7 @@ pub struct FileContextPickerDelegate {
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
confirm_behavior: ConfirmBehavior,
matches: Vec<PathMatch>,
matches: Vec<FileMatch>,
selected_index: usize,
}
@@ -114,7 +114,7 @@ impl PickerDelegate for FileContextPickerDelegate {
return Task::ready(());
};
let search_task = search_paths(query, Arc::<AtomicBool>::default(), &workspace, cx);
let search_task = search_files(query, Arc::<AtomicBool>::default(), &workspace, cx);
cx.spawn_in(window, async move |this, cx| {
// TODO: This should be probably be run in the background.
@@ -128,7 +128,7 @@ impl PickerDelegate for FileContextPickerDelegate {
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(mat) = self.matches.get(self.selected_index) else {
let Some(FileMatch { mat, .. }) = self.matches.get(self.selected_index) else {
return;
};
@@ -181,7 +181,7 @@ impl PickerDelegate for FileContextPickerDelegate {
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let path_match = &self.matches[ix];
let FileMatch { mat, .. } = &self.matches[ix];
Some(
ListItem::new(ix)
@@ -189,9 +189,10 @@ impl PickerDelegate for FileContextPickerDelegate {
.toggle_state(selected)
.child(render_file_context_entry(
ElementId::NamedInteger("file-ctx-picker".into(), ix),
&path_match.path,
&path_match.path_prefix,
path_match.is_dir,
WorktreeId::from_usize(mat.worktree_id),
&mat.path,
&mat.path_prefix,
mat.is_dir,
self.context_store.clone(),
cx,
)),
@@ -199,12 +200,17 @@ impl PickerDelegate for FileContextPickerDelegate {
}
}
pub(crate) fn search_paths(
pub struct FileMatch {
pub mat: PathMatch,
pub is_recent: bool,
}
pub(crate) fn search_files(
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &App,
) -> Task<Vec<PathMatch>> {
) -> Task<Vec<FileMatch>> {
if query.is_empty() {
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
@@ -213,28 +219,34 @@ pub(crate) fn search_paths(
.into_iter()
.filter_map(|(project_path, _)| {
let worktree = project.worktree_for_id(project_path.worktree_id, cx)?;
Some(PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix: worktree.read(cx).root_name().into(),
distance_to_relative_ancestor: 0,
is_dir: false,
Some(FileMatch {
mat: PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: project_path.worktree_id.to_usize(),
path: project_path.path,
path_prefix: worktree.read(cx).root_name().into(),
distance_to_relative_ancestor: 0,
is_dir: false,
},
is_recent: true,
})
});
let file_matches = project.worktrees(cx).flat_map(|worktree| {
let worktree = worktree.read(cx);
let path_prefix: Arc<str> = worktree.root_name().into();
worktree.entries(false, 0).map(move |entry| PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: entry.is_dir(),
worktree.entries(false, 0).map(move |entry| FileMatch {
mat: PathMatch {
score: 0.,
positions: Vec::new(),
worktree_id: worktree.id().to_usize(),
path: entry.path.clone(),
path_prefix: path_prefix.clone(),
distance_to_relative_ancestor: 0,
is_dir: entry.is_dir(),
},
is_recent: false,
})
});
@@ -269,6 +281,12 @@ pub(crate) fn search_paths(
executor,
)
.await
.into_iter()
.map(|mat| FileMatch {
mat,
is_recent: false,
})
.collect::<Vec<_>>()
})
}
}
@@ -311,19 +329,26 @@ pub fn extract_file_name_and_directory(
pub fn render_file_context_entry(
id: ElementId,
path: &Path,
worktree_id: WorktreeId,
path: &Arc<Path>,
path_prefix: &Arc<str>,
is_directory: bool,
context_store: WeakEntity<ContextStore>,
cx: &App,
) -> Stateful<Div> {
let (file_name, directory) = extract_file_name_and_directory(path, path_prefix);
let (file_name, directory) = extract_file_name_and_directory(&path, path_prefix);
let added = context_store.upgrade().and_then(|context_store| {
let project_path = ProjectPath {
worktree_id,
path: path.clone(),
};
if is_directory {
context_store.read(cx).includes_directory(path)
context_store.read(cx).includes_directory(&project_path)
} else {
context_store.read(cx).will_include_file_path(path, cx)
context_store
.read(cx)
.will_include_file_path(&project_path, cx)
}
});
@@ -363,8 +388,9 @@ pub fn render_file_context_entry(
)
.child(Label::new("Added").size(LabelSize::Small)),
),
FileInclusion::InDirectory(dir_name) => {
let dir_name = dir_name.to_string_lossy().into_owned();
FileInclusion::InDirectory(directory_project_path) => {
// TODO: Consider using worktree full_path to include worktree name.
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
el.child(
h_flex()
@@ -378,7 +404,7 @@ pub fn render_file_context_entry(
)
.child(Label::new("Included").size(LabelSize::Small)),
)
.tooltip(Tooltip::text(format!("in {dir_name}")))
.tooltip(Tooltip::text(format!("in {directory_path}")))
}
})
}

View File

@@ -2,7 +2,7 @@ use std::cmp::Reverse;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::{Context as _, Result};
use anyhow::Result;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
App, AppContext, DismissEvent, Entity, FocusHandle, Focusable, Stateful, Task, WeakEntity,
@@ -119,11 +119,7 @@ impl PickerDelegate for SymbolContextPickerDelegate {
let search_task = search_symbols(query, Arc::<AtomicBool>::default(), &workspace, cx);
let context_store = self.context_store.clone();
cx.spawn_in(window, async move |this, cx| {
let symbols = search_task
.await
.context("Failed to load symbols")
.log_err()
.unwrap_or_default();
let symbols = search_task.await;
let symbol_entries = context_store
.read_with(cx, |context_store, cx| {
@@ -285,12 +281,16 @@ fn find_matching_symbol(symbol: &Symbol, candidates: &[DocumentSymbol]) -> Optio
}
}
pub struct SymbolMatch {
pub symbol: Symbol,
}
pub(crate) fn search_symbols(
query: String,
cancellation_flag: Arc<AtomicBool>,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Task<Result<Vec<(StringMatch, Symbol)>>> {
) -> Task<Vec<SymbolMatch>> {
let symbols_task = workspace.update(cx, |workspace, cx| {
workspace
.project()
@@ -298,19 +298,28 @@ pub(crate) fn search_symbols(
});
let project = workspace.read(cx).project().clone();
cx.spawn(async move |cx| {
let symbols = symbols_task.await?;
let (visible_match_candidates, external_match_candidates): (Vec<_>, Vec<_>) = project
.update(cx, |project, cx| {
symbols
.iter()
.enumerate()
.map(|(id, symbol)| StringMatchCandidate::new(id, &symbol.label.filter_text()))
.partition(|candidate| {
project
.entry_for_path(&symbols[candidate.id].path, cx)
.map_or(false, |e| !e.is_ignored)
})
})?;
let Some(symbols) = symbols_task.await.log_err() else {
return Vec::new();
};
let Some((visible_match_candidates, external_match_candidates)): Option<(Vec<_>, Vec<_>)> =
project
.update(cx, |project, cx| {
symbols
.iter()
.enumerate()
.map(|(id, symbol)| {
StringMatchCandidate::new(id, &symbol.label.filter_text())
})
.partition(|candidate| {
project
.entry_for_path(&symbols[candidate.id].path, cx)
.map_or(false, |e| !e.is_ignored)
})
})
.log_err()
else {
return Vec::new();
};
const MAX_MATCHES: usize = 100;
let mut visible_matches = cx.background_executor().block(fuzzy::match_strings(
@@ -339,7 +348,7 @@ pub(crate) fn search_symbols(
let mut matches = visible_matches;
matches.append(&mut external_matches);
Ok(matches
matches
.into_iter()
.map(|mut mat| {
let symbol = symbols[mat.candidate_id].clone();
@@ -347,19 +356,19 @@ pub(crate) fn search_symbols(
for position in &mut mat.positions {
*position += filter_start;
}
(mat, symbol)
SymbolMatch { symbol }
})
.collect())
.collect()
})
}
fn compute_symbol_entries(
symbols: Vec<(StringMatch, Symbol)>,
symbols: Vec<SymbolMatch>,
context_store: &ContextStore,
cx: &App,
) -> Vec<SymbolEntry> {
let mut symbol_entries = Vec::with_capacity(symbols.len());
for (_, symbol) in symbols {
for SymbolMatch { symbol, .. } in symbols {
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
let is_included = if let Some(symbols_for_path) = symbols_for_path {
let mut is_included = false;

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use fuzzy::StringMatchCandidate;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
@@ -114,11 +115,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
return Task::ready(());
};
let search_task = search_threads(query, threads, cx);
let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
this.delegate.matches = matches;
this.delegate.matches = matches.into_iter().map(|mat| mat.thread).collect();
this.delegate.selected_index = 0;
cx.notify();
})
@@ -217,11 +218,18 @@ pub fn render_thread_context_entry(
})
}
#[derive(Clone)]
pub struct ThreadMatch {
pub thread: ThreadContextEntry,
pub is_recent: bool,
}
pub(crate) fn search_threads(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<ThreadContextEntry>> {
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store.update(cx, |this, _cx| {
this.threads()
.into_iter()
@@ -236,6 +244,12 @@ pub(crate) fn search_threads(
cx.background_spawn(async move {
if query.is_empty() {
threads
.into_iter()
.map(|thread| ThreadMatch {
thread,
is_recent: false,
})
.collect()
} else {
let candidates = threads
.iter()
@@ -247,14 +261,17 @@ pub(crate) fn search_threads(
&query,
false,
100,
&Default::default(),
&cancellation_flag,
executor,
)
.await;
matches
.into_iter()
.map(|mat| threads[mat.candidate_id].clone())
.map(|mat| ThreadMatch {
thread: threads[mat.candidate_id].clone(),
is_recent: false,
})
.collect()
}
})

View File

@@ -1,5 +1,5 @@
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::sync::Arc;
use anyhow::{Context as _, Result, anyhow};
@@ -28,7 +28,7 @@ pub struct ContextStore {
// TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId.
next_context_id: ContextId,
files: BTreeMap<BufferId, ContextId>,
directories: HashMap<PathBuf, ContextId>,
directories: HashMap<ProjectPath, ContextId>,
symbols: HashMap<ContextSymbolId, ContextId>,
symbol_buffers: HashMap<ContextSymbolId, Entity<Buffer>>,
symbols_by_path: HashMap<ProjectPath, Vec<ContextSymbolId>>,
@@ -93,7 +93,7 @@ impl ContextStore {
let buffer_id = this.update(cx, |_, cx| buffer.read(cx).remote_id())?;
let already_included = this.update(cx, |this, cx| {
match this.will_include_buffer(buffer_id, &project_path.path) {
match this.will_include_buffer(buffer_id, &project_path) {
Some(FileInclusion::Direct(context_id)) => {
if remove_if_exists {
this.remove_context(context_id, cx);
@@ -159,7 +159,7 @@ impl ContextStore {
return Task::ready(Err(anyhow!("failed to read project")));
};
let already_included = match self.includes_directory(&project_path.path) {
let already_included = match self.includes_directory(&project_path) {
Some(FileInclusion::Direct(context_id)) => {
if remove_if_exists {
self.remove_context(context_id, cx);
@@ -223,14 +223,12 @@ impl ContextStore {
.collect::<Vec<_>>();
if context_buffers.is_empty() {
return Err(anyhow!(
"No text files found in {}",
&project_path.path.display()
));
let full_path = cx.update(|cx| worktree.read(cx).full_path(&project_path.path))?;
return Err(anyhow!("No text files found in {}", &full_path.display()));
}
this.update(cx, |this, cx| {
this.insert_directory(project_path, context_buffers, cx);
this.insert_directory(worktree, project_path, context_buffers, cx);
})?;
anyhow::Ok(())
@@ -239,17 +237,20 @@ impl ContextStore {
fn insert_directory(
&mut self,
worktree: Entity<Worktree>,
project_path: ProjectPath,
context_buffers: Vec<ContextBuffer>,
cx: &mut Context<Self>,
) {
let id = self.next_context_id.post_inc();
self.directories.insert(project_path.path.to_path_buf(), id);
let path = project_path.path.clone();
self.directories.insert(project_path, id);
self.context
.push(AssistantContext::Directory(DirectoryContext {
id,
project_path,
worktree,
path,
context_buffers,
}));
cx.notify();
@@ -478,23 +479,31 @@ impl ContextStore {
/// Returns whether the buffer is already included directly in the context, or if it will be
/// included in the context via a directory. Directory inclusion is based on paths rather than
/// buffer IDs as the directory will be re-scanned.
pub fn will_include_buffer(&self, buffer_id: BufferId, path: &Path) -> Option<FileInclusion> {
pub fn will_include_buffer(
&self,
buffer_id: BufferId,
project_path: &ProjectPath,
) -> Option<FileInclusion> {
if let Some(context_id) = self.files.get(&buffer_id) {
return Some(FileInclusion::Direct(*context_id));
}
self.will_include_file_path_via_directory(path)
self.will_include_file_path_via_directory(project_path)
}
/// Returns whether this file path is already included directly in the context, or if it will be
/// included in the context via a directory.
pub fn will_include_file_path(&self, path: &Path, cx: &App) -> Option<FileInclusion> {
pub fn will_include_file_path(
&self,
project_path: &ProjectPath,
cx: &App,
) -> Option<FileInclusion> {
if !self.files.is_empty() {
let found_file_context = self.context.iter().find(|context| match &context {
AssistantContext::File(file_context) => {
let buffer = file_context.context_buffer.buffer.read(cx);
if let Some(file_path) = buffer_path_log_err(buffer, cx) {
*file_path == *path
if let Some(context_path) = buffer.project_path(cx) {
&context_path == project_path
} else {
false
}
@@ -506,31 +515,40 @@ impl ContextStore {
}
}
self.will_include_file_path_via_directory(path)
self.will_include_file_path_via_directory(project_path)
}
fn will_include_file_path_via_directory(&self, path: &Path) -> Option<FileInclusion> {
fn will_include_file_path_via_directory(
&self,
project_path: &ProjectPath,
) -> Option<FileInclusion> {
if self.directories.is_empty() {
return None;
}
let mut buf = path.to_path_buf();
let mut path_buf = project_path.path.to_path_buf();
while buf.pop() {
if let Some(_) = self.directories.get(&buf) {
return Some(FileInclusion::InDirectory(buf));
while path_buf.pop() {
// TODO: This isn't very efficient. Consider using a better representation of the
// directories map.
let directory_project_path = ProjectPath {
worktree_id: project_path.worktree_id,
path: path_buf.clone().into(),
};
if let Some(_) = self.directories.get(&directory_project_path) {
return Some(FileInclusion::InDirectory(directory_project_path));
}
}
None
}
pub fn includes_directory(&self, path: &Path) -> Option<FileInclusion> {
if let Some(context_id) = self.directories.get(path) {
pub fn includes_directory(&self, project_path: &ProjectPath) -> Option<FileInclusion> {
if let Some(context_id) = self.directories.get(project_path) {
return Some(FileInclusion::Direct(*context_id));
}
self.will_include_file_path_via_directory(path)
self.will_include_file_path_via_directory(project_path)
}
pub fn included_symbol(&self, symbol_id: &ContextSymbolId) -> Option<ContextId> {
@@ -564,13 +582,13 @@ impl ContextStore {
}
}
pub fn file_paths(&self, cx: &App) -> HashSet<PathBuf> {
pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
self.context
.iter()
.filter_map(|context| match context {
AssistantContext::File(file) => {
let buffer = file.context_buffer.buffer.read(cx);
buffer_path_log_err(buffer, cx).map(|p| p.to_path_buf())
buffer.project_path(cx)
}
AssistantContext::Directory(_)
| AssistantContext::Symbol(_)
@@ -587,7 +605,7 @@ impl ContextStore {
pub enum FileInclusion {
Direct(ContextId),
InDirectory(PathBuf),
InDirectory(ProjectPath),
}
// ContextBuffer without text.
@@ -654,19 +672,6 @@ fn collect_buffer_info_and_text(
Ok((buffer_info, text_task))
}
pub fn buffer_path_log_err(buffer: &Buffer, cx: &App) -> Option<Arc<Path>> {
if let Some(file) = buffer.file() {
let mut path = file.path().clone();
if path.as_os_str().is_empty() {
path = file.full_path(cx).into();
}
Some(path)
} else {
log::error!("Buffer that had a path unexpectedly no longer has a path.");
None
}
}
fn to_fenced_codeblock(path: &Path, content: Rope) -> SharedString {
let path_extension = path.extension().and_then(|ext| ext.to_str());
let path_string = path.to_string_lossy();
@@ -742,13 +747,13 @@ pub fn refresh_context_store_text(
}
}
AssistantContext::Directory(directory_context) => {
let directory_path = directory_context.project_path(cx);
let should_refresh = changed_buffers.is_empty()
|| changed_buffers.iter().any(|buffer| {
let buffer = buffer.read(cx);
buffer_path_log_err(&buffer, cx).map_or(false, |path| {
path.starts_with(&directory_context.project_path.path)
})
let Some(buffer_path) = buffer.read(cx).project_path(cx) else {
return false;
};
buffer_path.starts_with(&directory_path)
});
if should_refresh {
@@ -835,14 +840,16 @@ fn refresh_directory_text(
let context_buffers = future::join_all(futures);
let id = directory_context.id;
let project_path = directory_context.project_path.clone();
let worktree = directory_context.worktree.clone();
let path = directory_context.path.clone();
Some(cx.spawn(async move |cx| {
let context_buffers = context_buffers.await;
context_store
.update(cx, |context_store, _| {
let new_directory_context = DirectoryContext {
id,
project_path,
worktree,
path,
context_buffers,
};
context_store.replace_context(AssistantContext::Directory(new_directory_context));

View File

@@ -1,3 +1,4 @@
use std::path::Path;
use std::rc::Rc;
use collections::HashSet;
@@ -9,6 +10,7 @@ use gpui::{
};
use itertools::Itertools;
use language::Buffer;
use project::ProjectItem;
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
use workspace::{Workspace, notifications::NotifyResultExt};
@@ -93,26 +95,23 @@ impl ContextStrip {
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
let active_buffer = active_buffer_entity.read(cx);
let path = active_buffer.file()?.full_path(cx);
let project_path = active_buffer.project_path(cx)?;
if self
.context_store
.read(cx)
.will_include_buffer(active_buffer.remote_id(), &path)
.will_include_buffer(active_buffer.remote_id(), &project_path)
.is_some()
{
return None;
}
let name = match path.file_name() {
Some(name) => name.to_string_lossy().into_owned().into(),
None => path.to_string_lossy().into_owned().into(),
};
let file_name = active_buffer.file()?.file_name(cx);
let icon_path = FileIcons::get_icon(&path, cx);
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
Some(SuggestedContext::File {
name,
name: file_name.to_string_lossy().into_owned().into(),
buffer: active_buffer_entity.downgrade(),
icon_path,
})

View File

@@ -28,6 +28,7 @@ use language_model::{LanguageModelRegistry, report_assistant_event};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::LspAction;
use project::Project;
use project::{CodeAction, ProjectTransaction};
use prompt_store::PromptBuilder;
use settings::{Settings, SettingsStore};
@@ -254,6 +255,7 @@ impl InlineAssistant {
assistant.assist(
&active_editor,
cx.entity().downgrade(),
workspace.project().downgrade(),
thread_store,
window,
cx,
@@ -262,7 +264,14 @@ impl InlineAssistant {
}
InlineAssistTarget::Terminal(active_terminal) => {
TerminalInlineAssistant::update_global(cx, |assistant, cx| {
assistant.assist(&active_terminal, cx.entity(), thread_store, window, cx)
assistant.assist(
&active_terminal,
cx.entity().downgrade(),
workspace.project().downgrade(),
thread_store,
window,
cx,
)
})
}
};
@@ -312,17 +321,11 @@ impl InlineAssistant {
&mut self,
editor: &Entity<Editor>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
) {
let Some(project) = workspace
.upgrade()
.map(|workspace| workspace.read(cx).project().downgrade())
else {
return;
};
let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
(
editor.snapshot(window, cx),

View File

@@ -3,14 +3,16 @@ use std::sync::Arc;
use crate::assistant_model_selector::ModelType;
use collections::HashSet;
use editor::actions::MoveUp;
use editor::{ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle};
use editor::{
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorStyle, MultiBuffer,
};
use file_icons::FileIcons;
use fs::Fs;
use gpui::{
Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::Buffer;
use language::{Buffer, Language};
use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
@@ -30,8 +32,8 @@ use crate::profile_selector::ProfileSelector;
use crate::thread::{RequestKind, Thread, TokenUsageRatio};
use crate::thread_store::ThreadStore;
use crate::{
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ThreadEvent,
ToggleContextPicker, ToggleProfileSelector,
AgentDiff, Chat, ChatMode, NewThread, OpenAgentDiff, RemoveAllContext, ToggleContextPicker,
ToggleProfileSelector,
};
pub struct MessageEditor {
@@ -66,8 +68,24 @@ impl MessageEditor {
let inline_context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let editor = cx.new(|cx| {
let mut editor = Editor::auto_height(10, window, cx);
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight { max_lines: 10 },
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_context_menu_options(ContextMenuOptions {
@@ -75,7 +93,6 @@ impl MessageEditor {
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
@@ -184,7 +201,7 @@ impl MessageEditor {
}
fn is_editor_empty(&self, cx: &App) -> bool {
self.editor.read(cx).text(cx).is_empty()
self.editor.read(cx).text(cx).trim().is_empty()
}
fn is_model_selected(&self, cx: &App) -> bool {
@@ -218,8 +235,6 @@ impl MessageEditor {
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let system_prompt_context_task = self.thread.read(cx).load_system_prompt_context(cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
@@ -228,16 +243,6 @@ impl MessageEditor {
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
refresh_task.await;
let (system_prompt_context, load_error) = system_prompt_context_task.await;
thread
.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})
.log_err();
thread
.update(cx, |thread, cx| {

View File

@@ -16,6 +16,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use project::Project;
use prompt_store::PromptBuilder;
use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
@@ -66,7 +67,8 @@ impl TerminalInlineAssistant {
pub fn assist(
&mut self,
terminal_view: &Entity<TerminalView>,
workspace: Entity<Workspace>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
thread_store: Option<WeakEntity<ThreadStore>>,
window: &mut Window,
cx: &mut App,
@@ -75,7 +77,6 @@ impl TerminalInlineAssistant {
let assist_id = self.next_assist_id.post_inc();
let prompt_buffer =
cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
let project = workspace.read(cx).project().downgrade();
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
@@ -87,7 +88,7 @@ impl TerminalInlineAssistant {
codegen,
self.fs.clone(),
context_store.clone(),
workspace.downgrade(),
workspace.clone(),
thread_store.clone(),
window,
cx,
@@ -106,7 +107,7 @@ impl TerminalInlineAssistant {
assist_id,
terminal_view,
prompt_editor,
workspace.downgrade(),
workspace.clone(),
context_store,
window,
cx,

View File

@@ -3,13 +3,11 @@ use std::io::Write;
use std::ops::Range;
use std::sync::Arc;
use agent_rules::load_worktree_rules_file;
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::AssistantSettings;
use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
@@ -20,19 +18,20 @@ use language_model::{
LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
PaymentRequiredError, Role, StopReason, TokenUsage,
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
use project::{Project, Worktree};
use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
@@ -182,7 +181,7 @@ pub struct ThreadCheckpoint {
git_checkpoint: GitStoreCheckpoint,
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ThreadFeedback {
Positive,
Negative,
@@ -246,7 +245,7 @@ pub struct Thread {
next_message_id: MessageId,
context: BTreeMap<ContextId, AssistantContext>,
context_by_message: HashMap<MessageId, Vec<ContextId>>,
system_prompt_context: Option<AssistantSystemPromptContext>,
project_context: SharedProjectContext,
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
@@ -260,6 +259,7 @@ pub struct Thread {
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
cumulative_token_usage: TokenUsage,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
}
impl Thread {
@@ -267,6 +267,7 @@ impl Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
system_prompt: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
Self {
@@ -279,7 +280,7 @@ impl Thread {
next_message_id: MessageId(0),
context: BTreeMap::default(),
context_by_message: HashMap::default(),
system_prompt_context: None,
project_context: system_prompt,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -298,6 +299,7 @@ impl Thread {
},
cumulative_token_usage: TokenUsage::default(),
feedback: None,
message_feedback: HashMap::default(),
}
}
@@ -307,6 +309,7 @@ impl Thread {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
project_context: SharedProjectContext,
cx: &mut Context<Self>,
) -> Self {
let next_message_id = MessageId(
@@ -347,7 +350,7 @@ impl Thread {
next_message_id,
context: BTreeMap::default(),
context_by_message: HashMap::default(),
system_prompt_context: None,
project_context,
checkpoints_by_message: HashMap::default(),
completion_count: 0,
pending_completions: Vec::new(),
@@ -361,6 +364,7 @@ impl Thread {
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
cumulative_token_usage: serialized.cumulative_token_usage,
feedback: None,
message_feedback: HashMap::default(),
}
}
@@ -384,6 +388,10 @@ impl Thread {
self.summary.clone()
}
pub fn project_context(&self) -> SharedProjectContext {
self.project_context.clone()
}
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@@ -805,86 +813,6 @@ impl Thread {
})
}
pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
self.system_prompt_context = Some(context);
}
pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
&self.system_prompt_context
}
pub fn load_system_prompt_context(
&self,
cx: &App,
) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async |_cx| {
let results = futures::future::join_all(tasks).await;
let mut first_err = None;
let worktrees = results
.into_iter()
.map(|(worktree, err)| {
if first_err.is_none() && err.is_some() {
first_err = err;
}
worktree
})
.collect::<Vec<_>>();
(AssistantSystemPromptContext::new(worktrees), first_err)
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
let rules_task = load_worktree_rules_file(fs, worktree, cx);
let Some(rules_task) = rules_task else {
return Task::ready((
WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file: None,
},
None,
));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(ThreadError::Message {
header: "Error loading rules file".into(),
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeInfoForSystemPrompt {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
@@ -934,10 +862,10 @@ impl Thread {
temperature: None,
};
if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
if let Some(project_context) = self.project_context.borrow().as_ref() {
if let Some(system_prompt) = self
.prompt_builder
.generate_assistant_system_prompt(system_prompt_context)
.generate_assistant_system_prompt(project_context)
.context("failed to generate assistant system prompt")
.log_err()
{
@@ -948,7 +876,7 @@ impl Thread {
});
}
} else {
log::error!("system_prompt_context not set.")
log::error!("project_context not set.")
}
for message in &self.messages {
@@ -1178,7 +1106,8 @@ impl Thread {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
cx.emit(ThreadEvent::UsePendingTools);
let tool_uses = thread.use_pending_tools(cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
StopReason::MaxTokens => {}
@@ -1205,7 +1134,7 @@ impl Thread {
thread.cancel_last_completion(cx);
}
}
cx.emit(ThreadEvent::DoneStreaming);
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
if let Ok(initial_usage) = initial_token_usage {
let usage = thread.cumulative_token_usage.clone() - initial_usage;
@@ -1366,10 +1295,7 @@ impl Thread {
)
}
pub fn use_pending_tools(
&mut self,
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> + use<> {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
let request = self.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
let pending_tool_uses = self
@@ -1457,18 +1383,36 @@ impl Thread {
output,
cx,
);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
canceled: false,
});
thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
})
.ok();
}
})
}
fn tool_finished(
&mut self,
tool_use_id: LanguageModelToolUseId,
pending_tool_use: Option<PendingToolUse>,
canceled: bool,
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
self.attach_tool_results(cx);
if !canceled {
self.send_to_model(model, RequestKind::Chat, cx);
}
}
}
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
});
}
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
// Insert a user message to contain the tool results.
self.insert_user_message(
@@ -1492,11 +1436,12 @@ impl Thread {
let mut canceled = false;
for pending_tool_use in self.tool_use.cancel_pending() {
canceled = true;
cx.emit(ThreadEvent::ToolFinished {
tool_use_id: pending_tool_use.id.clone(),
pending_tool_use: Some(pending_tool_use),
canceled: true,
});
self.tool_finished(
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
cx,
);
}
canceled
};
@@ -1504,24 +1449,45 @@ impl Thread {
canceled
}
/// Returns the feedback given to the thread, if any.
pub fn feedback(&self) -> Option<ThreadFeedback> {
self.feedback
}
/// Reports feedback about the thread and stores it in our telemetry backend.
pub fn report_feedback(
pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
self.message_feedback.get(&message_id).copied()
}
pub fn report_message_feedback(
&mut self,
message_id: MessageId,
feedback: ThreadFeedback,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
if self.message_feedback.get(&message_id) == Some(&feedback) {
return Task::ready(Ok(()));
}
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
self.feedback = Some(feedback);
let enabled_tool_names: Vec<String> = self
.tools()
.enabled_tools(cx)
.iter()
.map(|tool| tool.name().to_string())
.collect();
self.message_feedback.insert(message_id, feedback);
cx.notify();
let message_content = self
.message(message_id)
.map(|msg| msg.to_string())
.unwrap_or_default();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
@@ -1536,6 +1502,9 @@ impl Thread {
"Assistant Thread Rated",
rating,
thread_id,
enabled_tool_names,
message_id = message_id.0,
message_content,
thread_data,
final_project_snapshot
);
@@ -1545,6 +1514,52 @@ impl Thread {
})
}
pub fn report_feedback(
&mut self,
feedback: ThreadFeedback,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let last_assistant_message_id = self
.messages
.iter()
.rev()
.find(|msg| msg.role == Role::Assistant)
.map(|msg| msg.id);
if let Some(message_id) = last_assistant_message_id {
self.report_message_feedback(message_id, feedback, cx)
} else {
let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
let serialized_thread = self.serialize(cx);
let thread_id = self.id().clone();
let client = self.project.read(cx).client();
self.feedback = Some(feedback);
cx.notify();
cx.background_spawn(async move {
let final_project_snapshot = final_project_snapshot.await;
let serialized_thread = serialized_thread.await?;
let thread_data = serde_json::to_value(serialized_thread)
.unwrap_or_else(|_| serde_json::Value::Null);
let rating = match feedback {
ThreadFeedback::Positive => "positive",
ThreadFeedback::Negative => "negative",
};
telemetry::event!(
"Assistant Thread Rated",
rating,
thread_id,
thread_data,
final_project_snapshot
);
client.telemetry().flush_events();
Ok(())
})
}
}
/// Create a snapshot of the current project state including git information and unsaved buffers.
fn project_snapshot(
project: Entity<Project>,
@@ -1801,19 +1816,17 @@ impl Thread {
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use: None,
canceled: true,
});
self.tool_finished(tool_use_id.clone(), None, true, cx);
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Error)]
pub enum ThreadError {
#[error("Payment required")]
PaymentRequired,
#[error("Max monthly spend reached")]
MaxMonthlySpendReached,
#[error("Message {header}: {message}")]
Message {
header: SharedString,
message: SharedString,
@@ -1826,20 +1839,20 @@ pub enum ThreadEvent {
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
DoneStreaming,
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
SummaryGenerated,
SummaryChanged,
UsePendingTools,
UsePendingTools {
tool_uses: Vec<PendingToolUse>,
},
ToolFinished {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
/// The pending tool use that corresponds to this tool.
pending_tool_use: Option<PendingToolUse>,
/// Whether the tool was canceled by the user.
canceled: bool,
},
CheckpointChanged,
ToolConfirmationNeeded,
@@ -1932,9 +1945,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages.len(), 2);
let expected_full_message = format!("{}Please explain this code", expected_context);
assert_eq!(request.messages[0].string_contents(), expected_full_message);
assert_eq!(request.messages[1].string_contents(), expected_full_message);
}
#[gpui::test]
@@ -2025,20 +2038,20 @@ fn main() {{
});
// The request should contain all 3 messages
assert_eq!(request.messages.len(), 3);
assert_eq!(request.messages.len(), 4);
// Check that the contexts are properly formatted in each message
assert!(request.messages[0].string_contents().contains("file1.rs"));
assert!(!request.messages[0].string_contents().contains("file2.rs"));
assert!(!request.messages[0].string_contents().contains("file3.rs"));
assert!(!request.messages[1].string_contents().contains("file1.rs"));
assert!(request.messages[1].string_contents().contains("file2.rs"));
assert!(request.messages[1].string_contents().contains("file1.rs"));
assert!(!request.messages[1].string_contents().contains("file2.rs"));
assert!(!request.messages[1].string_contents().contains("file3.rs"));
assert!(!request.messages[2].string_contents().contains("file1.rs"));
assert!(!request.messages[2].string_contents().contains("file2.rs"));
assert!(request.messages[2].string_contents().contains("file3.rs"));
assert!(request.messages[2].string_contents().contains("file2.rs"));
assert!(!request.messages[2].string_contents().contains("file3.rs"));
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
}
#[gpui::test]
@@ -2076,9 +2089,9 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 1);
assert_eq!(request.messages.len(), 2);
assert_eq!(
request.messages[0].string_contents(),
request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
@@ -2096,13 +2109,13 @@ fn main() {{
thread.to_completion_request(RequestKind::Chat, cx)
});
assert_eq!(request.messages.len(), 2);
assert_eq!(request.messages.len(), 3);
assert_eq!(
request.messages[0].string_contents(),
request.messages[1].string_contents(),
"What is the best way to learn Rust?"
);
assert_eq!(
request.messages[1].string_contents(),
request.messages[2].string_contents(),
"Are there any good books?"
);
}
@@ -2223,15 +2236,16 @@ fn main() {{
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx.update(|_, cx| {
ThreadStore::new(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
.unwrap()
});
let thread_store = cx
.update(|_, cx| {
ThreadStore::load(
project.clone(),
Arc::default(),
Arc::new(PromptBuilder::new(None).unwrap()),
cx,
)
})
.await;
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));

View File

@@ -1,37 +1,57 @@
use std::borrow::Cow;
use std::path::PathBuf;
use std::cell::{Ref, RefCell};
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result, anyhow};
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use fs::Fs;
use futures::FutureExt as _;
use futures::future::{self, BoxFuture, Shared};
use gpui::{
App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
prelude::*,
App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString,
Subscription, Task, prelude::*,
};
use heed::Database;
use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::Project;
use prompt_store::PromptBuilder;
use project::{Project, Worktree};
use prompt_store::{ProjectContext, PromptBuilder, RulesFileContext, WorktreeContext};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
use crate::thread::{
DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
};
use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
];
pub fn init(cx: &mut App) {
ThreadsDatabase::init(cx);
}
/// A system prompt shared by all threads created by this ThreadStore
#[derive(Clone, Default)]
pub struct SharedProjectContext(Rc<RefCell<Option<ProjectContext>>>);
impl SharedProjectContext {
pub fn borrow(&self) -> Ref<Option<ProjectContext>> {
self.0.borrow()
}
}
pub struct ThreadStore {
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
@@ -39,43 +59,187 @@ pub struct ThreadStore {
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
project_context: SharedProjectContext,
_subscriptions: Vec<Subscription>,
}
pub struct RulesLoadingError {
pub message: SharedString,
}
impl EventEmitter<RulesLoadingError> for ThreadStore {}
impl ThreadStore {
pub fn new(
pub fn load(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut App,
) -> Result<Entity<Self>> {
let this = cx.new(|cx| {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
) -> Task<Entity<Self>> {
let thread_store = cx.new(|cx| Self::new(project, tools, prompt_builder, cx));
let reload = thread_store.update(cx, |store, cx| store.reload_system_prompt(cx));
cx.foreground_executor().spawn(async move {
reload.await;
thread_store
})
}
let this = Self {
project,
tools,
prompt_builder,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
_subscriptions: vec![settings_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
this
fn new(
project: Entity<Project>,
tools: Arc<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
cx: &mut Context<Self>,
) -> Self {
let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});
let settings_subscription =
cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
this.load_default_profile(cx);
});
let project_subscription = cx.subscribe(&project, Self::handle_project_event);
Ok(this)
let this = Self {
project,
tools,
prompt_builder,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
_subscriptions: vec![settings_subscription, project_subscription],
};
this.load_default_profile(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
this
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,
event: &project::Event,
cx: &mut Context<Self>,
) {
match event {
project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
self.reload_system_prompt(cx).detach();
}
project::Event::WorktreeUpdatedEntries(_, items) => {
if items.iter().any(|(path, _, _)| {
RULES_FILE_NAMES
.iter()
.any(|name| path.as_ref() == Path::new(name))
}) {
self.reload_system_prompt(cx).detach();
}
}
_ => {}
}
}
pub fn reload_system_prompt(&self, cx: &mut Context<Self>) -> Task<()> {
let project = self.project.read(cx);
let tasks = project
.visible_worktrees(cx)
.map(|worktree| {
Self::load_worktree_info_for_system_prompt(
project.fs().clone(),
worktree.read(cx),
cx,
)
})
.collect::<Vec<_>>();
cx.spawn(async move |this, cx| {
let results = futures::future::join_all(tasks).await;
let worktrees = results
.into_iter()
.map(|(worktree, rules_error)| {
if let Some(rules_error) = rules_error {
this.update(cx, |_, cx| cx.emit(rules_error)).ok();
}
worktree
})
.collect::<Vec<_>>();
this.update(cx, |this, _cx| {
*this.project_context.0.borrow_mut() = Some(ProjectContext::new(worktrees));
})
.ok();
})
}
fn load_worktree_info_for_system_prompt(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
let root_name = worktree.root_name().into();
let abs_path = worktree.abs_path();
let rules_task = Self::load_worktree_rules_file(fs, worktree, cx);
let Some(rules_task) = rules_task else {
return Task::ready((
WorktreeContext {
root_name,
abs_path,
rules_file: None,
},
None,
));
};
cx.spawn(async move |_| {
let (rules_file, rules_file_error) = match rules_task.await {
Ok(rules_file) => (Some(rules_file), None),
Err(err) => (
None,
Some(RulesLoadingError {
message: format!("{err}").into(),
}),
),
};
let worktree_info = WorktreeContext {
root_name,
abs_path,
rules_file,
};
(worktree_info, rules_file_error)
})
}
fn load_worktree_rules_file(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Option<Task<Result<RulesFileContext>>> {
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|(path_in_worktree, abs_path)| {
let fs = fs.clone();
cx.background_spawn(async move {
let abs_path = abs_path?;
let text = fs.load(&abs_path).await.with_context(|| {
format!("Failed to load assistant rules file {:?}", abs_path)
})?;
anyhow::Ok(RulesFileContext {
path_in_worktree,
abs_path: abs_path.into(),
text: text.trim().to_string(),
})
})
})
}
pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
@@ -107,6 +271,7 @@ impl ThreadStore {
self.project.clone(),
self.tools.clone(),
self.prompt_builder.clone(),
self.project_context.clone(),
cx,
)
})
@@ -134,21 +299,12 @@ impl ThreadStore {
this.project.clone(),
this.tools.clone(),
this.prompt_builder.clone(),
this.project_context.clone(),
cx,
)
})
})?;
let (system_prompt_context, load_error) = thread
.update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
.await;
thread.update(cx, |thread, cx| {
thread.set_system_prompt_context(system_prompt_context);
if let Some(load_error) = load_error {
cx.emit(ThreadEvent::ShowError(load_error));
}
})?;
Ok(thread)
})
}
@@ -491,7 +647,7 @@ impl ThreadsDatabase {
let database_future = executor
.spawn({
let executor = executor.clone();
let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
let database_path = paths::data_dir().join("threads/threads-db.1.mdb");
async move { ThreadsDatabase::new(database_path, executor) }
})
.then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))

View File

@@ -1,5 +1,7 @@
mod agent_notification;
mod context_pill;
mod user_spending;
pub use agent_notification::*;
pub use context_pill::*;
// pub use user_spending::*;

View File

@@ -280,9 +280,10 @@ impl AddedContext {
}
AssistantContext::Directory(directory_context) => {
// TODO: handle worktree disambiguation. Maybe by storing an `Arc<dyn File>` to also
// handle renames?
let full_path = &directory_context.project_path.path;
let full_path = directory_context
.worktree
.read(cx)
.full_path(&directory_context.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path

View File

@@ -0,0 +1,186 @@
use gpui::{Entity, Render};
use ui::{ProgressBar, prelude::*};
#[derive(RegisterComponent)]
pub struct UserSpending {
free_tier_current: u32,
free_tier_cap: u32,
over_tier_current: u32,
over_tier_cap: u32,
free_tier_progress: Entity<ProgressBar>,
over_tier_progress: Entity<ProgressBar>,
}
impl UserSpending {
pub fn new(
free_tier_current: u32,
free_tier_cap: u32,
over_tier_current: u32,
over_tier_cap: u32,
cx: &mut App,
) -> Self {
let free_tier_capped = free_tier_current == free_tier_cap;
let free_tier_near_capped =
free_tier_current as f32 / 100.0 >= free_tier_cap as f32 / 100.0 * 0.9;
let over_tier_capped = over_tier_current == over_tier_cap;
let over_tier_near_capped =
over_tier_current as f32 / 100.0 >= over_tier_cap as f32 / 100.0 * 0.9;
let free_tier_progress = cx.new(|cx| {
ProgressBar::new(
"free_tier",
free_tier_current as f32,
free_tier_cap as f32,
cx,
)
});
let over_tier_progress = cx.new(|cx| {
ProgressBar::new(
"over_tier",
over_tier_current as f32,
over_tier_cap as f32,
cx,
)
});
if free_tier_capped {
free_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().error);
});
} else if free_tier_near_capped {
free_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().warning);
});
}
if over_tier_capped {
over_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().error);
});
} else if over_tier_near_capped {
over_tier_progress.update(cx, |progress_bar, cx| {
progress_bar.fg_color(cx.theme().status().warning);
});
}
Self {
free_tier_current,
free_tier_cap,
over_tier_current,
over_tier_cap,
free_tier_progress,
over_tier_progress,
}
}
}
impl Render for UserSpending {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let formatted_free_tier = format!(
"${} / ${}",
self.free_tier_current as f32 / 100.0,
self.free_tier_cap as f32 / 100.0
);
let formatted_over_tier = format!(
"${} / ${}",
self.over_tier_current as f32 / 100.0,
self.over_tier_cap as f32 / 100.0
);
v_group()
.elevation_2(cx)
.py_1p5()
.px_2p5()
.w(px(360.))
.child(
v_flex()
.child(
v_flex()
.p_1p5()
.gap_0p5()
.child(
h_flex()
.justify_between()
.child(Label::new("Free Tier Usage").size(LabelSize::Small))
.child(
Label::new(formatted_free_tier)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(self.free_tier_progress.clone()),
)
.child(
v_flex()
.p_1p5()
.gap_0p5()
.child(
h_flex()
.justify_between()
.child(Label::new("Current Spending").size(LabelSize::Small))
.child(
Label::new(formatted_over_tier)
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(self.over_tier_progress.clone()),
),
)
}
}
impl Component for UserSpending {
fn scope() -> ComponentScope {
ComponentScope::None
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let new_user = cx.new(|cx| UserSpending::new(0, 2000, 0, 2000, cx));
let free_capped = cx.new(|cx| UserSpending::new(2000, 2000, 0, 2000, cx));
let free_near_capped = cx.new(|cx| UserSpending::new(1800, 2000, 0, 2000, cx));
let over_near_capped = cx.new(|cx| UserSpending::new(2000, 2000, 1800, 2000, cx));
let over_capped = cx.new(|cx| UserSpending::new(1000, 2000, 2000, 2000, cx));
Some(
v_flex()
.gap_6()
.p_4()
.children(vec![example_group(vec![
single_example(
"New User",
div().size_full().child(new_user.clone()).into_any_element(),
),
single_example(
"Free Tier Capped",
div()
.size_full()
.child(free_capped.clone())
.into_any_element(),
),
single_example(
"Free Tier Near Capped",
div()
.size_full()
.child(free_near_capped.clone())
.into_any_element(),
),
single_example(
"Over Tier Near Capped",
div()
.size_full()
.child(over_near_capped.clone())
.into_any_element(),
),
single_example(
"Over Tier Capped",
div()
.size_full()
.child(over_capped.clone())
.into_any_element(),
),
])])
.into_any_element(),
)
}
}

View File

@@ -1,52 +0,0 @@
// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
use std::process::Command;
fn main() {
if cfg!(target_os = "macos") {
println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
// Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
// Seems to be required to enable Swift concurrency
println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
// Register exported Objective-C selectors, protocols, etc
println!("cargo:rustc-link-arg=-Wl,-ObjC");
}
// Populate git sha environment variable if git is available
println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
println!(
"cargo:rustc-env=TARGET={}",
std::env::var("TARGET").unwrap()
);
if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
if output.status.success() {
let git_sha = String::from_utf8_lossy(&output.stdout);
let git_sha = git_sha.trim();
println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
if let Ok(build_profile) = std::env::var("PROFILE") {
if build_profile == "release" {
// This is currently the best way to make `cargo build ...`'s build script
// to print something to stdout without extra verbosity.
println!(
"cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
);
}
}
}
}
#[cfg(target_os = "windows")]
{
#[cfg(target_env = "msvc")]
{
// todo(windows): This is to avoid stack overflow. Remove it when solved.
println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
}
}
}

View File

@@ -1,384 +0,0 @@
use crate::git_commands::{run_git, setup_temp_repo};
use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
use crate::{get_exercise_language, get_exercise_name};
use agent::RequestKind;
use anyhow::{Result, anyhow};
use collections::HashMap;
use gpui::{App, Task};
use language_model::{LanguageModel, TokenUsage};
use serde::{Deserialize, Serialize};
use std::{
fs,
io::Write,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, SystemTime},
};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct EvalResult {
pub exercise_name: String,
pub diff: String,
pub assistant_response: String,
pub elapsed_time_ms: u128,
pub timestamp: u128,
// Token usage fields
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub tool_use_counts: usize,
}
pub struct EvalOutput {
pub diff: String,
pub last_message: String,
pub elapsed_time: Duration,
pub assistant_response_count: usize,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub token_usage: TokenUsage,
}
#[derive(Deserialize)]
pub struct EvalSetup {
pub url: String,
pub base_sha: String,
}
pub struct Eval {
pub repo_path: PathBuf,
pub eval_setup: EvalSetup,
pub user_prompt: String,
}
impl Eval {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
let prompt_path = path.join("prompt.txt");
let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
let setup_path = path.join("setup.json");
let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
// Move this internal function inside the load method since it's only used here
fn repo_dir_name(url: &str) -> String {
url.trim_start_matches("https://")
.replace(|c: char| !c.is_alphanumeric(), "_")
}
let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
Ok(Eval {
repo_path,
eval_setup,
user_prompt,
})
}
pub fn run(
self,
app_state: Arc<HeadlessAppState>,
model: Arc<dyn LanguageModel>,
cx: &mut App,
) -> Task<Result<EvalOutput>> {
cx.spawn(async move |cx| {
run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
let (assistant, done_rx) =
cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
let _worktree = assistant
.update(cx, |assistant, cx| {
assistant.project.update(cx, |project, cx| {
project.create_worktree(&self.repo_path, true, cx)
})
})?
.await?;
let start_time = std::time::SystemTime::now();
let (system_prompt_context, load_error) = cx
.update(|cx| {
assistant
.read(cx)
.thread
.read(cx)
.load_system_prompt_context(cx)
})?
.await;
if let Some(load_error) = load_error {
return Err(anyhow!("{:?}", load_error));
};
assistant.update(cx, |assistant, cx| {
assistant.thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
thread.set_system_prompt_context(system_prompt_context);
thread.send_to_model(model, RequestKind::Chat, cx);
});
})?;
done_rx.recv().await??;
// Add this section to check untracked files
println!("Checking for untracked files:");
let untracked = run_git(
&self.repo_path,
&["ls-files", "--others", "--exclude-standard"],
)
.await?;
if untracked.is_empty() {
println!("No untracked files found");
} else {
// Add all files to git so they appear in the diff
println!("Adding untracked files to git");
run_git(&self.repo_path, &["add", "."]).await?;
}
// get git status
let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
let elapsed_time = start_time.elapsed()?;
// Get diff of staged changes (the files we just added)
let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
// Get diff of unstaged changes
let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
// Combine both diffs
let diff = if unstaged_diff.is_empty() {
staged_diff
} else if staged_diff.is_empty() {
unstaged_diff
} else {
format!(
"# Staged changes\n{}\n\n# Unstaged changes\n{}",
staged_diff, unstaged_diff
)
};
assistant.update(cx, |assistant, cx| {
let thread = assistant.thread.read(cx);
let last_message = thread.messages().last().unwrap();
if last_message.role != language_model::Role::Assistant {
return Err(anyhow!("Last message is not from assistant"));
}
let assistant_response_count = thread
.messages()
.filter(|message| message.role == language_model::Role::Assistant)
.count();
Ok(EvalOutput {
diff,
last_message: last_message.to_string(),
elapsed_time,
assistant_response_count,
tool_use_counts: assistant.tool_use_counts.clone(),
token_usage: thread.cumulative_token_usage(),
})
})?
})
}
}
impl EvalOutput {
// Keep this method for potential future use, but mark it as intentionally unused
#[allow(dead_code)]
pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
// Create the output directory if it doesn't exist
fs::create_dir_all(&output_dir)?;
// Save the diff to a file
let diff_path = output_dir.join("diff.patch");
let mut diff_file = fs::File::create(&diff_path)?;
diff_file.write_all(self.diff.as_bytes())?;
// Save the last message to a file
let message_path = output_dir.join("assistant_response.txt");
let mut message_file = fs::File::create(&message_path)?;
message_file.write_all(self.last_message.as_bytes())?;
// Current metrics for this run
let current_metrics = serde_json::json!({
"elapsed_time_ms": self.elapsed_time.as_millis(),
"assistant_response_count": self.assistant_response_count,
"tool_use_counts": self.tool_use_counts,
"token_usage": self.token_usage,
"eval_output_value": eval_output_value,
});
// Get current timestamp in milliseconds
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_millis()
.to_string();
// Path to metrics file
let metrics_path = output_dir.join("metrics.json");
// Load existing metrics if the file exists, or create a new object
let mut historical_metrics = if metrics_path.exists() {
let metrics_content = fs::read_to_string(&metrics_path)?;
serde_json::from_str::<serde_json::Value>(&metrics_content)
.unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Add new run with timestamp as key
if let serde_json::Value::Object(ref mut map) = historical_metrics {
map.insert(timestamp, current_metrics);
}
// Write updated metrics back to file
let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
let mut metrics_file = fs::File::create(&metrics_path)?;
metrics_file.write_all(metrics_json.as_bytes())?;
Ok(())
}
}
pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
let instructions_path = exercise_path.join(".docs").join("instructions.md");
println!("Reading instructions from: {}", instructions_path.display());
let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
Ok(instructions)
}
pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
let eval_dir = exercise_path.join("evaluation");
fs::create_dir_all(&eval_dir)?;
let eval_file = eval_dir.join("evals.json");
println!("Saving evaluation results to: {}", eval_file.display());
println!(
"Results to save: {} evaluations for exercise path: {}",
results.len(),
exercise_path.display()
);
// Check file existence before reading/writing
if eval_file.exists() {
println!("Existing evals.json file found, will update it");
} else {
println!("No existing evals.json file found, will create new one");
}
// Structure to organize evaluations by test name and timestamp
let mut eval_data: serde_json::Value = if eval_file.exists() {
let content = fs::read_to_string(&eval_file)?;
serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
} else {
serde_json::json!({})
};
// Get current timestamp for this batch of results
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis()
.to_string();
// Group the new results by test name (exercise name)
for result in results {
let exercise_name = &result.exercise_name;
println!("Adding result: exercise={}", exercise_name);
// Ensure the exercise entry exists
if eval_data.get(exercise_name).is_none() {
eval_data[exercise_name] = serde_json::json!({});
}
// Ensure the timestamp entry exists as an object
if eval_data[exercise_name].get(&timestamp).is_none() {
eval_data[exercise_name][&timestamp] = serde_json::json!({});
}
// Add this result under the timestamp with template name as key
eval_data[exercise_name][&timestamp] = serde_json::to_value(&result)?;
}
// Write back to file with pretty formatting
let json_content = serde_json::to_string_pretty(&eval_data)?;
match fs::write(&eval_file, json_content) {
Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
Err(e) => println!("✗ Failed to write results file: {}", e),
}
Ok(())
}
pub async fn run_exercise_eval(
exercise_path: PathBuf,
model: Arc<dyn LanguageModel>,
app_state: Arc<HeadlessAppState>,
base_sha: String,
_framework_path: PathBuf,
cx: gpui::AsyncApp,
) -> Result<EvalResult> {
let exercise_name = get_exercise_name(&exercise_path);
let language = get_exercise_language(&exercise_path)?;
let mut instructions = read_instructions(&exercise_path).await?;
instructions.push_str(&format!(
"\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
language
));
println!("Running evaluation for exercise: {}", exercise_name);
// Create temporary directory with exercise files
let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
let temp_path = temp_dir.path().to_path_buf();
let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
let start_time = SystemTime::now();
// Create a basic eval struct to work with the existing system
let eval = Eval {
repo_path: temp_path.clone(),
eval_setup: EvalSetup {
url: format!("file://{}", temp_path.display()),
base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
},
user_prompt: instructions.clone(),
};
// Run the evaluation
let eval_output = cx
.update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
.await?;
// Get diff from git
let diff = eval_output.diff.clone();
let elapsed_time = start_time.elapsed()?;
// Calculate total tokens as the sum of input and output tokens
let input_tokens = eval_output.token_usage.input_tokens;
let output_tokens = eval_output.token_usage.output_tokens;
let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
let total_tokens = input_tokens + output_tokens;
// Save results to evaluation directory
let result = EvalResult {
exercise_name: exercise_name.clone(),
diff,
assistant_response: eval_output.last_message.clone(),
elapsed_time_ms: elapsed_time.as_millis(),
timestamp: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)?
.as_millis(),
// Convert u32 token counts to usize
input_tokens: input_tokens.try_into().unwrap(),
output_tokens: output_tokens.try_into().unwrap(),
total_tokens: total_tokens.try_into().unwrap(),
tool_use_counts: tool_use_counts.try_into().unwrap(),
};
Ok(result)
}

View File

@@ -1,149 +0,0 @@
use anyhow::{Result, anyhow};
use std::{
fs,
path::{Path, PathBuf},
};
pub fn get_exercise_name(exercise_path: &Path) -> String {
exercise_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string()
}
pub fn get_exercise_language(exercise_path: &Path) -> Result<String> {
// Extract the language from path (data/python/exercises/... => python)
let parts: Vec<_> = exercise_path.components().collect();
for (i, part) in parts.iter().enumerate() {
if i > 0 && part.as_os_str() == "eval_code" {
if i + 1 < parts.len() {
let language = parts[i + 1].as_os_str().to_string_lossy().to_string();
return Ok(language);
}
}
}
Err(anyhow!(
"Could not determine language from path: {:?}",
exercise_path
))
}
pub fn find_exercises(
framework_path: &Path,
languages: &[&str],
max_per_language: Option<usize>,
) -> Result<Vec<PathBuf>> {
let mut all_exercises = Vec::new();
println!("Searching for exercises in languages: {:?}", languages);
for language in languages {
let language_dir = framework_path
.join("eval_code")
.join(language)
.join("exercises")
.join("practice");
println!("Checking language directory: {:?}", language_dir);
if !language_dir.exists() {
println!("Warning: Language directory not found: {:?}", language_dir);
continue;
}
let mut exercises = Vec::new();
match fs::read_dir(&language_dir) {
Ok(entries) => {
for entry_result in entries {
match entry_result {
Ok(entry) => {
let path = entry.path();
if path.is_dir() {
// Special handling for "internal" directory
if *language == "internal" {
// Check for repo_info.json to validate it's an internal exercise
let repo_info_path = path.join(".meta").join("repo_info.json");
let instructions_path =
path.join(".docs").join("instructions.md");
if repo_info_path.exists() && instructions_path.exists() {
exercises.push(path);
}
} else {
// Map the language to the file extension - original code
let language_extension = match *language {
"python" => "py",
"go" => "go",
"rust" => "rs",
"typescript" => "ts",
"javascript" => "js",
"ruby" => "rb",
"php" => "php",
"bash" => "sh",
"multi" => "diff",
_ => continue, // Skip unsupported languages
};
// Check if this is a valid exercise with instructions and example
let instructions_path =
path.join(".docs").join("instructions.md");
let has_instructions = instructions_path.exists();
let example_path = path
.join(".meta")
.join(format!("example.{}", language_extension));
let has_example = example_path.exists();
if has_instructions && has_example {
exercises.push(path);
}
}
}
}
Err(err) => println!("Error reading directory entry: {}", err),
}
}
}
Err(err) => println!(
"Error reading directory {}: {}",
language_dir.display(),
err
),
}
// Sort exercises by name for consistent selection
exercises.sort_by(|a, b| {
let a_name = a.file_name().unwrap_or_default().to_string_lossy();
let b_name = b.file_name().unwrap_or_default().to_string_lossy();
a_name.cmp(&b_name)
});
// Apply the limit if specified
if let Some(limit) = max_per_language {
if exercises.len() > limit {
println!(
"Limiting {} exercises to {} for language {}",
exercises.len(),
limit,
language
);
exercises.truncate(limit);
}
}
println!(
"Found {} exercises for language {}: {:?}",
exercises.len(),
language,
exercises
.iter()
.map(|p| p.file_name().unwrap_or_default().to_string_lossy())
.collect::<Vec<_>>()
);
all_exercises.extend(exercises);
}
Ok(all_exercises)
}

View File

@@ -1,125 +0,0 @@
use anyhow::{Result, anyhow};
use serde::Deserialize;
use std::{fs, path::Path};
use tempfile::TempDir;
use util::command::new_smol_command;
use walkdir::WalkDir;
#[derive(Debug, Deserialize)]
pub struct SetupConfig {
#[serde(rename = "base.sha")]
pub base_sha: String,
}
#[derive(Debug, Deserialize)]
pub struct RepoInfo {
pub remote_url: String,
pub head_sha: String,
}
pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = new_smol_command("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
if output.status.success() {
Ok(String::from_utf8(output.stdout)?.trim().to_string())
} else {
Err(anyhow!(
"Git command failed: {} with status: {}",
args.join(" "),
output.status
))
}
}
pub async fn read_base_sha(framework_path: &Path) -> Result<String> {
let setup_path = framework_path.join("setup.json");
let setup_content = smol::unblock(move || std::fs::read_to_string(&setup_path)).await?;
let setup_config: SetupConfig = serde_json_lenient::from_str_lenient(&setup_content)?;
Ok(setup_config.base_sha)
}
pub async fn read_repo_info(exercise_path: &Path) -> Result<RepoInfo> {
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
println!("Reading repo info from: {}", repo_info_path.display());
let repo_info_content = smol::unblock(move || std::fs::read_to_string(&repo_info_path)).await?;
let repo_info: RepoInfo = serde_json_lenient::from_str_lenient(&repo_info_content)?;
// Remove any quotes from the strings
let remote_url = repo_info.remote_url.trim_matches('"').to_string();
let head_sha = repo_info.head_sha.trim_matches('"').to_string();
Ok(RepoInfo {
remote_url,
head_sha,
})
}
pub async fn setup_temp_repo(exercise_path: &Path, _base_sha: &str) -> Result<TempDir> {
let temp_dir = TempDir::new()?;
// Check if this is an internal exercise by looking for repo_info.json
let repo_info_path = exercise_path.join(".meta").join("repo_info.json");
if repo_info_path.exists() {
// This is an internal exercise, handle it differently
let repo_info = read_repo_info(exercise_path).await?;
// Clone the repository to the temp directory
let url = repo_info.remote_url;
let clone_path = temp_dir.path();
println!(
"Cloning repository from {} to {}",
url,
clone_path.display()
);
run_git(
&std::env::current_dir()?,
&["clone", &url, &clone_path.to_string_lossy()],
)
.await?;
// Checkout the specified commit
println!("Checking out commit: {}", repo_info.head_sha);
run_git(temp_dir.path(), &["checkout", &repo_info.head_sha]).await?;
println!("Successfully set up internal repository");
} else {
// Original code for regular exercises
// Copy the exercise files to the temp directory, excluding .docs and .meta
for entry in WalkDir::new(exercise_path).min_depth(0).max_depth(10) {
let entry = entry?;
let source_path = entry.path();
// Skip .docs and .meta directories completely
if source_path.starts_with(exercise_path.join(".docs"))
|| source_path.starts_with(exercise_path.join(".meta"))
{
continue;
}
if source_path.is_file() {
let relative_path = source_path.strip_prefix(exercise_path)?;
let dest_path = temp_dir.path().join(relative_path);
// Make sure parent directories exist
if let Some(parent) = dest_path.parent() {
fs::create_dir_all(parent)?;
}
fs::copy(source_path, dest_path)?;
}
}
// Initialize git repo in the temp directory
run_git(temp_dir.path(), &["init"]).await?;
run_git(temp_dir.path(), &["add", "."]).await?;
run_git(temp_dir.path(), &["commit", "-m", "Initial commit"]).await?;
println!("Created temp repo without .docs and .meta directories");
}
Ok(temp_dir)
}

View File

@@ -1,246 +0,0 @@
use agent::{RequestKind, Thread, ThreadEvent, ThreadStore};
use anyhow::anyhow;
use assistant_tool::ToolWorkingSet;
use client::{Client, UserStore};
use collections::HashMap;
use dap::DapRegistry;
use gpui::{App, Entity, SemanticVersion, Subscription, Task, prelude::*};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::NodeRuntime;
use project::{Project, RealFs};
use prompt_store::PromptBuilder;
use settings::SettingsStore;
use smol::channel;
use std::sync::Arc;
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct HeadlessAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub struct HeadlessAssistant {
pub thread: Entity<Thread>,
pub project: Entity<Project>,
#[allow(dead_code)]
pub thread_store: Entity<ThreadStore>,
pub tool_use_counts: HashMap<Arc<str>, u32>,
pub done_tx: channel::Sender<anyhow::Result<()>>,
_subscription: Subscription,
}
impl HeadlessAssistant {
pub fn new(
app_state: Arc<HeadlessAppState>,
cx: &mut App,
) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
let env = None;
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
Arc::new(DapRegistry::default()),
app_state.fs.clone(),
env,
cx,
);
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
let headless_thread = cx.new(move |cx| Self {
_subscription: cx.subscribe(&thread, Self::handle_thread_event),
thread,
project,
thread_store,
tool_use_counts: HashMap::default(),
done_tx,
});
Ok((headless_thread, done_rx))
}
fn handle_thread_event(
&mut self,
thread: Entity<Thread>,
event: &ThreadEvent,
cx: &mut Context<Self>,
) {
match event {
ThreadEvent::ShowError(err) => self
.done_tx
.send_blocking(Err(anyhow!("{:?}", err)))
.unwrap(),
ThreadEvent::DoneStreaming => {
let thread = thread.read(cx);
if let Some(message) = thread.messages().last() {
println!("Message: {}", message.to_string());
}
if thread.all_tools_finished() {
self.done_tx.send_blocking(Ok(())).unwrap()
}
}
ThreadEvent::UsePendingTools => {
thread.update(cx, |thread, cx| {
thread.use_pending_tools(cx);
});
}
ThreadEvent::ToolConfirmationNeeded => {
// Automatically approve all tools that need confirmation in headless mode
println!("Tool confirmation needed - automatically approving in headless mode");
// Get the tools needing confirmation
let tools_needing_confirmation: Vec<_> = thread
.read(cx)
.tools_needing_confirmation()
.cloned()
.collect();
// Run each tool that needs confirmation
for tool_use in tools_needing_confirmation {
if let Some(tool) = thread.read(cx).tools().tool(&tool_use.name, cx) {
thread.update(cx, |thread, cx| {
println!("Auto-approving tool: {}", tool_use.name);
// Create a request to send to the tool
let request = thread.to_completion_request(RequestKind::Chat, cx);
let messages = Arc::new(request.messages);
// Run the tool
thread.run_tool(
tool_use.id.clone(),
tool_use.ui_text.clone(),
tool_use.input.clone(),
&messages,
tool,
cx,
);
});
}
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(pending_tool_use) = pending_tool_use {
println!(
"Used tool {} with input: {}",
pending_tool_use.name, pending_tool_use.input
);
*self
.tool_use_counts
.entry(pending_tool_use.name.clone())
.or_insert(0) += 1;
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("Tool result: {:?}", tool_result);
}
if thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.default_model() {
thread.update(cx, |thread, cx| {
thread.attach_tool_results(cx);
thread.send_to_model(model.model, RequestKind::Chat, cx);
});
} else {
println!(
"Warning: No active language model available to continue conversation"
);
}
}
}
_ => {}
}
}
}
pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
Arc::new(HeadlessAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}

View File

@@ -1,205 +0,0 @@
mod eval;
mod get_exercise;
mod git_commands;
mod headless_assistant;
use clap::Parser;
use eval::{run_exercise_eval, save_eval_results};
use futures::stream::{self, StreamExt};
use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
use git_commands::read_base_sha;
use gpui::Application;
use headless_assistant::{authenticate_model_provider, find_model};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use std::{path::PathBuf, sync::Arc};
#[derive(Parser, Debug)]
#[command(
name = "agent_eval",
disable_version_flag = true,
before_help = "Tool eval runner"
)]
struct Args {
/// Match the names of evals to run.
#[arg(long)]
exercise_names: Vec<String>,
/// Runs all exercises, causes the exercise_names to be ignored.
#[arg(long)]
all: bool,
/// Supported language types to evaluate (default: internal).
/// Internal is data generated from the agent panel
#[arg(long, default_value = "internal")]
languages: String,
/// Name of the model (default: "claude-3-7-sonnet-latest")
#[arg(long, default_value = "claude-3-7-sonnet-latest")]
model_name: String,
/// Name of the editor model (default: value of `--model_name`).
#[arg(long)]
editor_model_name: Option<String>,
/// Number of evaluations to run concurrently (default: 3)
#[arg(short, long, default_value = "5")]
concurrency: usize,
/// Maximum number of exercises to evaluate per language
#[arg(long)]
max_exercises_per_language: Option<usize>,
}
fn main() {
env_logger::init();
let args = Args::parse();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
// Path to the zed-ace-framework repo
let framework_path = PathBuf::from("../zed-ace-framework")
.canonicalize()
.unwrap();
// Fix the 'languages' lifetime issue by creating owned Strings instead of slices
let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
println!("Using zed-ace-framework at: {:?}", framework_path);
println!("Evaluating languages: {:?}", languages);
app.run(move |cx| {
let app_state = headless_assistant::init(cx);
let model = find_model(&args.model_name, cx).unwrap();
let editor_model = if let Some(model_name) = &args.editor_model_name {
find_model(model_name, cx).unwrap()
} else {
model.clone()
};
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
let editor_model_provider_id = editor_model.provider_id();
let framework_path_clone = framework_path.clone();
let languages_clone = languages.clone();
let exercise_names = args.exercise_names.clone();
let all_flag = args.all;
cx.spawn(async move |cx| {
// Authenticate all model providers first
cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
.unwrap()
.await
.unwrap();
println!("framework path: {}", framework_path_clone.display());
let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
println!("base sha: {}", base_sha);
let all_exercises = find_exercises(
&framework_path_clone,
&languages_clone
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
args.max_exercises_per_language,
)
.unwrap();
println!("Found {} exercises total", all_exercises.len());
// Filter exercises if specific ones were requested
let exercises_to_run = if !exercise_names.is_empty() {
// If exercise names are specified, filter by them regardless of --all flag
all_exercises
.into_iter()
.filter(|path| {
let name = get_exercise_name(path);
exercise_names.iter().any(|filter| name.contains(filter))
})
.collect()
} else if all_flag {
// Only use all_flag if no exercise names are specified
all_exercises
} else {
// Default behavior (no filters)
all_exercises
};
println!("Will run {} exercises", exercises_to_run.len());
// Create exercise eval tasks - each exercise is a single task that will run templates sequentially
let exercise_tasks: Vec<_> = exercises_to_run
.into_iter()
.map(|exercise_path| {
let exercise_name = get_exercise_name(&exercise_path);
let model_clone = model.clone();
let app_state_clone = app_state.clone();
let base_sha_clone = base_sha.clone();
let framework_path_clone = framework_path_clone.clone();
let cx_clone = cx.clone();
async move {
println!("Processing exercise: {}", exercise_name);
let mut exercise_results = Vec::new();
match run_exercise_eval(
exercise_path.clone(),
model_clone.clone(),
app_state_clone.clone(),
base_sha_clone.clone(),
framework_path_clone.clone(),
cx_clone.clone(),
)
.await
{
Ok(result) => {
println!("Completed {}", exercise_name);
exercise_results.push(result);
}
Err(err) => {
println!("Error running {}: {}", exercise_name, err);
}
}
// Save results for this exercise
if !exercise_results.is_empty() {
if let Err(err) =
save_eval_results(&exercise_path, exercise_results.clone()).await
{
println!("Error saving results for {}: {}", exercise_name, err);
} else {
println!("Saved results for {}", exercise_name);
}
}
exercise_results
}
})
.collect();
println!(
"Running {} exercises with concurrency: {}",
exercise_tasks.len(),
args.concurrency
);
// Run exercises concurrently, with each exercise running its templates sequentially
let all_results = stream::iter(exercise_tasks)
.buffer_unordered(args.concurrency)
.flat_map(stream::iter)
.collect::<Vec<_>>()
.await;
println!("Completed {} evaluation runs", all_results.len());
cx.update(|cx| cx.quit()).unwrap();
})
.detach();
});
println!("Done running evals");
}

View File

@@ -1,25 +0,0 @@
[package]
name = "agent_rules"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/agent_rules.rs"
doctest = false
[dependencies]
anyhow.workspace = true
fs.workspace = true
gpui.workspace = true
prompt_store.workspace = true
util.workspace = true
worktree.workspace = true
workspace-hack = { version = "0.1", path = "../../tooling/workspace-hack" }
[dev-dependencies]
indoc.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-GPL

View File

@@ -1,51 +0,0 @@
use std::sync::Arc;
use anyhow::{Context as _, Result};
use fs::Fs;
use gpui::{App, AppContext, Task};
use prompt_store::SystemPromptRulesFile;
use util::maybe;
use worktree::Worktree;
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
".cursorrules",
".windsurfrules",
".clinerules",
".github/copilot-instructions.md",
"CLAUDE.md",
];
pub fn load_worktree_rules_file(
fs: Arc<dyn Fs>,
worktree: &Worktree,
cx: &App,
) -> Option<Task<Result<SystemPromptRulesFile>>> {
let selected_rules_file = RULES_FILE_NAMES
.into_iter()
.filter_map(|name| {
worktree
.entry_for_path(name)
.filter(|entry| entry.is_file())
.map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
})
.next();
// Note that Cline supports `.clinerules` being a directory, but that is not currently
// supported. This doesn't seem to occur often in GitHub repositories.
selected_rules_file.map(|(path_in_worktree, abs_path)| {
let fs = fs.clone();
cx.background_spawn(maybe!(async move {
let abs_path = abs_path?;
let text = fs
.load(&abs_path)
.await
.with_context(|| format!("Failed to load assistant rules file {:?}", abs_path))?;
anyhow::Ok(SystemPromptRulesFile {
path_in_worktree,
abs_path: abs_path.into(),
text: text.trim().to_string(),
})
}))
})
}

View File

@@ -37,9 +37,9 @@ pub enum AnthropicModelMode {
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[default]
#[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
Claude3_5Sonnet,
#[default]
#[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
Claude3_7Sonnet,
#[serde(

View File

@@ -69,7 +69,7 @@ pub enum AssistantProviderContentV1 {
},
}
#[derive(Debug, Default)]
#[derive(Clone, Debug, Default)]
pub struct AssistantSettings {
pub enabled: bool,
pub button: bool,
@@ -742,7 +742,7 @@ mod tests {
AssistantSettings::get_global(cx).default_model,
LanguageModelSelection {
provider: "zed.dev".into(),
model: "claude-3-5-sonnet-latest".into(),
model: "claude-3-7-sonnet-latest".into(),
}
);
});

View File

@@ -1,5 +1,5 @@
mod bash_tool;
mod batch_tool;
mod code_action_tool;
mod code_symbols_tool;
mod copy_path_tool;
mod create_directory_tool;
@@ -15,9 +15,11 @@ mod open_tool;
mod path_search_tool;
mod read_file_tool;
mod regex_search_tool;
mod rename_tool;
mod replace;
mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
use std::sync::Arc;
@@ -28,8 +30,8 @@ use gpui::App;
use http_client::HttpClientWithUrl;
use move_path_tool::MovePathTool;
use crate::bash_tool::BashTool;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
use crate::create_directory_tool::CreateDirectoryTool;
use crate::create_file_tool::CreateFileTool;
@@ -43,14 +45,16 @@ use crate::open_tool::OpenTool;
use crate::path_search_tool::PathSearchTool;
use crate::read_file_tool::ReadFileTool;
use crate::regex_search_tool::RegexSearchTool;
use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
let registry = ToolRegistry::global(cx);
registry.register_tool(BashTool);
registry.register_tool(TerminalTool);
registry.register_tool(BatchTool);
registry.register_tool(CreateDirectoryTool);
registry.register_tool(CreateFileTool);
@@ -58,6 +62,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(DeletePathTool);
registry.register_tool(FindReplaceFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
registry.register_tool(DiagnosticsTool);
registry.register_tool(ListDirectoryTool);
@@ -67,6 +72,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(RegexSearchTool);
registry.register_tool(RenameTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
}

View File

@@ -1,230 +0,0 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use ui::IconName;
use util::command::new_smol_command;
use util::markdown::MarkdownString;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct BashToolInput {
/// The bash one-liner command to execute.
command: String,
/// Working directory for the command. This must be one of the root directories of the project.
cd: String,
}
pub struct BashTool;
impl Tool for BashTool {
fn name(&self) -> String {
"bash".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
fn description(&self) -> String {
include_str!("./bash_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Terminal
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<BashToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<BashToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownString::inline_code(&first_line).0,
1 => {
MarkdownString::inline_code(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.0
}
n => {
MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
}
}
}
Err(_) => "Run bash command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input: BashToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let project = project.read(cx);
let input_path = Path::new(&input.cd);
let working_dir = if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
let only_worktree = match worktrees.next() {
Some(worktree) => worktree,
None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
};
if worktrees.next().is_some() {
return Task::ready(Err(anyhow!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
)));
}
only_worktree.read(cx).abs_path()
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Task::ready(Err(anyhow!(
"The absolute path must be within one of the project's worktrees"
)));
}
input_path.into()
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
return Task::ready(Err(anyhow!(
"`cd` directory {} not found in the project",
&input.cd
)));
};
worktree.read(cx).abs_path()
};
cx.spawn(async move |_| {
// Add 2>&1 to merge stderr into stdout for proper interleaving.
let command = format!("({}) 2>&1", input.command);
let mut cmd = new_smol_command("bash")
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.spawn()
.context("Failed to execute bash command")?;
// Capture stdout with a limit
let stdout = cmd.stdout.take().unwrap();
let mut reader = BufReader::new(stdout);
const MESSAGE_1: &str = "Command output too long. The first ";
const MESSAGE_2: &str = " bytes:\n\n";
const ERR_MESSAGE_1: &str = "Command failed with exit code ";
const ERR_MESSAGE_2: &str = "\n\n";
const STDOUT_LIMIT: usize = 8192;
const LIMIT: usize = STDOUT_LIMIT
- (MESSAGE_1.len()
+ (STDOUT_LIMIT.ilog10() as usize + 1) // byte count
+ MESSAGE_2.len()
+ ERR_MESSAGE_1.len()
+ 3 // status code
+ ERR_MESSAGE_2.len());
// Read one more byte to determine whether the output was truncated
let mut buffer = vec![0; LIMIT + 1];
let mut bytes_read = 0;
// Read until we reach the limit
loop {
let read = reader.read(&mut buffer).await?;
if read == 0 {
break;
}
bytes_read += read;
if bytes_read > LIMIT {
bytes_read = LIMIT + 1;
break;
}
}
// Repeatedly fill the output reader's buffer without copying it.
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
}
let output_bytes = &buffer[..bytes_read];
// Let the process continue running
let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if bytes_read > LIMIT {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = output_bytes.iter().rposition(|b| *b == b'\n');
let output_string = String::from_utf8_lossy(
&output_bytes[..last_line_ix.unwrap_or(output_bytes.len())],
);
format!(
"{}{}{}{}",
MESSAGE_1,
output_string.len(),
MESSAGE_2,
output_string
)
} else {
String::from_utf8_lossy(&output_bytes).into()
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string.to_string()
}
} else {
format!(
"{}{}{}{}",
ERR_MESSAGE_1,
status.code().unwrap_or(-1),
ERR_MESSAGE_2,
output_string,
)
};
debug_assert!(output_with_status.len() <= STDOUT_LIMIT);
Ok(output_with_status)
})
}
}

View File

@@ -1,7 +0,0 @@
Executes a bash one-liner and returns the combined output.
This tool spawns a bash process, combines stdout and stderr into one interleaved stream as they are produced (preserving the order of writes), and captures that stream into a string which is returned.
Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error.
Remember that each invocation of this tool will spawn a new bash process, so you can't rely on any state from previous invocations.

View File

@@ -0,0 +1,389 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::LanguageModelRequestMessage;
use project::{self, LspAction, Project};
use regex::Regex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{ops::Range, sync::Arc};
use ui::IconName;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct CodeActionToolInput {
/// The relative path to the file containing the text range.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The specific code action to execute.
///
/// If this field is provided, the tool will execute the specified action.
/// If omitted, the tool will list all available code actions for the text range.
///
/// Here are some actions that are commonly supported (but may not be for this particular
/// text range; you can omit this field to list all the actions, if you want to know
/// what your options are, or you can just try an action and if it fails I'll tell you
/// what the available actions were instead):
/// - "quickfix.all" - applies all available quick fixes in the range
/// - "source.organizeImports" - sorts and cleans up import statements
/// - "source.fixAll" - applies all available auto fixes
/// - "refactor.extract" - extracts selected code into a new function or variable
/// - "refactor.inline" - inlines a variable by replacing references with its value
/// - "refactor.rewrite" - general code rewriting operations
/// - "source.addMissingImports" - adds imports for references that lack them
/// - "source.removeUnusedImports" - removes imports that aren't being used
/// - "source.implementInterface" - generates methods required by an interface/trait
/// - "source.generateAccessors" - creates getter/setter methods
/// - "source.convertToAsyncFunction" - converts callback-style code to async/await
///
/// Also, there is a special case: if you specify exactly "textDocument/rename" as the action,
/// then this will rename the symbol to whatever string you specified for the `arguments` field.
pub action: Option<String>,
/// Optional arguments to pass to the code action.
///
/// For rename operations (when action="textDocument/rename"), this should contain the new name.
/// For other code actions, these arguments may be passed to the language server.
pub arguments: Option<serde_json::Value>,
/// The text that comes immediately before the text range in the file.
pub context_before_range: String,
/// The text range. This text must appear in the file right between `context_before_range`
/// and `context_after_range`.
///
/// The file must contain exactly one occurrence of `context_before_range` followed by
/// `text_range` followed by `context_after_range`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the text range, as well as a minimum of 1 line of context after the text range.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub text_range: String,
/// The text that comes immediately after the text range in the file.
pub context_after_range: String,
}
pub struct CodeActionTool;
impl Tool for CodeActionTool {
fn name(&self) -> String {
"code_actions".into()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./code_action_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Wand
}
fn input_schema(
&self,
_format: language_model::LanguageModelToolSchemaFormat,
) -> serde_json::Value {
let schema = schemars::schema_for!(CodeActionToolInput);
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<CodeActionToolInput>(input.clone()) {
Ok(input) => {
if let Some(action) = &input.action {
if action == "textDocument/rename" {
let new_name = match &input.arguments {
Some(serde_json::Value::String(new_name)) => new_name.clone(),
Some(value) => {
if let Ok(new_name) =
serde_json::from_value::<String>(value.clone())
{
new_name
} else {
"invalid name".to_string()
}
}
None => "missing name".to_string(),
};
format!("Rename '{}' to '{}'", input.text_range, new_name)
} else {
format!(
"Execute code action '{}' for '{}'",
action, input.text_range
)
}
} else {
format!("List available code actions for '{}'", input.text_range)
}
}
Err(_) => "Perform code action".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
})?;
let range = {
let Some(range) = buffer.read_with(cx, |buffer, _cx| {
find_text_range(&buffer, &input.context_before_range, &input.text_range, &input.context_after_range)
})? else {
return Err(anyhow!(
"Failed to locate the text specified by context_before_range, text_range, and context_after_range. Make sure context_before_range and context_after_range each match exactly once in the file."
));
};
range
};
if let Some(action_type) = &input.action {
// Special-case the `rename` operation
let response = if action_type == "textDocument/rename" {
let Some(new_name) = input.arguments.and_then(|args| serde_json::from_value::<String>(args).ok()) else {
return Err(anyhow!("For rename operations, 'arguments' must be a string containing the new name"));
};
let position = buffer.read_with(cx, |buffer, _| {
range.start.to_point_utf16(&buffer.snapshot())
})?;
project
.update(cx, |project, cx| {
project.perform_rename(buffer.clone(), position, new_name.clone(), cx)
})?
.await?;
format!("Renamed '{}' to '{}'", input.text_range, new_name)
} else {
// Get code actions for the range
let actions = project
.update(cx, |project, cx| {
project.code_actions(&buffer, range.clone(), None, cx)
})?
.await?;
if actions.is_empty() {
return Err(anyhow!("No code actions available for this range"));
}
// Find all matching actions
let regex = match Regex::new(action_type) {
Ok(regex) => regex,
Err(err) => return Err(anyhow!("Invalid regex pattern: {}", err)),
};
let mut matching_actions = actions
.into_iter()
.filter(|action| { regex.is_match(action.lsp_action.title()) });
let Some(action) = matching_actions.next() else {
return Err(anyhow!("No code actions match the pattern: {}", action_type));
};
// There should have been exactly one matching action.
if let Some(second) = matching_actions.next() {
let mut all_matches = vec![action, second];
all_matches.extend(matching_actions);
return Err(anyhow!(
"Pattern '{}' matches multiple code actions: {}",
action_type,
all_matches.into_iter().map(|action| action.lsp_action.title().to_string()).collect::<Vec<_>>().join(", ")
));
}
let title = action.lsp_action.title().to_string();
project
.update(cx, |project, cx| {
project.apply_code_action(buffer.clone(), action, true, cx)
})?
.await?;
format!("Completed code action: {}", title)
};
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
})?;
Ok(response)
} else {
// No action specified, so list the available ones.
let (position_start, position_end) = buffer.read_with(cx, |buffer, _| {
let snapshot = buffer.snapshot();
(
range.start.to_point_utf16(&snapshot),
range.end.to_point_utf16(&snapshot)
)
})?;
// Convert position to display coordinates (1-based)
let position_start_display = language::Point {
row: position_start.row + 1,
column: position_start.column + 1,
};
let position_end_display = language::Point {
row: position_end.row + 1,
column: position_end.column + 1,
};
// Get code actions for the range
let actions = project
.update(cx, |project, cx| {
project.code_actions(&buffer, range.clone(), None, cx)
})?
.await?;
let mut response = format!(
"Available code actions for text range '{}' at position {}:{} to {}:{} (UTF-16 coordinates):\n\n",
input.text_range,
position_start_display.row, position_start_display.column,
position_end_display.row, position_end_display.column
);
if actions.is_empty() {
response.push_str("No code actions available for this range.");
} else {
for (i, action) in actions.iter().enumerate() {
let title = match &action.lsp_action {
LspAction::Action(code_action) => code_action.title.as_str(),
LspAction::Command(command) => command.title.as_str(),
LspAction::CodeLens(code_lens) => {
if let Some(cmd) = &code_lens.command {
cmd.title.as_str()
} else {
"Unknown code lens"
}
},
};
let kind = match &action.lsp_action {
LspAction::Action(code_action) => {
if let Some(kind) = &code_action.kind {
kind.as_str()
} else {
"unknown"
}
},
LspAction::Command(_) => "command",
LspAction::CodeLens(_) => "code_lens",
};
response.push_str(&format!("{}. {title} ({kind})\n", i + 1));
}
}
Ok(response)
}
})
}
}
/// Finds the range of the text in the buffer, if it appears between context_before_range
/// and context_after_range, and if that combined string has one unique result in the buffer.
///
/// If an exact match fails, it tries adding a newline to the end of context_before_range and
/// to the beginning of context_after_range to accommodate line-based context matching.
fn find_text_range(
buffer: &Buffer,
context_before_range: &str,
text_range: &str,
context_after_range: &str,
) -> Option<Range<Anchor>> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
// First try with exact match
let search_string = format!("{context_before_range}{text_range}{context_after_range}");
let mut positions = text.match_indices(&search_string);
let position_result = positions.next();
if let Some(position) = position_result {
// Check if the matched string is unique
if positions.next().is_none() {
let range_start = position.0 + context_before_range.len();
let range_end = range_start + text_range.len();
let range_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_start));
let range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_end));
return Some(range_start_anchor..range_end_anchor);
}
}
// If exact match fails or is not unique, try with line-based context
// Add a newline to the end of before context and beginning of after context
let line_based_before = if context_before_range.ends_with('\n') {
context_before_range.to_string()
} else {
format!("{context_before_range}\n")
};
let line_based_after = if context_after_range.starts_with('\n') {
context_after_range.to_string()
} else {
format!("\n{context_after_range}")
};
let line_search_string = format!("{line_based_before}{text_range}{line_based_after}");
let mut line_positions = text.match_indices(&line_search_string);
let line_position = line_positions.next()?;
// The line-based search string must also appear exactly once
if line_positions.next().is_some() {
return None;
}
let line_range_start = line_position.0 + line_based_before.len();
let line_range_end = line_range_start + text_range.len();
let line_range_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(line_range_start));
let line_range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(line_range_end));
Some(line_range_start_anchor..line_range_end_anchor)
}

View File

@@ -0,0 +1,19 @@
A tool for applying code actions to specific sections of your code. It uses language servers to provide refactoring capabilities similar to what you'd find in an IDE.
This tool can:
- List all available code actions for a selected text range
- Execute a specific code action on that range
- Rename symbols across your codebase. This tool is the preferred way to rename things, and you should always prefer to rename code symbols using this tool rather than using textual find/replace when both are available.
Use this tool when you want to:
- Discover what code actions are available for a piece of code
- Apply automatic fixes and code transformations
- Rename variables, functions, or other symbols consistently throughout your project
- Clean up imports, implement interfaces, or perform other language-specific operations
- If unsure what actions are available, call the tool without specifying an action to get a list
- For common operations, you can directly specify actions like "quickfix.all" or "source.organizeImports"
- For renaming, use the special "textDocument/rename" action and provide the new name in the arguments field
- Be specific with your text range and context to ensure the tool identifies the correct code location
The tool will automatically save any changes it makes to your files.

View File

@@ -179,11 +179,9 @@ pub async fn file_outline(
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while parse_status
.recv()
.await
.map_or(false, |status| status != ParseStatus::Idle)
{}
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {

View File

@@ -63,6 +63,16 @@ pub struct FindReplaceFileToolInput {
/// even one character in this string is different in any way from how it appears
/// in the file, then the tool call will fail.
///
/// If you get an error that the `find` string was not found, this means that either
/// you made a mistake, or that the file has changed since you last looked at it.
/// Either way, when this happens, you should retry doing this tool call until it
/// succeeds, up to 3 times. Each time you retry, you should take another look at
/// the exact text of the file in question, to make sure that you are searching for
/// exactly the right string. Regardless of whether it was because you made a mistake
/// or because the file changed since you last looked at it, you should be extra
/// careful when retrying in this way. It's a bad experience for the user if
/// this `find` string isn't found, so be super careful to get it exactly right!
///
/// <example>
/// If a file contains this code:
///

View File

@@ -1,8 +1,10 @@
Find one unique part of a file in the project and replace that text with new text.
This tool is the preferred way to make edits to files. If you have multiple edits to make, including edits across multiple files, then make a plan to respond with a single message containing multiple calls to this tool - one call for each find/replace operation.
This tool is the preferred way to make edits to files *except* when making a rename. When making a rename specifically, the rename tool must always be used instead.
You should use this tool when you want to edit a subset of a file's contents, but not the entire file. You should not use this tool when you want to replace the entire contents of a file with completely different contents. You also should not use this tool when you want to move or rename a file. You absolutely must NEVER use this tool to create new files from scratch. If you ever consider using this tool to create a new file from scratch, for any reason, instead you must reconsider and choose a different approach.
If you have multiple edits to make, including edits across multiple files, then make a plan to respond with a single message containing a batch of calls to this tool - one call for each find/replace operation.
You should only use this tool when you want to edit a subset of a file's contents, but not the entire file. You should not use this tool when you want to replace the entire contents of a file with completely different contents. You also should not use this tool when you want to move or rename a file. You absolutely must NEVER use this tool to create new files from scratch. If you ever consider using this tool to create a new file from scratch, for any reason, instead you must reconsider and choose a different approach.
DO NOT call this tool until the code to be edited appears in the conversation! You must use another tool to read the file's contents into the conversation, or ask the user to add it to context first.

View File

@@ -26,6 +26,10 @@ pub struct RegexSearchToolInput {
/// When not provided, starts from the beginning.
#[serde(default)]
pub offset: u32,
/// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
#[serde(default)]
pub case_sensitive: bool,
}
impl RegexSearchToolInput {
@@ -64,12 +68,17 @@ impl Tool for RegexSearchTool {
match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
Ok(input) => {
let page = input.page();
let regex = MarkdownString::inline_code(&input.regex);
let regex_str = MarkdownString::inline_code(&input.regex);
let case_info = if input.case_sensitive {
" (case-sensitive)"
} else {
""
};
if page > 1 {
format!("Get page {page} of search results for regex {regex}")
format!("Get page {page} of search results for regex {regex_str}{case_info}")
} else {
format!("Search files for regex {regex}")
format!("Search files for regex {regex_str}{case_info}")
}
}
Err(_) => "Search with regex".to_string(),
@@ -86,15 +95,16 @@ impl Tool for RegexSearchTool {
) -> Task<Result<String>> {
const CONTEXT_LINES: u32 = 2;
let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => (input.offset, input.regex),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let (offset, regex, case_sensitive) =
match serde_json::from_value::<RegexSearchToolInput>(input) {
Ok(input) => (input.offset, input.regex, input.case_sensitive),
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let query = match SearchQuery::regex(
&regex,
false,
false,
case_sensitive,
false,
PathMatcher::default(),
PathMatcher::default(),

View File

@@ -0,0 +1,205 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use gpui::{App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::IconName;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RenameToolInput {
/// The relative path to the file containing the symbol to rename.
///
/// WARNING: you MUST start this path with one of the project's root directories.
pub path: String,
/// The new name to give to the symbol.
pub new_name: String,
/// The text that comes immediately before the symbol in the file.
pub context_before_symbol: String,
/// The symbol to rename. This text must appear in the file right between
/// `context_before_symbol` and `context_after_symbol`.
///
/// The file must contain exactly one occurrence of `context_before_symbol` followed by
/// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
/// or if it contains more than one occurrence, the tool will fail, so it is absolutely
/// critical that you verify ahead of time that the string is unique. You can search
/// the file's contents to verify this ahead of time.
///
/// To make the string more likely to be unique, include a minimum of 1 line of context
/// before the symbol, as well as a minimum of 1 line of context after the symbol.
/// If these lines of context are not enough to obtain a string that appears only once
/// in the file, then double the number of context lines until the string becomes unique.
/// (Start with 1 line before and 1 line after though, because too much context is
/// needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. The combined string must be exactly
/// as it appears in the file, or else this tool call will fail.
pub symbol: String,
/// The text that comes immediately after the symbol in the file.
pub context_after_symbol: String,
}
pub struct RenameTool;
impl Tool for RenameTool {
fn name(&self) -> String {
"rename".into()
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("./rename_tool/description.md").into()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(
&self,
_format: language_model::LanguageModelToolSchemaFormat,
) -> serde_json::Value {
let schema = schemars::schema_for!(RenameToolInput);
serde_json::to_value(&schema).unwrap()
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<RenameToolInput>(input.clone()) {
Ok(input) => {
format!("Rename '{}' to '{}'", input.symbol, input.new_name)
}
Err(_) => "Rename symbol".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input = match serde_json::from_value::<RenameToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
cx.spawn(async move |cx| {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
};
action_log.update(cx, |action_log, cx| {
action_log.buffer_read(buffer.clone(), cx);
})?;
let position = {
let Some(position) = buffer.read_with(cx, |buffer, _cx| {
find_symbol_position(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
})? else {
return Err(anyhow!(
"Failed to locate the symbol specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
));
};
buffer.read_with(cx, |buffer, _| {
position.to_point_utf16(&buffer.snapshot())
})?
};
project
.update(cx, |project, cx| {
project.perform_rename(buffer.clone(), position, input.new_name.clone(), cx)
})?
.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
})?;
Ok(format!("Renamed '{}' to '{}'", input.symbol, input.new_name))
})
}
}
/// Finds the position of the symbol in the buffer, if it appears between context_before_symbol
/// and context_after_symbol, and if that combined string has one unique result in the buffer.
///
/// If an exact match fails, it tries adding a newline to the end of context_before_symbol and
/// to the beginning of context_after_symbol to accommodate line-based context matching.
fn find_symbol_position(
buffer: &Buffer,
context_before_symbol: &str,
symbol: &str,
context_after_symbol: &str,
) -> Option<language::Anchor> {
let snapshot = buffer.snapshot();
let text = snapshot.text();
// First try with exact match
let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
let mut positions = text.match_indices(&search_string);
let position_result = positions.next();
if let Some(position) = position_result {
// Check if the matched string is unique
if positions.next().is_none() {
let symbol_start = position.0 + context_before_symbol.len();
let symbol_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
return Some(symbol_start_anchor);
}
}
// If exact match fails or is not unique, try with line-based context
// Add a newline to the end of before context and beginning of after context
let line_based_before = if context_before_symbol.ends_with('\n') {
context_before_symbol.to_string()
} else {
format!("{context_before_symbol}\n")
};
let line_based_after = if context_after_symbol.starts_with('\n') {
context_after_symbol.to_string()
} else {
format!("\n{context_after_symbol}")
};
let line_search_string = format!("{line_based_before}{symbol}{line_based_after}");
let mut line_positions = text.match_indices(&line_search_string);
let line_position = line_positions.next()?;
// The line-based search string must also appear exactly once
if line_positions.next().is_some() {
return None;
}
let line_symbol_start = line_position.0 + line_based_before.len();
let line_symbol_start_anchor =
snapshot.anchor_before(snapshot.offset_to_point(line_symbol_start));
Some(line_symbol_start_anchor)
}

View File

@@ -0,0 +1,15 @@
Renames a symbol across your codebase using the language server's semantic knowledge.
This tool performs a rename refactoring operation on a specified symbol. It uses the project's language server to analyze the code and perform the rename correctly across all files where the symbol is referenced.
Unlike a simple find and replace, this tool understands the semantic meaning of the code, so it only renames the specific symbol you specify and not unrelated text that happens to have the same name.
Examples of symbols you can rename:
- Variables
- Functions
- Classes/structs
- Fields/properties
- Methods
- Interfaces/traits
The language server handles updating all references to the renamed symbol throughout the codebase.

View File

@@ -0,0 +1,366 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::future;
use util::get_system_shell;
use std::path::Path;
use std::sync::Arc;
use ui::IconName;
use util::command::new_smol_command;
use util::markdown::MarkdownString;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct TerminalToolInput {
/// The one-liner command to execute.
command: String,
/// Working directory for the command. This must be one of the root directories of the project.
cd: String,
}
pub struct TerminalTool;
impl Tool for TerminalTool {
fn name(&self) -> String {
"terminal".to_string()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
true
}
fn description(&self) -> String {
include_str!("./terminal_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Terminal
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
json_schema_for::<TerminalToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<TerminalToolInput>(input.clone()) {
Ok(input) => {
let mut lines = input.command.lines();
let first_line = lines.next().unwrap_or_default();
let remaining_line_count = lines.count();
match remaining_line_count {
0 => MarkdownString::inline_code(&first_line).0,
1 => {
MarkdownString::inline_code(&format!(
"{} - {} more line",
first_line, remaining_line_count
))
.0
}
n => {
MarkdownString::inline_code(&format!("{} - {} more lines", first_line, n)).0
}
}
}
Err(_) => "Run terminal command".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
cx: &mut App,
) -> Task<Result<String>> {
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let project = project.read(cx);
let input_path = Path::new(&input.cd);
let working_dir = if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
let only_worktree = match worktrees.next() {
Some(worktree) => worktree,
None => return Task::ready(Err(anyhow!("No worktrees found in the project"))),
};
if worktrees.next().is_some() {
return Task::ready(Err(anyhow!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
)));
}
only_worktree.read(cx).abs_path()
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Task::ready(Err(anyhow!(
"The absolute path must be within one of the project's worktrees"
)));
}
input_path.into()
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
return Task::ready(Err(anyhow!(
"`cd` directory {} not found in the project",
&input.cd
)));
};
worktree.read(cx).abs_path()
};
cx.background_spawn(run_command_limited(working_dir, input.command))
}
}
const LIMIT: usize = 16 * 1024;
async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
let shell = get_system_shell();
let mut cmd = new_smol_command(&shell)
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.context("Failed to execute terminal command")?;
let mut combined_buffer = String::with_capacity(LIMIT + 1);
let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
let mut out_tmp_buffer = String::with_capacity(512);
let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
let mut err_tmp_buffer = String::with_capacity(512);
let mut out_line = Box::pin(
out_reader
.read_line(&mut out_tmp_buffer)
.left_future()
.fuse(),
);
let mut err_line = Box::pin(
err_reader
.read_line(&mut err_tmp_buffer)
.left_future()
.fuse(),
);
let mut has_stdout = true;
let mut has_stderr = true;
while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
futures::select_biased! {
read = out_line => {
drop(out_line);
combined_buffer.extend(out_tmp_buffer.drain(..));
if read? == 0 {
out_line = Box::pin(future::pending().right_future().fuse());
has_stdout = false;
} else {
out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
}
}
read = err_line => {
drop(err_line);
combined_buffer.extend(err_tmp_buffer.drain(..));
if read? == 0 {
err_line = Box::pin(future::pending().right_future().fuse());
has_stderr = false;
} else {
err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
}
}
};
}
drop((out_line, err_line));
let truncated = combined_buffer.len() > LIMIT;
combined_buffer.truncate(LIMIT);
consume_reader(out_reader, truncated).await?;
consume_reader(err_reader, truncated).await?;
let status = cmd.status().await.context("Failed to get command status")?;
let output_string = if truncated {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
let combined_buffer = &combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
format!(
"Command output too long. The first {} bytes:\n\n{}",
combined_buffer.len(),
output_block(&combined_buffer),
)
} else {
output_block(&combined_buffer)
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string.to_string()
}
} else {
format!(
"Command failed with exit code {} (shell: {}).\n\n{}",
status.code().unwrap_or(-1),
shell,
output_string,
)
};
Ok(output_with_status)
}
async fn consume_reader<T: AsyncReadExt + Unpin>(
mut reader: BufReader<T>,
truncated: bool,
) -> Result<(), std::io::Error> {
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
// Should only skip if we went over the limit
debug_assert!(truncated);
}
Ok(())
}
fn output_block(output: &str) -> String {
format!(
"```\n{}{}```",
output,
if output.ends_with('\n') { "" } else { "\n" }
)
}
#[cfg(test)]
#[cfg(not(windows))]
mod tests {
use gpui::TestAppContext;
use super::*;
#[gpui::test(iterations = 10)]
async fn test_run_command_simple(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result =
run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\nHello, World!\n```");
}
#[gpui::test(iterations = 10)]
async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
"```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
);
}
#[gpui::test(iterations = 10)]
async fn test_multiple_output_reads(cx: &mut TestAppContext) {
cx.executor().allow_parking();
// Command with multiple outputs that might require multiple reads
let result = run_command_limited(
Path::new(".").into(),
"echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
}
#[gpui::test(iterations = 10)]
async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
// Output should be exactly the limit
assert_eq!(content_length, LIMIT);
}
#[gpui::test(iterations = 10)]
async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
assert!(content_length <= LIMIT);
}
#[gpui::test(iterations = 10)]
async fn test_command_failure(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
assert!(result.is_ok());
let output = result.unwrap();
// Extract the shell name from path for cleaner test output
let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
let expected_output = format!(
"Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
shell_path
);
assert_eq!(output, expected_output);
}
}

View File

@@ -0,0 +1,9 @@
Executes a shell one-liner and returns the combined output.
This tool spawns a process using the user's current shell, combines stdout and stderr into one interleaved stream as they are produced (preserving the order of writes), and captures that stream into a string which is returned.
Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error.
Do not use this tool for commands that run indefinitely, such as servers (e.g., `python -m http.server`) or file watchers that don't terminate on their own.
Remember that each invocation of this tool will spawn a new shell process, so you can't rely on any state from previous invocations.

View File

@@ -16,6 +16,7 @@ pub enum CliRequest {
wait: bool,
open_new_workspace: Option<bool>,
env: Option<HashMap<String, String>>,
user_data_dir: Option<String>,
},
}

View File

@@ -26,7 +26,11 @@ struct Detect;
trait InstalledApp {
fn zed_version_string(&self) -> String;
fn launch(&self, ipc_url: String) -> anyhow::Result<()>;
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus>;
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus>;
fn path(&self) -> PathBuf;
}
@@ -58,6 +62,13 @@ struct Args {
/// Create a new workspace
#[arg(short, long, overrides_with = "add")]
new: bool,
/// Sets a custom directory for all user data (e.g., database, extensions, logs).
/// This overrides the default platform-specific data directory location.
/// On macOS, the default is `~/Library/Application Support/Zed`.
/// On Linux/FreeBSD, the default is `$XDG_DATA_HOME/zed`.
/// On Windows, the default is `%LOCALAPPDATA%\Zed`.
#[arg(long, value_name = "DIR")]
user_data_dir: Option<String>,
/// The paths to open in Zed (space-separated).
///
/// Use `path:line:column` syntax to open a file at the given line and column.
@@ -135,6 +146,12 @@ fn main() -> Result<()> {
}
let args = Args::parse();
// Set custom data directory before any path operations
let user_data_dir = args.user_data_dir.clone();
if let Some(dir) = &user_data_dir {
paths::set_custom_data_dir(dir);
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
let args = flatpak::set_bin_if_no_escape(args);
@@ -246,6 +263,7 @@ fn main() -> Result<()> {
let sender: JoinHandle<anyhow::Result<()>> = thread::spawn({
let exit_status = exit_status.clone();
let user_data_dir_for_thread = user_data_dir.clone();
move || {
let (_, handshake) = server.accept().context("Handshake after Zed spawn")?;
let (tx, rx) = (handshake.requests, handshake.responses);
@@ -256,6 +274,7 @@ fn main() -> Result<()> {
wait: args.wait,
open_new_workspace,
env,
user_data_dir: user_data_dir_for_thread,
})?;
while let Ok(response) = rx.recv() {
@@ -291,7 +310,7 @@ fn main() -> Result<()> {
.collect();
if args.foreground {
app.run_foreground(url)?;
app.run_foreground(url, user_data_dir.as_deref())?;
} else {
app.launch(url)?;
sender.join().unwrap()?;
@@ -437,7 +456,7 @@ mod linux {
}
fn launch(&self, ipc_url: String) -> anyhow::Result<()> {
let sock_path = paths::support_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL));
let sock_path = paths::data_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL));
let sock = UnixDatagram::unbound()?;
if sock.connect(&sock_path).is_err() {
self.boot_background(ipc_url)?;
@@ -447,10 +466,17 @@ mod linux {
Ok(())
}
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
std::process::Command::new(self.0.clone())
.arg(ipc_url)
.status()
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus> {
let mut cmd = std::process::Command::new(self.0.clone());
cmd.arg(ipc_url);
if let Some(dir) = user_data_dir {
cmd.arg("--user-data-dir").arg(dir);
}
cmd.status()
}
fn path(&self) -> PathBuf {
@@ -688,12 +714,17 @@ mod windows {
Ok(())
}
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
std::process::Command::new(self.0.clone())
.arg(ipc_url)
.arg("--foreground")
.spawn()?
.wait()
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus> {
let mut cmd = std::process::Command::new(self.0.clone());
cmd.arg(ipc_url).arg("--foreground");
if let Some(dir) = user_data_dir {
cmd.arg("--user-data-dir").arg(dir);
}
cmd.spawn()?.wait()
}
fn path(&self) -> PathBuf {
@@ -875,13 +906,22 @@ mod mac_os {
Ok(())
}
fn run_foreground(&self, ipc_url: String) -> io::Result<ExitStatus> {
fn run_foreground(
&self,
ipc_url: String,
user_data_dir: Option<&str>,
) -> io::Result<ExitStatus> {
let path = match self {
Bundle::App { app_bundle, .. } => app_bundle.join("Contents/MacOS/zed"),
Bundle::LocalPath { executable, .. } => executable.clone(),
};
std::process::Command::new(path).arg(ipc_url).status()
let mut cmd = std::process::Command::new(path);
cmd.arg(ipc_url);
if let Some(dir) = user_data_dir {
cmd.arg("--user-data-dir").arg(dir);
}
cmd.status()
}
fn path(&self) -> PathBuf {

View File

@@ -274,7 +274,7 @@ async fn create_billing_subscription(
customer.id
};
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-7-sonnet")?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!(
"{}/account?checkout_complete=1",

View File

@@ -187,22 +187,20 @@ impl ComponentPreview {
let mut entries = Vec::new();
let known_scopes = [
ComponentScope::Layout,
ComponentScope::Input,
ComponentScope::Editor,
ComponentScope::Notification,
ComponentScope::Collaboration,
ComponentScope::VersionControl,
ComponentScope::None,
];
// Always show all components first
entries.push(PreviewEntry::AllComponents);
entries.push(PreviewEntry::Separator);
for scope in known_scopes.iter() {
if let Some(components) = scope_groups.remove(scope) {
let mut scopes: Vec<_> = scope_groups
.keys()
.filter(|scope| !matches!(**scope, ComponentScope::None))
.cloned()
.collect();
scopes.sort_by_key(|s| s.to_string());
for scope in scopes {
if let Some(components) = scope_groups.remove(&scope) {
if !components.is_empty() {
entries.push(PreviewEntry::SectionHeader(scope.to_string().into()));
let mut sorted_components = components;
@@ -215,6 +213,7 @@ impl ComponentPreview {
}
}
// Add uncategorized components last
if let Some(components) = scope_groups.get(&ComponentScope::None) {
if !components.is_empty() {
entries.push(PreviewEntry::Separator);
@@ -272,7 +271,12 @@ impl ComponentPreview {
.into_any_element()
}
PreviewEntry::Separator => ListItem::new(ix)
.child(h_flex().pt_3().child(Divider::horizontal_dashed()))
.child(
h_flex()
.occlude()
.pt_3()
.child(Divider::horizontal_dashed()),
)
.into_any_element(),
}
}

View File

@@ -85,6 +85,10 @@ pub fn lsp_tasks(
.map(|(name, buffer_ids)| {
let buffers = buffer_ids
.iter()
.filter(|&&buffer_id| match for_position {
Some(for_position) => for_position.buffer_id == Some(buffer_id),
None => true,
})
.filter_map(|&buffer_id| project.read(cx).buffer_for_id(buffer_id, cx))
.collect::<Vec<_>>();
language_server_for_buffers(project.clone(), name.clone(), buffers, cx)

View File

@@ -1,25 +1,16 @@
[package]
name = "agent_eval"
name = "eval"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[[bin]]
name = "agent_eval"
path = "src/main.rs"
edition.workspace = true
[dependencies]
agent.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
clap.workspace = true
assistant_settings.workspace = true
client.workspace = true
collections.workspace = true
context_server.workspace = true
dap.workspace = true
env_logger.workspace = true
@@ -36,11 +27,13 @@ prompt_store.workspace = true
release_channel.workspace = true
reqwest_client.workspace = true
serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
tempfile.workspace = true
util.workspace = true
walkdir.workspace = true
toml.workspace = true
workspace-hack.workspace = true
[[bin]]
name = "eval"
path = "src/eval.rs"
[lints]
workspace = true

7
crates/eval/README.md Normal file
View File

@@ -0,0 +1,7 @@
# Eval
This eval assumes the working directory is the root of the repository. Run it with:
```sh
cargo run -p eval
```

View File

@@ -0,0 +1,2 @@
path = "../zed_worktree"
revision = "38fcadf9481d018543c65f36ac3bafeba190179b"

View File

@@ -0,0 +1,3 @@
Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation.
The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line.

145
crates/eval/src/eval.rs Normal file
View File

@@ -0,0 +1,145 @@
mod example;
use assistant_settings::AssistantSettings;
use client::{Client, UserStore};
pub(crate) use example::*;
use ::fs::RealFs;
use anyhow::anyhow;
use gpui::{App, AppContext, Application, Entity, SemanticVersion, Task};
use language::LanguageRegistry;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
};
use node_runtime::NodeRuntime;
use project::Project;
use prompt_store::PromptBuilder;
use reqwest_client::ReqwestClient;
use settings::{Settings, SettingsStore};
use std::sync::Arc;
fn main() {
env_logger::init();
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client.clone());
app.run(move |cx| {
let app_state = init(cx);
let model = find_model("claude-3-7-sonnet-thinking-latest", cx).unwrap();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.set_default_model(Some(model.clone()), cx);
});
let model_provider_id = model.provider_id();
let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
cx.spawn(async move |cx| {
authenticate.await.unwrap();
let example =
Example::load_from_directory("./crates/eval/examples/find_and_replace_diff_card")?;
example.setup()?;
cx.update(|cx| example.run(model, app_state, cx))?.await?;
anyhow::Ok(())
})
.detach_and_log_err(cx);
});
}
/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
pub struct AgentAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
pub fs: Arc<dyn fs::Fs>,
pub node_runtime: NodeRuntime,
// Additional fields not present in `workspace::AppState`.
pub prompt_builder: Arc<PromptBuilder>,
}
pub fn init(cx: &mut App) -> Arc<AgentAppState> {
release_channel::init(SemanticVersion::default(), cx);
gpui_tokio::init(cx);
let mut settings_store = SettingsStore::new(cx);
settings_store
.set_default_settings(settings::default_settings().as_ref(), cx)
.unwrap();
cx.set_global(settings_store);
client::init_settings(cx);
Project::init_settings(cx);
let client = Client::production(cx);
cx.set_http_client(client.http_client().clone());
let git_binary_path = None;
let fs = Arc::new(RealFs::new(
git_binary_path,
cx.background_executor().clone(),
));
let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language::init(cx);
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
assistant_tools::init(client.http_client().clone(), cx);
context_server::init(cx);
let stdout_is_a_pty = false;
let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
AssistantSettings::override_global(
AssistantSettings {
always_allow_tool_actions: true,
..AssistantSettings::get_global(cx).clone()
},
cx,
);
Arc::new(AgentAppState {
languages,
client,
user_store,
fs,
node_runtime: NodeRuntime::unavailable(),
prompt_builder,
})
}
pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model = model_registry
.available_models(cx)
.find(|model| model.id().0 == model_name);
let Some(model) = model else {
return Err(anyhow!(
"No language model named {} was available. Available models: {}",
model_name,
model_registry
.available_models(cx)
.map(|model| model.id().0.clone())
.collect::<Vec<_>>()
.join(", ")
));
};
Ok(model)
}
pub fn authenticate_model_provider(
provider_id: LanguageModelProviderId,
cx: &mut App,
) -> Task<std::result::Result<(), AuthenticateError>> {
let model_registry = LanguageModelRegistry::read_global(cx);
let model_provider = model_registry.provider(&provider_id).unwrap();
model_provider.authenticate(cx)
}

178
crates/eval/src/example.rs Normal file
View File

@@ -0,0 +1,178 @@
use agent::{RequestKind, ThreadEvent, ThreadStore};
use anyhow::{Result, anyhow};
use assistant_tool::ToolWorkingSet;
use dap::DapRegistry;
use futures::channel::oneshot;
use gpui::{App, Task};
use language_model::{LanguageModel, StopReason};
use project::Project;
use serde::Deserialize;
use std::process::Command;
use std::sync::Arc;
use std::{
fs,
path::{Path, PathBuf},
};
use crate::AgentAppState;
#[derive(Debug, Deserialize)]
pub struct ExampleBase {
pub path: PathBuf,
pub revision: String,
}
#[derive(Debug)]
pub struct Example {
pub base: ExampleBase,
/// Content of the prompt.md file
pub prompt: String,
/// Content of the rubric.md file
pub _rubric: String,
}
impl Example {
/// Load an example from a directory containing base.toml, prompt.md, and rubric.md
pub fn load_from_directory<P: AsRef<Path>>(dir_path: P) -> Result<Self> {
let base_path = dir_path.as_ref().join("base.toml");
let prompt_path = dir_path.as_ref().join("prompt.md");
let rubric_path = dir_path.as_ref().join("rubric.md");
let mut base: ExampleBase = toml::from_str(&fs::read_to_string(&base_path)?)?;
base.path = base.path.canonicalize()?;
Ok(Example {
base,
prompt: fs::read_to_string(prompt_path)?,
_rubric: fs::read_to_string(rubric_path)?,
})
}
/// Set up the example by checking out the specified Git revision
pub fn setup(&self) -> Result<()> {
// Check if the directory exists
let path = Path::new(&self.base.path);
anyhow::ensure!(path.exists(), "Path does not exist: {:?}", self.base.path);
// Change to the project directory and checkout the specified revision
let output = Command::new("git")
.current_dir(&self.base.path)
.arg("checkout")
.arg(&self.base.revision)
.output()?;
anyhow::ensure!(
output.status.success(),
"Failed to checkout revision {}: {}",
self.base.revision,
String::from_utf8_lossy(&output.stderr),
);
Ok(())
}
pub fn run(
self,
model: Arc<dyn LanguageModel>,
app_state: Arc<AgentAppState>,
cx: &mut App,
) -> Task<Result<()>> {
let project = Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
Arc::new(DapRegistry::default()),
app_state.fs.clone(),
None,
cx,
);
let worktree = project.update(cx, |project, cx| {
project.create_worktree(self.base.path, true, cx)
});
let tools = Arc::new(ToolWorkingSet::default());
let thread_store =
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
println!("USER:");
println!("{}", self.prompt);
println!("ASSISTANT:");
cx.spawn(async move |cx| {
worktree.await?;
let thread_store = thread_store.await;
let thread =
thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
let (tx, rx) = oneshot::channel();
let mut tx = Some(tx);
let _subscription =
cx.subscribe(
&thread,
move |thread, event: &ThreadEvent, cx| match event {
ThreadEvent::Stopped(reason) => match reason {
Ok(StopReason::EndTurn) => {
if let Some(tx) = tx.take() {
tx.send(Ok(())).ok();
}
}
Ok(StopReason::MaxTokens) => {
if let Some(tx) = tx.take() {
tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
}
}
Ok(StopReason::ToolUse) => {}
Err(error) => {
if let Some(tx) = tx.take() {
tx.send(Err(anyhow!(error.clone()))).ok();
}
}
},
ThreadEvent::ShowError(thread_error) => {
if let Some(tx) = tx.take() {
tx.send(Err(anyhow!(thread_error.clone()))).ok();
}
}
ThreadEvent::StreamedAssistantText(_, chunk) => {
print!("{}", chunk);
}
ThreadEvent::StreamedAssistantThinking(_, chunk) => {
print!("{}", chunk);
}
ThreadEvent::UsePendingTools { tool_uses } => {
println!("\n\nUSING TOOLS:");
for tool_use in tool_uses {
println!("{}: {}", tool_use.name, tool_use.input);
}
}
ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
..
} => {
if let Some(tool_use) = pending_tool_use {
println!("\nTOOL FINISHED: {}", tool_use.name);
}
if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
println!("\n{}\n", tool_result.content);
}
}
_ => {}
},
)?;
thread.update(cx, |thread, cx| {
let context = vec![];
thread.insert_user_message(self.prompt.clone(), context, None, cx);
thread.send_to_model(model, RequestKind::Chat, cx);
})?;
rx.await??;
Ok(())
})
}
}

View File

@@ -1,7 +1,7 @@
use crate::{
AnyView, AnyWindowHandle, App, AppCell, AppContext, BackgroundExecutor, BorrowAppContext,
Entity, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation, Result, Task,
VisualContext, Window, WindowHandle,
Entity, EventEmitter, Focusable, ForegroundExecutor, Global, PromptLevel, Render, Reservation,
Result, Subscription, Task, VisualContext, Window, WindowHandle,
};
use anyhow::{Context as _, anyhow};
use derive_more::{Deref, DerefMut};
@@ -154,6 +154,26 @@ impl AsyncApp {
Ok(lock.update(f))
}
/// Arrange for the given callback to be invoked whenever the given entity emits an event of a given type.
/// The callback is provided a handle to the emitting entity and a reference to the emitted event.
pub fn subscribe<T, Event>(
&mut self,
entity: &Entity<T>,
mut on_event: impl FnMut(Entity<T>, &Event, &mut App) + 'static,
) -> Result<Subscription>
where
T: 'static + EventEmitter<Event>,
Event: 'static,
{
let app = self
.app
.upgrade()
.ok_or_else(|| anyhow!("app was released"))?;
let mut lock = app.borrow_mut();
let subscription = lock.subscribe(entity, on_event);
Ok(subscription)
}
/// Open a window with the given options based on the root view returned by the given function.
pub fn open_window<V>(
&self,

View File

@@ -53,7 +53,7 @@ impl IndexedDocsProvider for LocalRustdocProvider {
}
fn database_path(&self) -> PathBuf {
paths::support_dir().join("docs/rust/rustdoc-db.1.mdb")
paths::data_dir().join("docs/rust/rustdoc-db.1.mdb")
}
async fn suggest_packages(&self) -> Result<Vec<PackageName>> {
@@ -144,7 +144,7 @@ impl IndexedDocsProvider for DocsDotRsProvider {
}
fn database_path(&self) -> PathBuf {
paths::support_dir().join("docs/rust/docs-rs-db.1.mdb")
paths::data_dir().join("docs/rust/docs-rs-db.1.mdb")
}
async fn suggest_packages(&self) -> Result<Vec<PackageName>> {

View File

@@ -254,6 +254,7 @@ impl LanguageModel for CopilotChatLanguageModel {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err)).boxed(),
};
let is_streaming = copilot_request.stream;
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(async move |cx| {
@@ -261,7 +262,10 @@ impl LanguageModel for CopilotChatLanguageModel {
request_limiter
.stream(async move {
let response = request.await?;
Ok(map_to_language_model_completion_events(response))
Ok(map_to_language_model_completion_events(
response,
is_streaming,
))
})
.await
});
@@ -271,6 +275,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
is_streaming: bool,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
#[derive(Default)]
struct RawToolCall {
@@ -289,7 +294,7 @@ pub fn map_to_language_model_completion_events(
events,
tool_calls_by_index: HashMap::default(),
},
|mut state| async move {
move |mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
@@ -300,7 +305,13 @@ pub fn map_to_language_model_completion_events(
));
};
let Some(delta) = choice.delta.as_ref() else {
let delta = if is_streaming {
choice.delta.as_ref()
} else {
choice.message.as_ref()
};
let Some(delta) = delta else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
state,
@@ -312,26 +323,26 @@ pub fn map_to_language_model_completion_events(
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
for tool_call in &delta.tool_calls {
let entry = state
.tool_calls_by_index
.entry(tool_call.index)
.or_default();
for tool_call in &delta.tool_calls {
let entry = state
.tool_calls_by_index
.entry(tool_call.index)
.or_default();
if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
if let Some(tool_id) = tool_call.id.clone() {
entry.id = tool_id;
}
if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}
if let Some(function) = tool_call.function.as_ref() {
if let Some(name) = function.name.clone() {
entry.name = name;
}
if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
if let Some(arguments) = function.arguments.clone() {
entry.arguments.push_str(&arguments);
}
}
}
match choice.finish_reason.as_deref() {
Some("stop") => {
@@ -361,7 +372,7 @@ pub fn map_to_language_model_completion_events(
)));
}
Some(stop_reason) => {
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));

View File

@@ -18,6 +18,7 @@ use gpui::{
TextStyleRefinement, actions, point, quad,
};
use language::{Language, LanguageRegistry, Rope};
use parser::CodeBlockMetadata;
use parser::{MarkdownEvent, MarkdownTag, MarkdownTagEnd, parse_links_only, parse_markdown};
use pulldown_cmark::Alignment;
use sum_tree::TreeMap;
@@ -88,12 +89,30 @@ struct Options {
}
pub enum CodeBlockRenderer {
Default { copy_button: bool },
Custom { render: CodeBlockRenderFn },
Default {
copy_button: bool,
},
Custom {
render: CodeBlockRenderFn,
/// A function that can modify the parent container after the code block
/// content has been appended as a child element.
transform: Option<CodeBlockTransformFn>,
},
}
pub type CodeBlockRenderFn =
Arc<dyn Fn(usize, &CodeBlockKind, &ParsedMarkdown, Range<usize>, &mut Window, &App) -> Div>;
pub type CodeBlockRenderFn = Arc<
dyn Fn(
&CodeBlockKind,
&ParsedMarkdown,
Range<usize>,
CodeBlockMetadata,
&mut Window,
&App,
) -> Div,
>;
pub type CodeBlockTransformFn =
Arc<dyn Fn(AnyDiv, Range<usize>, CodeBlockMetadata, &mut Window, &App) -> AnyDiv>;
actions!(markdown, [Copy, CopyAsMarkdown]);
@@ -594,7 +613,9 @@ impl Element for MarkdownElement {
0
};
for (index, (range, event)) in parsed_markdown.events.iter().enumerate() {
let mut current_code_block_metadata = None;
for (range, event) in parsed_markdown.events.iter() {
match event {
MarkdownEvent::Start(tag) => {
match tag {
@@ -632,7 +653,7 @@ impl Element for MarkdownElement {
markdown_end,
);
}
MarkdownTag::CodeBlock(kind) => {
MarkdownTag::CodeBlock { kind, metadata } => {
let language = match kind {
CodeBlockKind::Fenced => None,
CodeBlockKind::FencedLang(language) => {
@@ -645,6 +666,8 @@ impl Element for MarkdownElement {
_ => None,
};
current_code_block_metadata = Some(metadata.clone());
let is_indented = matches!(kind, CodeBlockKind::Indented);
match (&self.code_block_renderer, is_indented) {
@@ -676,12 +699,12 @@ impl Element for MarkdownElement {
builder.push_code_block(language);
builder.push_div(code_block, range, markdown_end);
}
(CodeBlockRenderer::Custom { render }, _) => {
(CodeBlockRenderer::Custom { render, .. }, _) => {
let parent_container = render(
index,
kind,
&parsed_markdown,
range.clone(),
metadata.clone(),
window,
cx,
);
@@ -695,9 +718,12 @@ impl Element for MarkdownElement {
if self.style.code_block_overflow_x_scroll {
code_block.style().restrict_scroll_to_axis =
Some(true);
code_block.flex().overflow_x_scroll()
code_block
.flex()
.overflow_x_scroll()
.overflow_y_hidden()
} else {
code_block.w_full()
code_block.w_full().overflow_hidden()
}
});
@@ -846,15 +872,37 @@ impl Element for MarkdownElement {
builder.pop_text_style();
}
let metadata = current_code_block_metadata.take();
if let CodeBlockRenderer::Custom {
transform: Some(transform),
..
} = &self.code_block_renderer
{
builder.modify_current_div(|el| {
transform(
el,
range.clone(),
metadata.clone().unwrap_or_default(),
window,
cx,
)
});
}
if matches!(
&self.code_block_renderer,
CodeBlockRenderer::Default { copy_button: true }
) {
builder.flush_text();
builder.modify_current_div(|el| {
let code =
without_fences(parsed_markdown.source()[range.clone()].trim())
.to_string();
let content_range = parser::extract_code_block_content_range(
parsed_markdown.source()[range.clone()].trim(),
);
let content_range = content_range.start + range.start
..content_range.end + range.start;
let code = parsed_markdown.source()[content_range].to_string();
let codeblock = render_copy_code_block_button(
range.end,
code,
@@ -1049,7 +1097,7 @@ impl IntoElement for MarkdownElement {
}
}
enum AnyDiv {
pub enum AnyDiv {
Div(Div),
Stateful(Stateful<Div>),
}
@@ -1493,43 +1541,3 @@ impl RenderedText {
.find(|link| link.source_range.contains(&source_index))
}
}
/// Some markdown blocks are indented, and others have e.g. ```rust … ``` around them.
/// If this block is fenced with backticks, strip them off (and the language name).
/// We use this when copying code blocks to the clipboard.
pub fn without_fences(mut markdown: &str) -> &str {
if let Some(opening_backticks) = markdown.find("```") {
markdown = &markdown[opening_backticks..];
// Trim off the next newline. This also trims off a language name if it's there.
if let Some(newline) = markdown.find('\n') {
markdown = &markdown[newline + 1..];
}
};
if let Some(closing_backticks) = markdown.rfind("```") {
markdown = &markdown[..closing_backticks];
};
markdown
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_without_fences() {
let input = "```rust\nlet x = 5;\n```";
assert_eq!(without_fences(input), "let x = 5;\n");
let input = " ```\nno language\n``` ";
assert_eq!(without_fences(input), "no language\n");
let input = "plain text";
assert_eq!(without_fences(input), "plain text");
let input = "```python\nprint('hello')\nprint('world')\n```";
assert_eq!(without_fences(input), "print('hello')\nprint('world')\n");
}
}

View File

@@ -65,11 +65,33 @@ pub fn parse_markdown(
within_metadata = true;
MarkdownTag::MetadataBlock(kind)
}
pulldown_cmark::Tag::CodeBlock(pulldown_cmark::CodeBlockKind::Indented) => {
MarkdownTag::CodeBlock {
kind: CodeBlockKind::Indented,
metadata: CodeBlockMetadata {
content_range: range.start + 1..range.end + 1,
line_count: 1,
},
}
}
pulldown_cmark::Tag::CodeBlock(pulldown_cmark::CodeBlockKind::Fenced(
ref info,
)) => {
let content_range = extract_code_block_content_range(&text[range.clone()]);
let content_range =
content_range.start + range.start..content_range.end + range.start;
let line_count = text[content_range.clone()]
.bytes()
.filter(|c| *c == b'\n')
.count();
let metadata = CodeBlockMetadata {
content_range,
line_count,
};
let info = info.trim();
MarkdownTag::CodeBlock(if info.is_empty() {
let kind = if info.is_empty() {
CodeBlockKind::Fenced
// Languages should never contain a slash, and PathRanges always should.
// (Models are told to specify them relative to a workspace root.)
@@ -81,9 +103,68 @@ pub fn parse_markdown(
let language = SharedString::from(info.to_string());
language_names.insert(language.clone());
CodeBlockKind::FencedLang(language)
})
};
MarkdownTag::CodeBlock { kind, metadata }
}
pulldown_cmark::Tag::Paragraph => MarkdownTag::Paragraph,
pulldown_cmark::Tag::Heading {
level,
id,
classes,
attrs,
} => {
let id = id.map(|id| SharedString::from(id.into_string()));
let classes = classes
.into_iter()
.map(|c| SharedString::from(c.into_string()))
.collect();
let attrs = attrs
.into_iter()
.map(|(key, value)| {
(
SharedString::from(key.into_string()),
value.map(|v| SharedString::from(v.into_string())),
)
})
.collect();
MarkdownTag::Heading {
level,
id,
classes,
attrs,
}
}
pulldown_cmark::Tag::BlockQuote(_kind) => MarkdownTag::BlockQuote,
pulldown_cmark::Tag::List(start_number) => MarkdownTag::List(start_number),
pulldown_cmark::Tag::Item => MarkdownTag::Item,
pulldown_cmark::Tag::FootnoteDefinition(label) => {
MarkdownTag::FootnoteDefinition(SharedString::from(label.to_string()))
}
pulldown_cmark::Tag::Table(alignments) => MarkdownTag::Table(alignments),
pulldown_cmark::Tag::TableHead => MarkdownTag::TableHead,
pulldown_cmark::Tag::TableRow => MarkdownTag::TableRow,
pulldown_cmark::Tag::TableCell => MarkdownTag::TableCell,
pulldown_cmark::Tag::Emphasis => MarkdownTag::Emphasis,
pulldown_cmark::Tag::Strong => MarkdownTag::Strong,
pulldown_cmark::Tag::Strikethrough => MarkdownTag::Strikethrough,
pulldown_cmark::Tag::Image {
link_type,
dest_url,
title,
id,
} => MarkdownTag::Image {
link_type,
dest_url: SharedString::from(dest_url.into_string()),
title: SharedString::from(title.into_string()),
id: SharedString::from(id.into_string()),
},
pulldown_cmark::Tag::HtmlBlock => MarkdownTag::HtmlBlock,
pulldown_cmark::Tag::DefinitionList => MarkdownTag::DefinitionList,
pulldown_cmark::Tag::DefinitionListTitle => MarkdownTag::DefinitionListTitle,
pulldown_cmark::Tag::DefinitionListDefinition => {
MarkdownTag::DefinitionListDefinition
}
tag => tag.into(),
};
events.push((range, MarkdownEvent::Start(tag)))
}
@@ -252,7 +333,10 @@ pub enum MarkdownTag {
BlockQuote,
/// A code block.
CodeBlock(CodeBlockKind),
CodeBlock {
kind: CodeBlockKind,
metadata: CodeBlockMetadata,
},
/// A HTML block.
HtmlBlock,
@@ -323,96 +407,26 @@ pub enum CodeBlockKind {
FencedSrc(PathWithRange),
}
impl From<pulldown_cmark::Tag<'_>> for MarkdownTag {
fn from(tag: pulldown_cmark::Tag) -> Self {
match tag {
pulldown_cmark::Tag::Paragraph => MarkdownTag::Paragraph,
pulldown_cmark::Tag::Heading {
level,
id,
classes,
attrs,
} => {
let id = id.map(|id| SharedString::from(id.into_string()));
let classes = classes
.into_iter()
.map(|c| SharedString::from(c.into_string()))
.collect();
let attrs = attrs
.into_iter()
.map(|(key, value)| {
(
SharedString::from(key.into_string()),
value.map(|v| SharedString::from(v.into_string())),
)
})
.collect();
MarkdownTag::Heading {
level,
id,
classes,
attrs,
}
}
pulldown_cmark::Tag::BlockQuote(_kind) => MarkdownTag::BlockQuote,
pulldown_cmark::Tag::CodeBlock(kind) => match kind {
pulldown_cmark::CodeBlockKind::Indented => {
MarkdownTag::CodeBlock(CodeBlockKind::Indented)
}
pulldown_cmark::CodeBlockKind::Fenced(info) => {
let info = info.trim();
MarkdownTag::CodeBlock(if info.is_empty() {
CodeBlockKind::Fenced
} else if info.contains('/') {
// Languages should never contain a slash, and PathRanges always should.
// (Models are told to specify them relative to a workspace root.)
CodeBlockKind::FencedSrc(PathWithRange::new(info))
} else {
CodeBlockKind::FencedLang(SharedString::from(info.to_string()))
})
}
},
pulldown_cmark::Tag::List(start_number) => MarkdownTag::List(start_number),
pulldown_cmark::Tag::Item => MarkdownTag::Item,
pulldown_cmark::Tag::FootnoteDefinition(label) => {
MarkdownTag::FootnoteDefinition(SharedString::from(label.to_string()))
}
pulldown_cmark::Tag::Table(alignments) => MarkdownTag::Table(alignments),
pulldown_cmark::Tag::TableHead => MarkdownTag::TableHead,
pulldown_cmark::Tag::TableRow => MarkdownTag::TableRow,
pulldown_cmark::Tag::TableCell => MarkdownTag::TableCell,
pulldown_cmark::Tag::Emphasis => MarkdownTag::Emphasis,
pulldown_cmark::Tag::Strong => MarkdownTag::Strong,
pulldown_cmark::Tag::Strikethrough => MarkdownTag::Strikethrough,
pulldown_cmark::Tag::Link {
link_type,
dest_url,
title,
id,
} => MarkdownTag::Link {
link_type,
dest_url: SharedString::from(dest_url.into_string()),
title: SharedString::from(title.into_string()),
id: SharedString::from(id.into_string()),
},
pulldown_cmark::Tag::Image {
link_type,
dest_url,
title,
id,
} => MarkdownTag::Image {
link_type,
dest_url: SharedString::from(dest_url.into_string()),
title: SharedString::from(title.into_string()),
id: SharedString::from(id.into_string()),
},
pulldown_cmark::Tag::HtmlBlock => MarkdownTag::HtmlBlock,
pulldown_cmark::Tag::MetadataBlock(kind) => MarkdownTag::MetadataBlock(kind),
pulldown_cmark::Tag::DefinitionList => MarkdownTag::DefinitionList,
pulldown_cmark::Tag::DefinitionListTitle => MarkdownTag::DefinitionListTitle,
pulldown_cmark::Tag::DefinitionListDefinition => MarkdownTag::DefinitionListDefinition,
#[derive(Default, Clone, Debug, PartialEq)]
pub struct CodeBlockMetadata {
pub content_range: Range<usize>,
pub line_count: usize,
}
pub(crate) fn extract_code_block_content_range(text: &str) -> Range<usize> {
let mut range = 0..text.len();
if text.starts_with("```") {
range.start += 3;
if let Some(newline_ix) = text[range.clone()].find('\n') {
range.start += newline_ix + 1;
}
}
if !range.is_empty() && text.ends_with("```") {
range.end -= 3;
}
range
}
/// Represents either an owned or inline string. Motivation for this is to make `SubstitutedText`
@@ -570,4 +584,41 @@ mod tests {
)
)
}
#[test]
fn test_code_block_metadata() {
assert_eq!(
parse_markdown("```rust\nfn main() {\n let a = 1;\n}\n```"),
(
vec![
(
0..37,
Start(CodeBlock {
kind: CodeBlockKind::FencedLang("rust".into()),
metadata: CodeBlockMetadata {
content_range: 8..34,
line_count: 3
}
})
),
(8..34, Text),
(0..37, End(MarkdownTagEnd::CodeBlock)),
],
HashSet::from(["rust".into()]),
HashSet::new()
)
)
}
#[test]
fn test_extract_code_block_content_range() {
let input = "```rust\nlet x = 5;\n```";
assert_eq!(extract_code_block_content_range(input), 8..19);
let input = "plain text";
assert_eq!(extract_code_block_content_range(input), 0..10);
let input = "```python\nprint('hello')\nprint('world')\n```";
assert_eq!(extract_code_block_content_range(input), 10..40);
}
}

View File

@@ -37,3 +37,9 @@ pub(crate) mod m_2025_03_29 {
pub(crate) use settings::SETTINGS_PATTERNS;
}
pub(crate) mod m_2025_04_15 {
mod settings;
pub(crate) use settings::SETTINGS_PATTERNS;
}

View File

@@ -0,0 +1,29 @@
use std::ops::Range;
use tree_sitter::{Query, QueryMatch};
use crate::MigrationPatterns;
use crate::patterns::SETTINGS_ASSISTANT_TOOLS_PATTERN;
pub const SETTINGS_PATTERNS: MigrationPatterns = &[(
SETTINGS_ASSISTANT_TOOLS_PATTERN,
replace_bash_with_terminal_in_profiles,
)];
fn replace_bash_with_terminal_in_profiles(
contents: &str,
mat: &QueryMatch,
query: &Query,
) -> Option<(Range<usize>, String)> {
let tool_name_capture_ix = query.capture_index_for_name("tool_name")?;
let tool_name_range = mat
.nodes_for_capture_index(tool_name_capture_ix)
.next()?
.byte_range();
let tool_name = contents.get(tool_name_range.clone())?;
if tool_name != "bash" {
return None;
}
Some((tool_name_range, "terminal".to_string()))
}

View File

@@ -120,6 +120,10 @@ pub fn migrate_settings(text: &str) -> Result<Option<String>> {
migrations::m_2025_03_29::SETTINGS_PATTERNS,
&SETTINGS_QUERY_2025_03_29,
),
(
migrations::m_2025_04_15::SETTINGS_PATTERNS,
&SETTINGS_QUERY_2025_04_15,
),
];
run_migrations(text, migrations)
}
@@ -190,6 +194,10 @@ define_query!(
SETTINGS_QUERY_2025_03_29,
migrations::m_2025_03_29::SETTINGS_PATTERNS
);
define_query!(
SETTINGS_QUERY_2025_04_15,
migrations::m_2025_04_15::SETTINGS_PATTERNS
);
// custom query
static EDIT_PREDICTION_SETTINGS_MIGRATION_QUERY: LazyLock<Query> = LazyLock::new(|| {
@@ -527,4 +535,103 @@ mod tests {
),
)
}
#[test]
fn test_replace_bash_with_terminal_in_profiles() {
assert_migrate_settings(
r#"
{
"assistant": {
"profiles": {
"custom": {
"name": "Custom",
"tools": {
"bash": true,
"diagnostics": true
}
}
}
}
}
"#,
Some(
r#"
{
"assistant": {
"profiles": {
"custom": {
"name": "Custom",
"tools": {
"terminal": true,
"diagnostics": true
}
}
}
}
}
"#,
),
)
}
#[test]
fn test_replace_bash_false_with_terminal_in_profiles() {
assert_migrate_settings(
r#"
{
"assistant": {
"profiles": {
"custom": {
"name": "Custom",
"tools": {
"bash": false,
"diagnostics": true
}
}
}
}
}
"#,
Some(
r#"
{
"assistant": {
"profiles": {
"custom": {
"name": "Custom",
"tools": {
"terminal": false,
"diagnostics": true
}
}
}
}
}
"#,
),
)
}
#[test]
fn test_no_bash_in_profiles() {
assert_migrate_settings(
r#"
{
"assistant": {
"profiles": {
"custom": {
"name": "Custom",
"tools": {
"diagnostics": true,
"path_search": true,
"read_file": true
}
}
}
}
}
"#,
None,
)
}
}

View File

@@ -7,5 +7,6 @@ pub(crate) use keymap::{
};
pub(crate) use settings::{
SETTINGS_LANGUAGES_PATTERN, SETTINGS_NESTED_KEY_VALUE_PATTERN, SETTINGS_ROOT_KEY_VALUE_PATTERN,
SETTINGS_ASSISTANT_TOOLS_PATTERN, SETTINGS_LANGUAGES_PATTERN,
SETTINGS_NESTED_KEY_VALUE_PATTERN, SETTINGS_ROOT_KEY_VALUE_PATTERN,
};

View File

@@ -39,3 +39,35 @@ pub const SETTINGS_LANGUAGES_PATTERN: &str = r#"(document
)
(#eq? @languages "languages")
)"#;
pub const SETTINGS_ASSISTANT_TOOLS_PATTERN: &str = r#"(document
(object
(pair
key: (string (string_content) @assistant)
value: (object
(pair
key: (string (string_content) @profiles)
value: (object
(pair
key: (_)
value: (object
(pair
key: (string (string_content) @tools_key)
value: (object
(pair
key: (string (string_content) @tool_name)
value: (_) @tool_value
)
)
)
)
)
)
)
)
)
)
(#eq? @assistant "assistant")
(#eq? @profiles "profiles")
(#eq? @tools_key "tools")
)"#;

View File

@@ -312,7 +312,7 @@ impl ManagedNodeRuntime {
let version = Self::VERSION;
let folder_name = format!("node-{version}-{os}-{arch}");
let node_containing_dir = paths::support_dir().join("node");
let node_containing_dir = paths::data_dir().join("node");
let node_dir = node_containing_dir.join(folder_name);
let node_binary = node_dir.join(Self::NODE_PATH);
let npm_file = node_dir.join(Self::NPM_PATH);
@@ -498,7 +498,7 @@ impl SystemNodeRuntime {
)
}
let scratch_dir = paths::support_dir().join("node");
let scratch_dir = paths::data_dir().join("node");
fs::create_dir(&scratch_dir).await.ok();
fs::create_dir(scratch_dir.join("cache")).await.ok();

View File

@@ -5,61 +5,109 @@ use std::sync::OnceLock;
pub use util::paths::home_dir;
/// A default editorconfig file name to use when resolving project settings.
pub const EDITORCONFIG_NAME: &str = ".editorconfig";
/// A custom data directory override, set only by `set_custom_data_dir`.
/// This is used to override the default data directory location.
/// The directory will be created if it doesn't exist when set.
static CUSTOM_DATA_DIR: OnceLock<PathBuf> = OnceLock::new();
/// The resolved data directory, combining custom override or platform defaults.
/// This is set once and cached for subsequent calls.
/// On macOS, this is `~/Library/Application Support/Zed`.
/// On Linux/FreeBSD, this is `$XDG_DATA_HOME/zed`.
/// On Windows, this is `%LOCALAPPDATA%\Zed`.
static CURRENT_DATA_DIR: OnceLock<PathBuf> = OnceLock::new();
/// The resolved config directory, combining custom override or platform defaults.
/// This is set once and cached for subsequent calls.
/// On macOS, this is `~/.config/zed`.
/// On Linux/FreeBSD, this is `$XDG_CONFIG_HOME/zed`.
/// On Windows, this is `%APPDATA%\Zed`.
static CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
/// Returns the relative path to the zed_server directory on the ssh host.
pub fn remote_server_dir_relative() -> &'static Path {
Path::new(".zed_server")
}
/// Sets a custom directory for all user data, overriding the default data directory.
/// This function must be called before any other path operations that depend on the data directory.
/// The directory will be created if it doesn't exist.
///
/// # Arguments
///
/// * `dir` - The path to use as the custom data directory. This will be used as the base
/// directory for all user data, including databases, extensions, and logs.
///
/// # Returns
///
/// A reference to the static `PathBuf` containing the custom data directory path.
///
/// # Panics
///
/// Panics if:
/// * Called after the data directory has been initialized (e.g., via `data_dir` or `config_dir`)
/// * The directory cannot be created
pub fn set_custom_data_dir(dir: &str) -> &'static PathBuf {
if CURRENT_DATA_DIR.get().is_some() || CONFIG_DIR.get().is_some() {
panic!("set_custom_data_dir called after data_dir or config_dir was initialized");
}
CUSTOM_DATA_DIR.get_or_init(|| {
let path = PathBuf::from(dir);
std::fs::create_dir_all(&path).expect("failed to create custom data directory");
path
})
}
/// Returns the path to the configuration directory used by Zed.
pub fn config_dir() -> &'static PathBuf {
static CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
CONFIG_DIR.get_or_init(|| {
if cfg!(target_os = "windows") {
return dirs::config_dir()
if let Some(custom_dir) = CUSTOM_DATA_DIR.get() {
custom_dir.join("config")
} else if cfg!(target_os = "windows") {
dirs::config_dir()
.expect("failed to determine RoamingAppData directory")
.join("Zed");
}
if cfg!(any(target_os = "linux", target_os = "freebsd")) {
return if let Ok(flatpak_xdg_config) = std::env::var("FLATPAK_XDG_CONFIG_HOME") {
.join("Zed")
} else if cfg!(any(target_os = "linux", target_os = "freebsd")) {
if let Ok(flatpak_xdg_config) = std::env::var("FLATPAK_XDG_CONFIG_HOME") {
flatpak_xdg_config.into()
} else {
dirs::config_dir().expect("failed to determine XDG_CONFIG_HOME directory")
dirs::config_dir()
.expect("failed to determine XDG_CONFIG_HOME directory")
.join("zed")
}
.join("zed");
} else {
home_dir().join(".config").join("zed")
}
home_dir().join(".config").join("zed")
})
}
/// Returns the path to the support directory used by Zed.
pub fn support_dir() -> &'static PathBuf {
static SUPPORT_DIR: OnceLock<PathBuf> = OnceLock::new();
SUPPORT_DIR.get_or_init(|| {
if cfg!(target_os = "macos") {
return home_dir().join("Library/Application Support/Zed");
}
if cfg!(any(target_os = "linux", target_os = "freebsd")) {
return if let Ok(flatpak_xdg_data) = std::env::var("FLATPAK_XDG_DATA_HOME") {
/// Returns the path to the data directory used by Zed.
pub fn data_dir() -> &'static PathBuf {
CURRENT_DATA_DIR.get_or_init(|| {
if let Some(custom_dir) = CUSTOM_DATA_DIR.get() {
custom_dir.clone()
} else if cfg!(target_os = "macos") {
home_dir().join("Library/Application Support/Zed")
} else if cfg!(any(target_os = "linux", target_os = "freebsd")) {
if let Ok(flatpak_xdg_data) = std::env::var("FLATPAK_XDG_DATA_HOME") {
flatpak_xdg_data.into()
} else {
dirs::data_local_dir().expect("failed to determine XDG_DATA_HOME directory")
dirs::data_local_dir()
.expect("failed to determine XDG_DATA_HOME directory")
.join("zed")
}
.join("zed");
}
if cfg!(target_os = "windows") {
return dirs::data_local_dir()
} else if cfg!(target_os = "windows") {
dirs::data_local_dir()
.expect("failed to determine LocalAppData directory")
.join("Zed");
.join("Zed")
} else {
config_dir().clone() // Fallback
}
config_dir().clone()
})
}
/// Returns the path to the temp directory used by Zed.
pub fn temp_dir() -> &'static PathBuf {
static TEMP_DIR: OnceLock<PathBuf> = OnceLock::new();
@@ -96,7 +144,7 @@ pub fn logs_dir() -> &'static PathBuf {
if cfg!(target_os = "macos") {
home_dir().join("Library/Logs/Zed")
} else {
support_dir().join("logs")
data_dir().join("logs")
}
})
}
@@ -104,7 +152,7 @@ pub fn logs_dir() -> &'static PathBuf {
/// Returns the path to the Zed server directory on this SSH host.
pub fn remote_server_state_dir() -> &'static PathBuf {
static REMOTE_SERVER_STATE: OnceLock<PathBuf> = OnceLock::new();
REMOTE_SERVER_STATE.get_or_init(|| support_dir().join("server_state"))
REMOTE_SERVER_STATE.get_or_init(|| data_dir().join("server_state"))
}
/// Returns the path to the `Zed.log` file.
@@ -122,7 +170,7 @@ pub fn old_log_file() -> &'static PathBuf {
/// Returns the path to the database directory.
pub fn database_dir() -> &'static PathBuf {
static DATABASE_DIR: OnceLock<PathBuf> = OnceLock::new();
DATABASE_DIR.get_or_init(|| support_dir().join("db"))
DATABASE_DIR.get_or_init(|| data_dir().join("db"))
}
/// Returns the path to the crashes directory, if it exists for the current platform.
@@ -180,7 +228,7 @@ pub fn debug_tasks_file() -> &'static PathBuf {
/// This is where installed extensions are stored.
pub fn extensions_dir() -> &'static PathBuf {
static EXTENSIONS_DIR: OnceLock<PathBuf> = OnceLock::new();
EXTENSIONS_DIR.get_or_init(|| support_dir().join("extensions"))
EXTENSIONS_DIR.get_or_init(|| data_dir().join("extensions"))
}
/// Returns the path to the extensions directory.
@@ -188,7 +236,7 @@ pub fn extensions_dir() -> &'static PathBuf {
/// This is where installed extensions are stored on a remote.
pub fn remote_extensions_dir() -> &'static PathBuf {
static EXTENSIONS_DIR: OnceLock<PathBuf> = OnceLock::new();
EXTENSIONS_DIR.get_or_init(|| support_dir().join("remote_extensions"))
EXTENSIONS_DIR.get_or_init(|| data_dir().join("remote_extensions"))
}
/// Returns the path to the extensions directory.
@@ -222,7 +270,7 @@ pub fn contexts_dir() -> &'static PathBuf {
if cfg!(target_os = "macos") {
config_dir().join("conversations")
} else {
support_dir().join("conversations")
data_dir().join("conversations")
}
})
}
@@ -236,7 +284,7 @@ pub fn prompts_dir() -> &'static PathBuf {
if cfg!(target_os = "macos") {
config_dir().join("prompts")
} else {
support_dir().join("prompts")
data_dir().join("prompts")
}
})
}
@@ -262,7 +310,7 @@ pub fn prompt_overrides_dir(repo_path: Option<&Path>) -> PathBuf {
if cfg!(target_os = "macos") {
config_dir().join("prompt_overrides")
} else {
support_dir().join("prompt_overrides")
data_dir().join("prompt_overrides")
}
})
.clone()
@@ -277,7 +325,7 @@ pub fn embeddings_dir() -> &'static PathBuf {
if cfg!(target_os = "macos") {
config_dir().join("embeddings")
} else {
support_dir().join("embeddings")
data_dir().join("embeddings")
}
})
}
@@ -287,7 +335,7 @@ pub fn embeddings_dir() -> &'static PathBuf {
/// This is where language servers are downloaded to for languages built-in to Zed.
pub fn languages_dir() -> &'static PathBuf {
static LANGUAGES_DIR: OnceLock<PathBuf> = OnceLock::new();
LANGUAGES_DIR.get_or_init(|| support_dir().join("languages"))
LANGUAGES_DIR.get_or_init(|| data_dir().join("languages"))
}
/// Returns the path to the debug adapters directory
@@ -295,31 +343,31 @@ pub fn languages_dir() -> &'static PathBuf {
/// This is where debug adapters are downloaded to for DAPs that are built-in to Zed.
pub fn debug_adapters_dir() -> &'static PathBuf {
static DEBUG_ADAPTERS_DIR: OnceLock<PathBuf> = OnceLock::new();
DEBUG_ADAPTERS_DIR.get_or_init(|| support_dir().join("debug_adapters"))
DEBUG_ADAPTERS_DIR.get_or_init(|| data_dir().join("debug_adapters"))
}
/// Returns the path to the Copilot directory.
pub fn copilot_dir() -> &'static PathBuf {
static COPILOT_DIR: OnceLock<PathBuf> = OnceLock::new();
COPILOT_DIR.get_or_init(|| support_dir().join("copilot"))
COPILOT_DIR.get_or_init(|| data_dir().join("copilot"))
}
/// Returns the path to the Supermaven directory.
pub fn supermaven_dir() -> &'static PathBuf {
static SUPERMAVEN_DIR: OnceLock<PathBuf> = OnceLock::new();
SUPERMAVEN_DIR.get_or_init(|| support_dir().join("supermaven"))
SUPERMAVEN_DIR.get_or_init(|| data_dir().join("supermaven"))
}
/// Returns the path to the default Prettier directory.
pub fn default_prettier_dir() -> &'static PathBuf {
static DEFAULT_PRETTIER_DIR: OnceLock<PathBuf> = OnceLock::new();
DEFAULT_PRETTIER_DIR.get_or_init(|| support_dir().join("prettier"))
DEFAULT_PRETTIER_DIR.get_or_init(|| data_dir().join("prettier"))
}
/// Returns the path to the remote server binaries directory.
pub fn remote_servers_dir() -> &'static PathBuf {
static REMOTE_SERVERS_DIR: OnceLock<PathBuf> = OnceLock::new();
REMOTE_SERVERS_DIR.get_or_init(|| support_dir().join("remote_servers"))
REMOTE_SERVERS_DIR.get_or_init(|| data_dir().join("remote_servers"))
}
/// Returns the relative path to a `.zed` folder within a project.
@@ -359,6 +407,3 @@ pub fn local_debug_file_relative_path() -> &'static Path {
pub fn local_vscode_launch_file_relative_path() -> &'static Path {
Path::new(".vscode/launch.json")
}
/// A default editorconfig file name to use when resolving project settings.
pub const EDITORCONFIG_NAME: &str = ".editorconfig";

View File

@@ -68,7 +68,7 @@ impl ProjectEnvironment {
}
if let Some(cli_environment) = self.get_cli_environment() {
log::info!("using project environment variables from CLI");
log::debug!("using project environment variables from CLI");
return Task::ready(Some(cli_environment)).shared();
}
@@ -94,7 +94,7 @@ impl ProjectEnvironment {
}
if let Some(cli_environment) = self.get_cli_environment() {
log::info!("using project environment variables from CLI");
log::debug!("using project environment variables from CLI");
return Task::ready(Some(cli_environment)).shared();
}
@@ -128,7 +128,7 @@ impl ProjectEnvironment {
}
if let Some(cli_environment) = self.get_cli_environment() {
log::info!("using project environment variables from CLI");
log::debug!("using project environment variables from CLI");
return Task::ready(Some(cli_environment)).shared();
}

View File

@@ -2630,9 +2630,7 @@ impl RepositorySnapshot {
}
pub fn has_conflict(&self, repo_path: &RepoPath) -> bool {
self.statuses_by_path
.get(&PathKey(repo_path.0.clone()), &())
.map_or(false, |entry| entry.status.is_conflicted())
self.merge_conflicts.contains(repo_path)
}
/// This is the name that will be displayed in the repository selector for this repository.

View File

@@ -334,6 +334,10 @@ impl ProjectPath {
path: Path::new("").into(),
}
}
pub fn starts_with(&self, other: &ProjectPath) -> bool {
self.worktree_id == other.worktree_id && self.path.starts_with(&other.path)
}
}
#[derive(Debug, Default)]

View File

@@ -14,35 +14,41 @@ use std::{
time::Duration,
};
use text::LineEnding;
use util::ResultExt;
use util::{ResultExt, get_system_shell};
#[derive(Serialize)]
pub struct AssistantSystemPromptContext {
pub worktrees: Vec<WorktreeInfoForSystemPrompt>,
#[derive(Debug, Clone, Serialize)]
pub struct ProjectContext {
pub worktrees: Vec<WorktreeContext>,
pub has_rules: bool,
pub os: String,
pub arch: String,
pub shell: String,
}
impl AssistantSystemPromptContext {
pub fn new(worktrees: Vec<WorktreeInfoForSystemPrompt>) -> Self {
impl ProjectContext {
pub fn new(worktrees: Vec<WorktreeContext>) -> Self {
let has_rules = worktrees
.iter()
.any(|worktree| worktree.rules_file.is_some());
Self {
worktrees,
has_rules,
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
shell: get_system_shell(),
}
}
}
#[derive(Serialize)]
pub struct WorktreeInfoForSystemPrompt {
#[derive(Debug, Clone, Serialize)]
pub struct WorktreeContext {
pub root_name: String,
pub abs_path: Arc<Path>,
pub rules_file: Option<SystemPromptRulesFile>,
pub rules_file: Option<RulesFileContext>,
}
#[derive(Serialize)]
pub struct SystemPromptRulesFile {
#[derive(Debug, Clone, Serialize)]
pub struct RulesFileContext {
pub path_in_worktree: Arc<Path>,
pub abs_path: Arc<Path>,
pub text: String,
@@ -254,7 +260,7 @@ impl PromptBuilder {
pub fn generate_assistant_system_prompt(
&self,
context: &AssistantSystemPromptContext,
context: &ProjectContext,
) -> Result<String, RenderError> {
self.handlebars
.lock()

View File

@@ -467,7 +467,7 @@ impl ShellBuilder {
// `alacritty_terminal` uses this as default on Windows. See:
// https://github.com/alacritty/alacritty/blob/0d4ab7bca43213d96ddfe40048fc0f922543c6f8/alacritty_terminal/src/tty/windows/mod.rs#L130
// We could use `util::retrieve_system_shell()` here, but we are running tasks here, so leave it to `powershell.exe`
// We could use `util::get_windows_system_shell()` here, but we are running tasks here, so leave it to `powershell.exe`
// should be okay.
fn system_shell() -> String {
"powershell.exe".to_string()

View File

@@ -380,7 +380,7 @@ impl TerminalBuilder {
#[cfg(target_os = "windows")]
{
Some(alacritty_terminal::tty::Shell::new(
util::retrieve_system_shell(),
util::get_windows_system_shell(),
Vec::new(),
))
}

View File

@@ -22,6 +22,7 @@ mod notification;
mod numeric_stepper;
mod popover;
mod popover_menu;
mod progress;
mod radio;
mod right_click_menu;
mod scrollbar;
@@ -61,6 +62,7 @@ pub use notification::*;
pub use numeric_stepper::*;
pub use popover::*;
pub use popover_menu::*;
pub use progress::*;
pub use radio::*;
pub use right_click_menu::*;
pub use scrollbar::*;

View File

@@ -267,7 +267,7 @@ impl RenderOnce for IconWithIndicator {
impl Component for Icon {
fn scope() -> ComponentScope {
ComponentScope::None
ComponentScope::Images
}
fn description() -> Option<&'static str> {

View File

@@ -26,7 +26,7 @@ impl RenderOnce for DecoratedIcon {
impl Component for DecoratedIcon {
fn scope() -> ComponentScope {
ComponentScope::None
ComponentScope::Images
}
fn description() -> Option<&'static str> {

View File

@@ -199,7 +199,7 @@ impl RenderOnce for Label {
impl Component for Label {
fn scope() -> ComponentScope {
ComponentScope::None
ComponentScope::Typography
}
fn description() -> Option<&'static str> {

View File

@@ -0,0 +1,2 @@
mod progress_bar;
pub use progress_bar::*;

View File

@@ -0,0 +1,159 @@
use documented::Documented;
use gpui::{Hsla, point};
use crate::components::Label;
use crate::prelude::*;
/// A progress bar is a horizontal bar that communicates the status of a process.
///
/// A progress bar should not be used to represent indeterminate progress.
#[derive(RegisterComponent, Documented)]
pub struct ProgressBar {
id: ElementId,
value: f32,
max_value: f32,
bg_color: Hsla,
fg_color: Hsla,
}
impl ProgressBar {
/// Create a new progress bar with the given value and maximum value.
pub fn new(
id: impl Into<ElementId>,
value: f32,
max_value: f32,
cx: &mut Context<Self>,
) -> Self {
Self {
id: id.into(),
value,
max_value,
bg_color: cx.theme().colors().background,
fg_color: cx.theme().status().info,
}
}
/// Set the current value of the progress bar.
pub fn value(&mut self, value: f32) -> &mut Self {
self.value = value;
self
}
/// Set the maximum value of the progress bar.
pub fn max_value(&mut self, max_value: f32) -> &mut Self {
self.max_value = max_value;
self
}
/// Set the background color of the progress bar.
pub fn bg_color(&mut self, color: Hsla) -> &mut Self {
self.bg_color = color;
self
}
/// Set the foreground color of the progress bar.
pub fn fg_color(&mut self, color: Hsla) -> &mut Self {
self.fg_color = color;
self
}
}
impl Render for ProgressBar {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
let fill_width = (self.value / self.max_value).clamp(0.02, 1.0);
div()
.id(self.id.clone())
.w_full()
.h(px(8.0))
.rounded_full()
.py(px(2.0))
.px(px(4.0))
.bg(self.bg_color)
.shadow(smallvec::smallvec![gpui::BoxShadow {
color: gpui::black().opacity(0.08),
offset: point(px(0.), px(1.)),
blur_radius: px(0.),
spread_radius: px(0.),
}])
.child(
div()
.h_full()
.rounded_full()
.bg(self.fg_color)
.w(relative(fill_width)),
)
}
}
impl Component for ProgressBar {
fn scope() -> ComponentScope {
ComponentScope::Status
}
fn description() -> Option<&'static str> {
Some(Self::DOCS)
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let max_value = 180.0;
let empty_progress_bar = cx.new(|cx| ProgressBar::new("empty", 0.0, max_value, cx));
let partial_progress_bar =
cx.new(|cx| ProgressBar::new("partial", max_value * 0.35, max_value, cx));
let filled_progress_bar = cx.new(|cx| ProgressBar::new("filled", max_value, max_value, cx));
Some(
div()
.flex()
.flex_col()
.gap_4()
.p_4()
.w(px(240.0))
.child(div().child("Progress Bar"))
.child(
div()
.flex()
.flex_col()
.gap_2()
.child(
div()
.flex()
.justify_between()
.child(Label::new("0%"))
.child(Label::new("Empty")),
)
.child(empty_progress_bar.clone()),
)
.child(
div()
.flex()
.flex_col()
.gap_2()
.child(
div()
.flex()
.justify_between()
.child(Label::new("38%"))
.child(Label::new("Partial")),
)
.child(partial_progress_bar.clone()),
)
.child(
div()
.flex()
.flex_col()
.gap_2()
.child(
div()
.flex()
.justify_between()
.child(Label::new("100%"))
.child(Label::new("Complete")),
)
.child(filled_progress_bar.clone()),
)
.into_any_element(),
)
}
}

View File

@@ -235,7 +235,7 @@ impl Headline {
impl Component for Headline {
fn scope() -> ComponentScope {
ComponentScope::None
ComponentScope::Typography
}
fn description() -> Option<&'static str> {

View File

@@ -477,7 +477,7 @@ pub fn iterate_expanded_and_wrapped_usize_range(
}
#[cfg(target_os = "windows")]
pub fn retrieve_system_shell() -> String {
pub fn get_windows_system_shell() -> String {
use std::path::PathBuf;
fn find_pwsh_in_programfiles(find_alternate: bool, find_preview: bool) -> Option<PathBuf> {
@@ -994,6 +994,18 @@ pub fn default<D: Default>() -> D {
Default::default()
}
pub fn get_system_shell() -> String {
#[cfg(target_os = "windows")]
{
get_windows_system_shell()
}
#[cfg(not(target_os = "windows"))]
{
std::env::var("SHELL").unwrap_or("/bin/sh".to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1172,6 +1172,31 @@ impl Worktree {
pub fn is_single_file(&self) -> bool {
self.root_dir().is_none()
}
/// For visible worktrees, returns the path with the worktree name as the first component.
/// Otherwise, returns an absolute path.
pub fn full_path(&self, worktree_relative_path: &Path) -> PathBuf {
let mut full_path = PathBuf::new();
if self.is_visible() {
full_path.push(self.root_name());
} else {
let path = self.abs_path();
if self.is_local() && path.starts_with(home_dir().as_path()) {
full_path.push("~");
full_path.push(path.strip_prefix(home_dir().as_path()).unwrap());
} else {
full_path.push(path)
}
}
if worktree_relative_path.components().next().is_some() {
full_path.push(&worktree_relative_path);
}
full_path
}
}
impl LocalWorktree {
@@ -3229,27 +3254,7 @@ impl language::File for File {
}
fn full_path(&self, cx: &App) -> PathBuf {
let mut full_path = PathBuf::new();
let worktree = self.worktree.read(cx);
if worktree.is_visible() {
full_path.push(worktree.root_name());
} else {
let path = worktree.abs_path();
if worktree.is_local() && path.starts_with(home_dir().as_path()) {
full_path.push("~");
full_path.push(path.strip_prefix(home_dir().as_path()).unwrap());
} else {
full_path.push(path)
}
}
if self.path.components().next().is_some() {
full_path.push(&self.path);
}
full_path
self.worktree.read(cx).full_path(&self.path)
}
/// Returns the last component of this handle's absolute path. If this handle refers to the root

View File

@@ -2,7 +2,7 @@
description = "The fast, collaborative code editor."
edition.workspace = true
name = "zed"
version = "0.182.0"
version = "0.182.3"
publish.workspace = true
license = "GPL-3.0-or-later"
authors = ["Zed Team <hi@zed.dev>"]

View File

@@ -1 +1 @@
dev
preview

View File

@@ -172,6 +172,11 @@ fn fail_to_open_window(e: anyhow::Error, _cx: &mut App) {
fn main() {
let args = Args::parse();
// Set custom data directory.
if let Some(dir) = &args.user_data_dir {
paths::set_custom_data_dir(dir);
}
#[cfg(all(not(debug_assertions), target_os = "windows"))]
unsafe {
use windows::Win32::System::Console::{ATTACH_PARENT_PROCESS, AttachConsole};
@@ -962,6 +967,14 @@ struct Args {
/// URLs can either be `file://` or `zed://` scheme, or relative to <https://zed.dev>.
paths_or_urls: Vec<String>,
/// Sets a custom directory for all user data (e.g., database, extensions, logs).
/// This overrides the default platform-specific data directory location.
/// On macOS, the default is `~/Library/Application Support/Zed`.
/// On Linux/FreeBSD, the default is `$XDG_DATA_HOME/zed`.
/// On Windows, the default is `%LOCALAPPDATA%\Zed`.
#[arg(long, value_name = "DIR")]
user_data_dir: Option<String>,
/// Instructs zed to run as a dev server on this machine. (not implemented)
#[arg(long)]
dev_server_token: Option<String>,

View File

@@ -151,7 +151,7 @@ pub fn listen_for_cli_connections(opener: OpenListener) -> Result<()> {
use release_channel::RELEASE_CHANNEL_NAME;
use std::os::unix::net::UnixDatagram;
let sock_path = paths::support_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL_NAME));
let sock_path = paths::data_dir().join(format!("zed-{}.sock", *RELEASE_CHANNEL_NAME));
// remove the socket if the process listening on it has died
if let Err(e) = UnixDatagram::unbound()?.connect(&sock_path) {
if e.kind() == std::io::ErrorKind::ConnectionRefused {
@@ -261,6 +261,7 @@ pub async fn handle_cli_connection(
wait,
open_new_workspace,
env,
user_data_dir: _, // Ignore user_data_dir
} => {
if !urls.is_empty() {
cx.update(|cx| {

View File

@@ -130,6 +130,7 @@ fn send_args_to_instance(args: &Args) -> anyhow::Result<()> {
wait: false,
open_new_workspace: None,
env: None,
user_data_dir: args.user_data_dir.clone(),
}
};

View File

@@ -498,10 +498,12 @@ impl RateCompletionModal {
cx
))
.on_click(cx.listener(move |this, _, window, cx| {
this.thumbs_down_active(
&ThumbsDownActiveCompletion,
window, cx,
);
if this.active_completion.is_some() {
this.thumbs_down_active(
&ThumbsDownActiveCompletion,
window, cx,
);
}
})),
)
.child(
@@ -517,7 +519,9 @@ impl RateCompletionModal {
cx
))
.on_click(cx.listener(move |this, _, window, cx| {
this.thumbs_up_active(&ThumbsUpActiveCompletion, window, cx);
if this.active_completion.is_some() {
this.thumbs_up_active(&ThumbsUpActiveCompletion, window, cx);
}
})),
),
),

View File

@@ -316,8 +316,8 @@ Where `some-provider` can be any of the following values: `anthropic`, `google`,
### Configuring Models {#default-model}
The default model can be set via the model dropdown in the assistant panel's top-right corner. Selecting a model saves it as the default.
You can also manually edit the `default_model` object in your settings:
Zed's hosted LLM service sets `claude-3-7-sonnet-latest` as the default model.
However, you can change it either via the model dropdown in the Assistant Panel's bottom-left corner or by manually editing the `default_model` object in your settings:
```json
{
@@ -325,7 +325,7 @@ You can also manually edit the `default_model` object in your settings:
"version": "2",
"default_model": {
"provider": "zed.dev",
"model": "claude-3-5-sonnet"
"model": "gpt-4o"
}
}
}

View File

@@ -2998,11 +2998,11 @@ Run the `theme selector: toggle` action in the command palette to see a current
"default_height": 320,
"default_model": {
"provider": "zed.dev",
"model": "claude-3-5-sonnet-latest"
"model": "claude-3-7-sonnet-latest"
},
"editor_model": {
"provider": "zed.dev",
"model": "claude-3-5-sonnet-latest"
"model": "claude-3-7-sonnet-latest"
}
}
```