Compare commits

..

7 Commits

Author SHA1 Message Date
Nate Butler
7f0ffa0109 Add some random static text to test the thread 2025-04-29 17:59:42 -04:00
Nate Butler
95d8409900 make it gooo 2025-04-29 15:41:27 -04:00
Nate Butler
d27fdd96f2 WIP 2025-04-29 15:19:51 -04:00
Nate Butler
c8e909afc6 wip 2025-04-29 14:56:39 -04:00
Nate Butler
abad6d9be9 Use the same property order for ThreadStore::new and load 2025-04-29 13:58:13 -04:00
Nate Butler
9c50d19841 wip 2025-04-29 13:06:19 -04:00
Nate Butler
17b98d068a Update usage banner scope 2025-04-29 11:53:23 -04:00
209 changed files with 2985 additions and 56746 deletions

View File

@@ -69,7 +69,7 @@ jobs:
run: cargo build --package=eval
- name: Run eval
run: cargo run --package=eval -- --repetitions=8 --concurrency=1
run: cargo run --package=eval -- --repetitions=3 --concurrency=1
# Even the Linux runner is not stateful, in theory there is no need to do this cleanup.
# But, to avoid potential issues in the future if we choose to use a stateful Linux runner and forget to add code

View File

@@ -170,6 +170,55 @@ jobs:
- name: Upload Zed Nightly
run: script/upload-nightly linux-targz
bundle-nix:
timeout-minutes: 60
name: (${{ matrix.system.os }}) Nix Build
continue-on-error: true
strategy:
fail-fast: false
matrix:
system:
- os: x86 Linux
runner: buildjet-16vcpu-ubuntu-2204
install_nix: true
- os: arm Mac
runner: [macOS, ARM64, test]
install_nix: false
if: github.repository_owner == 'zed-industries'
runs-on: ${{ matrix.system.runner }}
needs: tests
env:
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: ${{ secrets.ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON }}
GIT_LFS_SKIP_SMUDGE: 1 # breaks the livekit rust sdk examples which we don't actually depend on
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
# on our macs we manually install nix. for some reason the cachix action is running
# under a non-login /bin/bash shell which doesn't source the proper script to add the
# nix profile to PATH, so we manually add them here
- name: Set path
if: ${{ ! matrix.system.install_nix }}
run: |
echo "/nix/var/nix/profiles/default/bin" >> $GITHUB_PATH
echo "/Users/administrator/.nix-profile/bin" >> $GITHUB_PATH
- uses: cachix/install-nix-action@d1ca217b388ee87b2507a9a93bf01368bde7cec2 # v31
if: ${{ matrix.system.install_nix }}
with:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
- uses: cachix/cachix-action@0fc020193b5a1fa3ac4575aa3a7d3aa6a35435ad # v16
with:
name: zed-industries
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- run: nix build
- name: Limit /nix/store to 50GB
run: '[ $(du -sm /nix/store | cut -f1) -gt 50000 ] && nix-collect-garbage -d'
update-nightly-tag:
name: Update nightly tag
if: github.repository_owner == 'zed-industries'

View File

@@ -46,17 +46,5 @@
"formatter": "auto",
"remove_trailing_whitespace_on_save": true,
"ensure_final_newline_on_save": true,
"file_scan_exclusions": [
"crates/eval/worktrees/",
"crates/eval/repos/",
"**/.git",
"**/.svn",
"**/.hg",
"**/.jj",
"**/CVS",
"**/.DS_Store",
"**/Thumbs.db",
"**/.classpath",
"**/.settings"
]
"file_scan_exclusions": ["crates/eval/worktrees/", "crates/eval/repos/"]
}

36
Cargo.lock generated
View File

@@ -690,7 +690,6 @@ dependencies = [
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"serde",
"serde_json",
"settings",
@@ -704,9 +703,7 @@ dependencies = [
name = "assistant_tools"
version = "0.1.0"
dependencies = [
"aho-corasick",
"anyhow",
"assistant_settings",
"assistant_tool",
"buffer_diff",
"chrono",
@@ -714,40 +711,25 @@ dependencies = [
"clock",
"collections",
"component",
"derive_more",
"editor",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"handlebars 4.5.0",
"html_to_markdown",
"http_client",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"language_models",
"linkme",
"open",
"pretty_assertions",
"project",
"rand 0.8.5",
"regex",
"reqwest_client",
"rust-embed",
"schemars",
"serde",
"serde_json",
"settings",
"smallvec",
"streaming_diff",
"strsim",
"task",
"tempfile",
"terminal",
"terminal_view",
"tree-sitter-rust",
"ui",
"unindent",
@@ -3221,7 +3203,9 @@ dependencies = [
name = "component_preview"
version = "0.1.0"
dependencies = [
"agent",
"anyhow",
"assistant_tool",
"client",
"collections",
"component",
@@ -3231,9 +3215,11 @@ dependencies = [
"log",
"notifications",
"project",
"prompt_store",
"serde",
"ui",
"ui_input",
"util",
"workspace",
"workspace-hack",
]
@@ -4378,17 +4364,14 @@ name = "diagnostics"
version = "0.1.0"
dependencies = [
"anyhow",
"cargo_metadata",
"client",
"collections",
"component",
"ctor",
"editor",
"env_logger 0.11.8",
"futures 0.3.31",
"gpui",
"indoc",
"itertools 0.14.0",
"language",
"linkme",
"log",
@@ -4400,7 +4383,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"text",
"theme",
"ui",
@@ -5012,11 +4994,9 @@ dependencies = [
"language_model",
"language_models",
"languages",
"markdown",
"node_runtime",
"pathdiff",
"paths",
"pretty_assertions",
"project",
"prompt_store",
"regex",
@@ -6381,7 +6361,6 @@ dependencies = [
"log",
"pest",
"pest_derive",
"rust-embed",
"serde",
"serde_json",
"thiserror 1.0.69",
@@ -16899,22 +16878,18 @@ version = "0.1.0"
dependencies = [
"anyhow",
"client",
"component",
"db",
"documented",
"editor",
"fuzzy",
"gpui",
"install_cli",
"language",
"linkme",
"picker",
"project",
"schemars",
"serde",
"settings",
"telemetry",
"theme",
"ui",
"util",
"vim_mode_setting",
@@ -18047,7 +18022,6 @@ dependencies = [
"getrandom 0.2.15",
"getrandom 0.3.2",
"gimli",
"handlebars 4.5.0",
"hashbrown 0.14.5",
"hashbrown 0.15.2",
"heck 0.4.1",
@@ -18491,7 +18465,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.186.0"
version = "0.185.0"
dependencies = [
"activity_indicator",
"agent",

View File

@@ -435,7 +435,6 @@ dap-types = { git = "https://github.com/zed-industries/dap-types", rev = "be69a0
dashmap = "6.0"
derive_more = "0.99.17"
dirs = "4.0"
documented = "0.9.1"
dotenv = "0.15.0"
ec4rs = "1.1"
emojis = "0.6.1"
@@ -798,6 +797,5 @@ ignored = [
"serde",
"component",
"linkme",
"documented",
"workspace-hack",
]

View File

@@ -1,5 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 12H16" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4 6H20" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4 18H12" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 402 B

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-user-round-check-icon lucide-user-round-check"><path d="M2 21a8 8 0 0 1 13.292-6"/><circle cx="10" cy="8" r="5"/><path d="m16 19 2 2 4-4"/></svg>

Before

Width:  |  Height:  |  Size: 348 B

View File

@@ -1,14 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_2489_484)">
<path d="M11 8.9V11C8.51716 11 7.48284 11 5 11V10.4L11 5.6V5H5V7.1" stroke="black" stroke-width="1.5"/>
<path d="M1.5 5.5V1.5H5" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
<path d="M14.5 5.5V1.5H11" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
<path d="M1.5 10.5V14.5H5" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
<path d="M14.5 10.5V14.5H11" stroke="black" stroke-opacity="0.5" stroke-width="1.5"/>
</g>
<defs>
<clipPath id="clip0_2489_484">
<rect width="16" height="16" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 687 B

View File

@@ -245,19 +245,11 @@
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "assistant::ToggleModelSelector",
"ctrl-shift-a": "agent::ToggleContextPicker",
"ctrl-shift-o": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"shift-escape": "agent::ExpandMessageEditor",
"ctrl-e": "agent::ChatMode",
"ctrl-alt-e": "agent::RemoveAllContext"
}
},
{
"context": "AgentPanel > NavigationMenu",
"bindings": {
"shift-backspace": "agent::DeleteRecentlyOpenThread"
}
},
{
"context": "AgentPanel > Markdown",
"bindings": {
@@ -528,7 +520,6 @@
"shift-new": "workspace::NewWindow",
"ctrl-shift-n": "workspace::NewWindow",
"ctrl-`": "terminal_panel::ToggleFocus",
"f10": ["app_menu::OpenApplicationMenu", "Zed"],
"alt-1": ["workspace::ActivatePane", 0],
"alt-2": ["workspace::ActivatePane", 1],
"alt-3": ["workspace::ActivatePane", 2],
@@ -587,7 +578,6 @@
{
"context": "ApplicationMenu",
"bindings": {
"f10": "menu::Cancel",
"left": "app_menu::ActivateMenuLeft",
"right": "app_menu::ActivateMenuRight"
}
@@ -962,12 +952,5 @@
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
"bindings": {
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
}
}
]

View File

@@ -290,19 +290,11 @@
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "assistant::ToggleModelSelector",
"cmd-shift-a": "agent::ToggleContextPicker",
"cmd-shift-o": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"shift-escape": "agent::ExpandMessageEditor",
"cmd-e": "agent::ChatMode",
"cmd-alt-e": "agent::RemoveAllContext"
}
},
{
"context": "AgentPanel > NavigationMenu",
"bindings": {
"shift-backspace": "agent::DeleteRecentlyOpenThread"
}
},
{
"context": "AgentPanel > Markdown",
"use_key_equivalents": true,
@@ -1068,12 +1060,5 @@
"bindings": {
"escape": "menu::Cancel"
}
},
{
"context": "Diagnostics",
"use_key_equivalents": true,
"bindings": {
"ctrl-r": "diagnostics::ToggleDiagnosticsRefresh"
}
}
]

View File

@@ -516,14 +516,12 @@
"'": "vim::Quotes",
"`": "vim::BackQuotes",
"\"": "vim::DoubleQuotes",
// "q": "vim::AnyQuotes",
"q": "vim::MiniQuotes",
"q": "vim::AnyQuotes",
"|": "vim::VerticalBars",
"(": "vim::Parentheses",
")": "vim::Parentheses",
"b": "vim::Parentheses",
// "b": "vim::AnyBrackets",
// "b": "vim::MiniBrackets",
"[": "vim::SquareBrackets",
"]": "vim::SquareBrackets",
"r": "vim::SquareBrackets",

View File

@@ -15,7 +15,6 @@ You are a highly skilled software engineer with extensive knowledge in many prog
3. DO NOT use tools to access items that are already available in the context section.
4. Use only the tools that are currently available.
5. DO NOT use a tool that is not available just because it appears in the conversation. This means the user turned it off.
6. NEVER run commands that don't terminate on their own such as web servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers.
## Searching and Reading
@@ -39,78 +38,18 @@ If appropriate, use tool calls to explore the current project, which contains th
## Code Block Formatting
Whenever you mention a code block, you MUST use ONLY use the following format:
Whenever you mention a code block, you MUST use ONLY use the following format when the code in the block comes from a file
in the project:
```path/to/Something.blah#L123-456
(code goes here)
```
The `#L123-456` means the line number range 123 through 456, and the path/to/Something.blah
is a path in the project. (If there is no valid path in the project, then you can use
/dev/null/path.extension for its path.) This is the ONLY valid way to format code blocks, because the Markdown parser
does not understand the more common ```language syntax, or bare ``` blocks. It only
understands this path-based syntax, and if the path is missing, then it will error and you will have to do it over again.
Just to be really clear about this, if you ever find yourself writing three backticks followed by a language name, STOP!
You have made a mistake. You can only ever put paths after triple backticks!
<example>
Based on all the information I've gathered, here's a summary of how this system works:
1. The README file is loaded into the system.
2. The system finds the first two headers, including everything in between. In this case, that would be:
```path/to/README.md#L8-12
# First Header
This is the info under the first header.
## Sub-header
```
3. Then the system finds the last header in the README:
```path/to/README.md#L27-29
## Last Header
This is the last header in the README.
```
4. Finally, it passes this information on to the next process.
</example>
<example>
In Markdown, hash marks signify headings. For example:
```/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</example>
Here are examples of ways you must never render code blocks:
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it does not include the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because it has the language instead of the path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
# Level 1 heading
## Level 2 heading
### Level 3 heading
</bad_example_do_not_do_this>
This example is unacceptable because it uses indentation to mark the code block
instead of backticks with a path.
<bad_example_do_not_do_this>
In Markdown, hash marks signify headings. For example:
```markdown
/dev/null/example.md#L1-3
# Level 1 heading
## Level 2 heading
### Level 3 heading
```
</bad_example_do_not_do_this>
This example is unacceptable because the path is in the wrong place. The path must be directly after the opening backticks.
is a path in the project. (If this code block does not come from a file in the project, then you may instead use
the normal markdown style of three backticks followed by language name. However, you MUST use this format if
the code in the block comes from a file in the project.)
## Fixing Diagnostics
1. Make 1-2 attempts at fixing diagnostics, then defer to the user.

View File

@@ -657,8 +657,6 @@
},
// When enabled, the agent can run potentially destructive actions without asking for your confirmation.
"always_allow_tool_actions": false,
// When enabled, the agent will stream edits.
"stream_edits": false,
"default_profile": "write",
"profiles": {
"ask": {
@@ -673,7 +671,6 @@
"now": true,
"find_path": true,
"read_file": true,
"open": true,
"grep": true,
"thinking": true,
"web_search": true
@@ -705,11 +702,6 @@
"thinking": true,
"web_search": true
}
},
"manual": {
"name": "Manual",
"enable_all_context_servers": false,
"tools": {}
}
},
// Where to show notifications when an agent has either completed
@@ -837,20 +829,7 @@
// "modal_max_width": "full"
//
// Default: small
"modal_max_width": "small",
// Determines whether the file finder should skip focus for the active file in search results.
// There are 2 possible values:
//
// 1. true: When searching for files, if the currently active file appears as the first result,
// auto-focus will skip it and focus the second result instead.
// "skip_focus_for_active_in_search": true
//
// 2. false: When searching for files, the first result will always receive focus,
// even if it's the currently active file.
// "skip_focus_for_active_in_search": false
//
// Default: true
"skip_focus_for_active_in_search": true
"modal_max_width": "small"
},
// Whether or not to remove any trailing whitespace from lines of a buffer
// before saving it.
@@ -933,24 +912,6 @@
// The minimum severity of the diagnostics to show inline.
// Shows all diagnostics when not specified.
"max_severity": null
},
"rust": {
// When enabled, Zed runs `cargo check --message-format=json`-based commands and
// collect cargo diagnostics instead of rust-analyzer.
"fetch_cargo_diagnostics": false,
// A command override for fetching the cargo diagnostics.
// First argument is the command, followed by the arguments.
"diagnostics_fetch_command": [
"cargo",
"check",
"--quiet",
"--workspace",
"--message-format=json",
"--all-targets",
"--keep-going"
],
// Extra environment variables to pass to the diagnostics fetch command.
"env": {}
}
},
// Files or globs of files that will be excluded by Zed entirely. They will be skipped during file

File diff suppressed because it is too large Load Diff

View File

@@ -973,8 +973,8 @@ mod tests {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
prompt_store,
Arc::new(PromptBuilder::new(None).unwrap()),
prompt_store,
cx,
)
})

View File

@@ -34,15 +34,15 @@ use prompt_store::PromptBuilder;
use schemars::JsonSchema;
use serde::Deserialize;
use settings::Settings as _;
use thread::ThreadId;
pub use thread::{MessageSegment, ThreadId};
pub use crate::active_thread::ActiveThread;
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::context::{ContextLoadResult, LoadedContext};
pub use crate::inline_assistant::InlineAssistant;
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
pub use crate::thread_store::ThreadStore;
pub use crate::thread::{Message, Thread, ThreadEvent};
pub use crate::thread_store::{SharedProjectContext, ThreadStore};
pub use agent_diff::{AgentDiff, AgentDiffToolbar};
actions!(
@@ -50,9 +50,6 @@ actions!(
[
NewTextThread,
ToggleContextPicker,
ToggleNavigationMenu,
ToggleOptionsMenu,
DeleteRecentlyOpenThread,
ToggleProfileSelector,
RemoveAllContext,
ExpandMessageEditor,

View File

@@ -1,5 +1,5 @@
use std::ops::Range;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -18,8 +18,8 @@ use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
Corner, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext,
Pixels, Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
Corner, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels,
Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
};
use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
@@ -41,16 +41,15 @@ use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
use crate::history_store::{HistoryEntry, HistoryStore, RecentEntry};
use crate::history_store::{HistoryEntry, HistoryStore};
use crate::message_editor::{MessageEditor, MessageEditorEvent};
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
use crate::thread_history::{PastContext, PastThread, ThreadHistory};
use crate::thread_store::ThreadStore;
use crate::ui::UsageBanner;
use crate::{
AddContextServer, AgentDiff, DeleteRecentlyOpenThread, ExpandMessageEditor, InlineAssistant,
NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent,
ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu,
AddContextServer, AgentDiff, ExpandMessageEditor, InlineAssistant, NewTextThread, NewThread,
OpenActiveThreadAsMarkdown, OpenAgentDiff, OpenHistory, ThreadEvent, ToggleContextPicker,
};
pub fn init(cx: &mut App) {
@@ -105,22 +104,6 @@ pub fn init(cx: &mut App) {
});
});
}
})
.register_action(|workspace, _: &ToggleNavigationMenu, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.toggle_navigation_menu(&ToggleNavigationMenu, window, cx);
});
}
})
.register_action(|workspace, _: &ToggleOptionsMenu, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.toggle_options_menu(&ToggleOptionsMenu, window, cx);
});
}
});
},
)
@@ -130,7 +113,6 @@ pub fn init(cx: &mut App) {
enum ActiveView {
Thread {
change_title_editor: Entity<Editor>,
thread: WeakEntity<Thread>,
_subscriptions: Vec<gpui::Subscription>,
},
PromptEditor {
@@ -148,7 +130,7 @@ impl ActiveView {
let editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_text(summary.clone(), window, cx);
editor.set_text(summary, window, cx);
editor
});
@@ -194,7 +176,6 @@ impl ActiveView {
Self::Thread {
change_title_editor: editor,
thread: thread.downgrade(),
_subscriptions: subscriptions,
}
}
@@ -298,8 +279,6 @@ pub struct AssistantPanel {
history_store: Entity<HistoryStore>,
history: Entity<ThreadHistory>,
assistant_dropdown_menu_handle: PopoverMenuHandle<ContextMenu>,
assistant_navigation_menu_handle: PopoverMenuHandle<ContextMenu>,
assistant_navigation_menu: Option<Entity<ContextMenu>>,
width: Option<Pixels>,
height: Option<Pixels>,
}
@@ -323,8 +302,8 @@ impl AssistantPanel {
ThreadStore::load(
project,
tools.clone(),
prompt_store.clone(),
prompt_builder.clone(),
prompt_store.clone(),
cx,
)
})?
@@ -401,15 +380,8 @@ impl AssistantPanel {
}
});
let thread_id = thread.read(cx).id().clone();
let history_store = cx.new(|cx| {
HistoryStore::new(
thread_store.clone(),
context_store.clone(),
[RecentEntry::Thread(thread_id, thread.clone())],
cx,
)
});
let history_store =
cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx));
cx.observe(&history_store, |_, _, cx| cx.notify()).detach();
@@ -420,11 +392,10 @@ impl AssistantPanel {
cx.notify();
}
});
let active_thread = cx.new(|cx| {
let thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
message_editor_context_store.clone(),
language_registry.clone(),
workspace.clone(),
window,
@@ -432,112 +403,10 @@ impl AssistantPanel {
)
});
let active_thread_subscription =
cx.subscribe(&active_thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
let weak_panel = weak_self.clone();
window.defer(cx, move |window, cx| {
let panel = weak_panel.clone();
let assistant_navigation_menu =
ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| {
let recently_opened = panel
.update(cx, |this, cx| {
this.history_store.update(cx, |history_store, cx| {
history_store.recently_opened_entries(cx)
})
})
.unwrap_or_default();
if !recently_opened.is_empty() {
menu = menu.header("Recently Opened");
for entry in recently_opened.iter() {
let summary = entry.summary(cx);
menu = menu.entry_with_end_slot_on_hover(
summary,
None,
{
let panel = panel.clone();
let entry = entry.clone();
move |window, cx| {
panel
.update(cx, {
let entry = entry.clone();
move |this, cx| match entry {
RecentEntry::Thread(_, thread) => {
this.open_thread(thread, window, cx)
}
RecentEntry::Context(context) => {
let Some(path) = context.read(cx).path()
else {
return;
};
this.open_saved_prompt_editor(
path.clone(),
window,
cx,
)
.detach_and_log_err(cx)
}
}
})
.ok();
}
},
IconName::Close,
"Close Entry".into(),
{
let panel = panel.clone();
let entry = entry.clone();
move |_window, cx| {
panel
.update(cx, |this, cx| {
this.history_store.update(
cx,
|history_store, cx| {
history_store.remove_recently_opened_entry(
&entry, cx,
);
},
);
})
.ok();
}
},
);
}
menu = menu.separator();
}
menu.action("View All", Box::new(OpenHistory))
.end_slot_action(DeleteRecentlyOpenThread.boxed_clone())
.fixed_width(px(320.).into())
.keep_open_on_confirm(false)
.key_context("NavigationMenu")
});
weak_panel
.update(cx, |panel, cx| {
cx.subscribe_in(
&assistant_navigation_menu,
window,
|_, menu, _: &DismissEvent, window, cx| {
menu.update(cx, |menu, _| {
menu.clear_selected();
});
cx.focus_self(window);
},
)
.detach();
panel.assistant_navigation_menu = Some(assistant_navigation_menu);
})
.ok();
let active_thread_subscription = cx.subscribe(&thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
let _default_model_subscription = cx.subscribe(
@@ -562,7 +431,7 @@ impl AssistantPanel {
fs: fs.clone(),
language_registry,
thread_store: thread_store.clone(),
thread: active_thread,
thread,
message_editor,
_active_thread_subscriptions: vec![
thread_subscription,
@@ -582,8 +451,6 @@ impl AssistantPanel {
history_store: history_store.clone(),
history: cx.new(|cx| ThreadHistory::new(weak_self, history_store, window, cx)),
assistant_dropdown_menu_handle: PopoverMenuHandle::default(),
assistant_navigation_menu_handle: PopoverMenuHandle::default(),
assistant_navigation_menu: None,
width: None,
height: None,
}
@@ -628,7 +495,7 @@ impl AssistantPanel {
let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx);
let context_store = cx.new(|_cx| {
let message_editor_context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
self.project.downgrade(),
Some(self.thread_store.downgrade()),
@@ -641,7 +508,7 @@ impl AssistantPanel {
.update(cx, |this, cx| this.open_thread(&other_thread_id, cx));
cx.spawn({
let context_store = context_store.clone();
let context_store = message_editor_context_store.clone();
async move |_panel, cx| {
let other_thread = other_thread_task.await?;
@@ -666,7 +533,6 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
@@ -685,7 +551,7 @@ impl AssistantPanel {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
context_store,
message_editor_context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
@@ -779,13 +645,13 @@ impl AssistantPanel {
pub(crate) fn open_saved_prompt_editor(
&mut self,
path: Arc<Path>,
path: PathBuf,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let context = self
.context_store
.update(cx, |store, cx| store.open_local_context(path, cx));
.update(cx, |store, cx| store.open_local_context(path.clone(), cx));
let fs = self.fs.clone();
let project = self.project.clone();
let workspace = self.workspace.clone();
@@ -819,7 +685,7 @@ impl AssistantPanel {
})
}
pub(crate) fn open_thread_by_id(
pub(crate) fn open_thread(
&mut self,
thread_id: &ThreadId,
window: &mut Window,
@@ -828,84 +694,73 @@ impl AssistantPanel {
let open_thread_task = self
.thread_store
.update(cx, |this, cx| this.open_thread(thread_id, cx));
cx.spawn_in(window, async move |this, cx| {
let thread = open_thread_task.await?;
this.update_in(cx, |this, window, cx| {
this.open_thread(thread, window, cx);
anyhow::Ok(())
})??;
Ok(())
let thread_view = ActiveView::thread(thread.clone(), window, cx);
this.set_active_view(thread_view, window, cx);
let message_editor_context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
this.project.downgrade(),
Some(this.thread_store.downgrade()),
)
});
let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
this.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
this.thread_store.clone(),
this.language_registry.clone(),
this.workspace.clone(),
window,
cx,
)
});
let active_thread_subscription =
cx.subscribe(&this.thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
this.message_editor = cx.new(|cx| {
MessageEditor::new(
this.fs.clone(),
this.workspace.clone(),
message_editor_context_store,
this.prompt_store.clone(),
this.thread_store.downgrade(),
thread,
window,
cx,
)
});
this.message_editor.focus_handle(cx).focus(window);
let message_editor_subscription =
cx.subscribe(&this.message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
this._active_thread_subscriptions = vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
];
})
})
}
pub(crate) fn open_thread(
&mut self,
thread: Entity<Thread>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let thread_view = ActiveView::thread(thread.clone(), window, cx);
self.set_active_view(thread_view, window, cx);
let context_store = cx.new(|_cx| {
crate::context_store::ContextStore::new(
self.project.downgrade(),
Some(self.thread_store.downgrade()),
)
});
let thread_subscription = cx.subscribe(&thread, |_, _, event, cx| {
if let ThreadEvent::MessageAdded(_) = &event {
// needed to leave empty state
cx.notify();
}
});
self.thread = cx.new(|cx| {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
context_store.clone(),
self.language_registry.clone(),
self.workspace.clone(),
window,
cx,
)
});
let active_thread_subscription =
cx.subscribe(&self.thread, |_, _, event, cx| match &event {
ActiveThreadEvent::EditingMessageTokenCountChanged => {
cx.notify();
}
});
self.message_editor = cx.new(|cx| {
MessageEditor::new(
self.fs.clone(),
self.workspace.clone(),
context_store,
self.prompt_store.clone(),
self.thread_store.downgrade(),
thread,
window,
cx,
)
});
self.message_editor.focus_handle(cx).focus(window);
let message_editor_subscription =
cx.subscribe(&self.message_editor, |_, _, event, cx| match event {
MessageEditorEvent::Changed | MessageEditorEvent::EstimatedTokenCount => {
cx.notify();
}
});
self._active_thread_subscriptions = vec![
thread_subscription,
active_thread_subscription,
message_editor_subscription,
];
}
pub fn go_back(&mut self, _: &workspace::GoBack, window: &mut Window, cx: &mut Context<Self>) {
match self.active_view {
ActiveView::Configuration | ActiveView::History => {
@@ -918,24 +773,6 @@ impl AssistantPanel {
}
}
pub fn toggle_navigation_menu(
&mut self,
_: &ToggleNavigationMenu,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.assistant_navigation_menu_handle.toggle(window, cx);
}
pub fn toggle_options_menu(
&mut self,
_: &ToggleOptionsMenu,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.assistant_dropdown_menu_handle.toggle(window, cx);
}
pub fn open_agent_diff(
&mut self,
_: &OpenAgentDiff,
@@ -1084,7 +921,7 @@ impl AssistantPanel {
pub(crate) fn delete_context(
&mut self,
path: Arc<Path>,
path: PathBuf,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
self.context_store
@@ -1100,34 +937,6 @@ impl AssistantPanel {
let current_is_history = matches!(self.active_view, ActiveView::History);
let new_is_history = matches!(new_view, ActiveView::History);
match &self.active_view {
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
if let Some(thread) = thread.upgrade() {
if thread.read(cx).is_empty() {
let id = thread.read(cx).id().clone();
store.remove_recently_opened_thread(id, cx);
}
}
}),
_ => {}
}
match &new_view {
ActiveView::Thread { thread, .. } => self.history_store.update(cx, |store, cx| {
if let Some(thread) = thread.upgrade() {
let id = thread.read(cx).id().clone();
store.push_recently_opened_entry(RecentEntry::Thread(id, thread), cx);
}
}),
ActiveView::PromptEditor { context_editor, .. } => {
self.history_store.update(cx, |store, cx| {
let context = context_editor.read(cx).context().clone();
store.push_recently_opened_entry(RecentEntry::Context(context), cx)
})
}
_ => {}
}
if current_is_history && !new_is_history {
self.active_view = new_view;
} else if !current_is_history && new_is_history {
@@ -1257,13 +1066,16 @@ impl AssistantPanel {
if is_empty {
Label::new(Thread::DEFAULT_SUMMARY.clone())
.truncate()
.ml_2()
.into_any_element()
} else if summary.is_none() {
Label::new(LOADING_SUMMARY_PLACEHOLDER)
.ml_2()
.truncate()
.into_any_element()
} else {
div()
.ml_2()
.w_full()
.child(change_title_editor.clone())
.into_any_element()
@@ -1280,15 +1092,18 @@ impl AssistantPanel {
match summary {
None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone())
.truncate()
.ml_2()
.into_any_element(),
Some(summary) => {
if summary.done {
div()
.ml_2()
.w_full()
.child(title_editor.clone())
.into_any_element()
} else {
Label::new(LOADING_SUMMARY_PLACEHOLDER)
.ml_2()
.truncate()
.into_any_element()
}
@@ -1315,6 +1130,7 @@ impl AssistantPanel {
let thread = active_thread.thread().read(cx);
let thread_id = thread.id().clone();
let is_empty = active_thread.is_empty();
let is_history = matches!(self.active_view, ActiveView::History);
let show_token_count = match &self.active_view {
ActiveView::Thread { .. } => !is_empty,
@@ -1324,108 +1140,30 @@ impl AssistantPanel {
let focus_handle = self.focus_handle(cx);
let go_back_button = div().child(
IconButton::new("go-back", IconName::ArrowLeft)
.icon_size(IconSize::Small)
.on_click(cx.listener(|this, _, window, cx| {
this.go_back(&workspace::GoBack, window, cx);
}))
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Go Back",
&workspace::GoBack,
&focus_handle,
window,
cx,
)
}
}),
);
let recent_entries_menu = div().child(
PopoverMenu::new("agent-nav-menu")
.trigger_with_tooltip(
IconButton::new("agent-nav-menu", IconName::MenuAlt)
let go_back_button = match &self.active_view {
ActiveView::History | ActiveView::Configuration => Some(
div().pl_1().child(
IconButton::new("go-back", IconName::ArrowLeft)
.icon_size(IconSize::Small)
.style(ui::ButtonStyle::Subtle),
{
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Toggle Panel Menu",
&ToggleNavigationMenu,
&focus_handle,
window,
cx,
)
}
},
)
.anchor(Corner::TopLeft)
.with_handle(self.assistant_navigation_menu_handle.clone())
.menu({
let menu = self.assistant_navigation_menu.clone();
move |window, cx| {
if let Some(menu) = menu.as_ref() {
menu.update(cx, |_, cx| {
cx.defer_in(window, |menu, window, cx| {
menu.rebuild(window, cx);
});
})
}
menu.clone()
}
}),
);
let agent_extra_menu = PopoverMenu::new("agent-options-menu")
.trigger_with_tooltip(
IconButton::new("agent-options-menu", IconName::Ellipsis)
.icon_size(IconSize::Small),
{
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Toggle Agent Menu",
&ToggleOptionsMenu,
&focus_handle,
window,
cx,
)
}
},
)
.anchor(Corner::TopRight)
.with_handle(self.assistant_dropdown_menu_handle.clone())
.menu(move |window, cx| {
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
menu.when(!is_empty, |menu| {
menu.action(
"Start New From Summary",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
)
.separator()
})
.action("New Text Thread", NewTextThread.boxed_clone())
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenConfiguration))
.separator()
.header("MCPs")
.action(
"View Server Extensions",
Box::new(zed_actions::Extensions {
category_filter: Some(
zed_actions::ExtensionCategoryFilter::ContextServers,
),
.on_click(cx.listener(|this, _, window, cx| {
this.go_back(&workspace::GoBack, window, cx);
}))
.tooltip({
let focus_handle = focus_handle.clone();
move |window, cx| {
Tooltip::for_action_in(
"Go Back",
&workspace::GoBack,
&focus_handle,
window,
cx,
)
}
}),
)
.action("Add Custom Server", Box::new(AddContextServer))
}))
});
),
),
_ => None,
};
h_flex()
.id("assistant-toolbar")
@@ -1439,22 +1177,18 @@ impl AssistantPanel {
.border_color(cx.theme().colors().border)
.child(
h_flex()
.size_full()
.pl_1()
.w_full()
.gap_1()
.child(match &self.active_view {
ActiveView::History | ActiveView::Configuration => go_back_button,
_ => recent_entries_menu,
})
.children(go_back_button)
.child(self.render_title_view(window, cx)),
)
.child(
h_flex()
.h_full()
.gap_2()
.when(show_token_count, |parent| {
.when(show_token_count, |parent|
parent.children(self.render_token_count(&thread, cx))
})
)
.child(
h_flex()
.h_full()
@@ -1482,7 +1216,72 @@ impl AssistantPanel {
);
}),
)
.child(agent_extra_menu),
.child(
IconButton::new("open-history", IconName::HistoryRerun)
.icon_size(IconSize::Small)
.toggle_state(is_history)
.selected_icon_color(Color::Accent)
.tooltip({
let focus_handle = self.focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"History",
&OpenHistory,
&focus_handle,
window,
cx,
)
}
})
.on_click(move |_event, window, cx| {
window.dispatch_action(OpenHistory.boxed_clone(), cx);
}),
)
.child(
PopoverMenu::new("assistant-menu")
.trigger_with_tooltip(
IconButton::new("new", IconName::Ellipsis)
.icon_size(IconSize::Small)
.style(ButtonStyle::Subtle),
Tooltip::text("Toggle Agent Menu"),
)
.anchor(Corner::TopRight)
.with_handle(self.assistant_dropdown_menu_handle.clone())
.menu(move |window, cx| {
Some(ContextMenu::build(
window,
cx,
|menu, _window, _cx| {
menu
.when(!is_empty, |menu| {
menu.action(
"Start New From Summary",
Box::new(NewThread {
from_thread_id: Some(thread_id.clone()),
}),
).separator()
})
.action(
"New Text Thread",
NewTextThread.boxed_clone(),
)
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenConfiguration))
.separator()
.header("MCPs")
.action(
"View Server Extensions",
Box::new(zed_actions::Extensions {
category_filter: Some(
zed_actions::ExtensionCategoryFilter::ContextServers,
),
}),
)
.action("Add Custom Server", Box::new(AddContextServer))
},
))
}),
),
),
)
}
@@ -2183,8 +1982,6 @@ impl Render for AssistantPanel {
.on_action(cx.listener(Self::deploy_rules_library))
.on_action(cx.listener(Self::open_agent_diff))
.on_action(cx.listener(Self::go_back))
.on_action(cx.listener(Self::toggle_navigation_menu))
.on_action(cx.listener(Self::toggle_options_menu))
.child(self.render_toolbar(window, cx))
.map(|parent| match &self.active_view {
ActiveView::Thread { .. } => parent
@@ -2269,7 +2066,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
fn open_saved_context(
&self,
workspace: &mut Workspace,
path: Arc<Path>,
path: std::path::PathBuf,
window: &mut Window,
cx: &mut Context<Workspace>,
) -> Task<Result<()>> {

View File

@@ -3,12 +3,11 @@ use std::hash::{Hash, Hasher};
use std::path::PathBuf;
use std::{ops::Range, path::Path, sync::Arc};
use assistant_tool::outline;
use collections::HashSet;
use futures::future;
use futures::{FutureExt, future::Shared};
use gpui::{App, AppContext as _, Entity, SharedString, Task};
use language::{Buffer, ParseStatus};
use language::Buffer;
use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
use project::{Project, ProjectEntryId, ProjectPath, Worktree};
use prompt_store::{PromptStore, UserPromptId};
@@ -153,7 +152,6 @@ pub struct FileContext {
pub handle: FileContextHandle,
pub full_path: Arc<Path>,
pub text: SharedString,
pub is_outline: bool,
}
impl FileContextHandle {
@@ -179,51 +177,14 @@ impl FileContextHandle {
log::error!("file context missing path");
return Task::ready(None);
};
let full_path: Arc<Path> = file.full_path(cx).into();
let full_path = file.full_path(cx);
let rope = buffer_ref.as_rope().clone();
let buffer = self.buffer.clone();
cx.spawn(async move |cx| {
// For large files, use outline instead of full content
if rope.len() > outline::AUTO_OUTLINE_SIZE {
// Wait until the buffer has been fully parsed, so we can read its outline
if let Ok(mut parse_status) =
buffer.read_with(cx, |buffer, _| buffer.parse_status())
{
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await.log_err();
}
if let Ok(snapshot) = buffer.read_with(cx, |buffer, _| buffer.snapshot()) {
if let Some(outline) = snapshot.outline(None) {
let items = outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot));
if let Ok(outline_text) =
outline::render_outline(items, None, 0, usize::MAX).await
{
let context = AgentContext::File(FileContext {
handle: self,
full_path,
text: outline_text.into(),
is_outline: true,
});
return Some((context, vec![buffer]));
}
}
}
}
}
// Fallback to full content if we couldn't build an outline
// (or didn't need to because the file was small enough)
cx.background_spawn(async move {
let context = AgentContext::File(FileContext {
handle: self,
full_path,
full_path: full_path.into(),
text: rope.to_string().into(),
is_outline: false,
});
Some((context, vec![buffer]))
})
@@ -1035,115 +996,3 @@ impl Hash for AgentContextKey {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
fn init_test_settings(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
// Helper to create a test project with test files
async fn create_test_project(
cx: &mut TestAppContext,
files: serde_json::Value,
) -> Entity<Project> {
let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(path!("/test"), files).await;
Project::test(fs, [path!("/test").as_ref()], cx).await
}
#[gpui::test]
async fn test_large_file_uses_outline(cx: &mut TestAppContext) {
init_test_settings(cx);
// Create a large file that exceeds AUTO_OUTLINE_SIZE
const LINE: &str = "Line with some text\n";
let large_content = LINE.repeat(2 * (outline::AUTO_OUTLINE_SIZE / LINE.len()));
let content_len = large_content.len();
assert!(content_len > outline::AUTO_OUTLINE_SIZE);
let file_context = file_context_for(large_content, cx).await;
assert!(
file_context.is_outline,
"Large file should use outline format"
);
assert!(
file_context.text.len() < content_len,
"Outline should be smaller than original content"
);
}
#[gpui::test]
async fn test_small_file_uses_full_content(cx: &mut TestAppContext) {
init_test_settings(cx);
let small_content = "This is a small file.\n";
let content_len = small_content.len();
assert!(content_len < outline::AUTO_OUTLINE_SIZE);
let file_context = file_context_for(small_content.to_string(), cx).await;
assert!(
!file_context.is_outline,
"Small files should not get an outline"
);
assert_eq!(file_context.text, small_content);
}
async fn file_context_for(content: String, cx: &mut TestAppContext) -> FileContext {
// Create a test project with the file
let project = create_test_project(
cx,
json!({
"file.txt": content,
}),
)
.await;
// Open the buffer
let buffer_path = project
.read_with(cx, |project, cx| project.find_project_path("file.txt", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
.await
.unwrap();
let context_handle = AgentContextHandle::File(FileContextHandle {
buffer: buffer.clone(),
context_id: ContextId::zero(),
});
cx.update(|cx| load_context(vec![context_handle], &project, &None, cx))
.await
.loaded_context
.contexts
.into_iter()
.find_map(|ctx| {
if let AgentContext::File(file_ctx) = ctx {
Some(file_ctx)
} else {
None
}
})
.expect("Should have found a file context")
}
}

View File

@@ -267,7 +267,7 @@ impl ContextPicker {
context_picker.update(cx, |this, cx| this.select_entry(entry, window, cx))
})
}))
.keep_open_on_confirm(true)
.keep_open_on_confirm()
});
cx.subscribe(&menu, move |_, _, _: &DismissEvent, cx| {

View File

@@ -21,7 +21,7 @@ use crate::context::{
SymbolContextHandle, ThreadContextHandle,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{MessageId, Thread, ThreadId};
use crate::thread::{Thread, ThreadId};
pub struct ContextStore {
project: WeakEntity<Project>,
@@ -54,14 +54,9 @@ impl ContextStore {
self.context_thread_ids.clear();
}
pub fn new_context_for_thread(
&self,
thread: &Thread,
exclude_messages_from_id: Option<MessageId>,
) -> Vec<AgentContextHandle> {
pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
let existing_context = thread
.messages()
.take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
.flat_map(|message| {
message
.loaded_context

View File

@@ -1,27 +1,10 @@
use std::{collections::VecDeque, path::Path};
use anyhow::{Context as _, anyhow};
use assistant_context_editor::{AssistantContext, SavedContextMetadata};
use assistant_context_editor::SavedContextMetadata;
use chrono::{DateTime, Utc};
use futures::future::{TryFutureExt as _, join_all};
use gpui::{Entity, Task, prelude::*};
use serde::{Deserialize, Serialize};
use smol::future::FutureExt;
use std::time::Duration;
use ui::{App, SharedString};
use util::ResultExt as _;
use gpui::{Entity, prelude::*};
use crate::{
Thread,
thread::ThreadId,
thread_store::{SerializedThreadMetadata, ThreadStore},
};
use crate::thread_store::{SerializedThreadMetadata, ThreadStore};
const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;
const NAVIGATION_HISTORY_PATH: &str = "agent-navigation-history.json";
const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50);
#[derive(Clone, Debug)]
#[derive(Debug)]
pub enum HistoryEntry {
Thread(SerializedThreadMetadata),
Context(SavedContextMetadata),
@@ -36,52 +19,16 @@ impl HistoryEntry {
}
}
#[derive(Clone, Debug)]
pub(crate) enum RecentEntry {
Thread(ThreadId, Entity<Thread>),
Context(Entity<AssistantContext>),
}
impl PartialEq for RecentEntry {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Thread(l0, _), Self::Thread(r0, _)) => l0 == r0,
(Self::Context(l0), Self::Context(r0)) => l0 == r0,
_ => false,
}
}
}
impl Eq for RecentEntry {}
impl RecentEntry {
pub(crate) fn summary(&self, cx: &App) -> SharedString {
match self {
RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(),
RecentEntry::Context(context) => context.read(cx).summary_or_default(),
}
}
}
#[derive(Serialize, Deserialize)]
enum SerializedRecentEntry {
Thread(String),
Context(String),
}
pub struct HistoryStore {
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
recently_opened_entries: VecDeque<RecentEntry>,
_subscriptions: Vec<gpui::Subscription>,
_save_recently_opened_entries_task: Task<()>,
}
impl HistoryStore {
pub fn new(
thread_store: Entity<ThreadStore>,
context_store: Entity<assistant_context_editor::ContextStore>,
initial_recent_entries: impl IntoIterator<Item = RecentEntry>,
cx: &mut Context<Self>,
) -> Self {
let subscriptions = vec![
@@ -89,62 +36,10 @@ impl HistoryStore {
cx.observe(&context_store, |_, _, cx| cx.notify()),
];
cx.spawn({
let thread_store = thread_store.downgrade();
let context_store = context_store.downgrade();
async move |this, cx| {
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
let contents = cx
.background_spawn(async move { std::fs::read_to_string(path) })
.await
.context("reading persisted agent panel navigation history")?;
let entries = serde_json::from_str::<Vec<SerializedRecentEntry>>(&contents)
.context("deserializing persisted agent panel navigation history")?
.into_iter()
.take(MAX_RECENTLY_OPENED_ENTRIES)
.map(|serialized| match serialized {
SerializedRecentEntry::Thread(id) => thread_store
.update(cx, |thread_store, cx| {
let thread_id = ThreadId::from(id.as_str());
thread_store
.open_thread(&thread_id, cx)
.map_ok(|thread| RecentEntry::Thread(thread_id, thread))
.boxed()
})
.unwrap_or_else(|_| async { Err(anyhow!("no thread store")) }.boxed()),
SerializedRecentEntry::Context(id) => context_store
.update(cx, |context_store, cx| {
context_store
.open_local_context(Path::new(&id).into(), cx)
.map_ok(RecentEntry::Context)
.boxed()
})
.unwrap_or_else(|_| async { Err(anyhow!("no context store")) }.boxed()),
});
let entries = join_all(entries)
.await
.into_iter()
.filter_map(|result| result.log_err())
.collect::<VecDeque<_>>();
this.update(cx, |this, _| {
this.recently_opened_entries.extend(entries);
this.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
})
.ok();
anyhow::Ok(())
}
})
.detach_and_log_err(cx);
Self {
thread_store,
context_store,
recently_opened_entries: initial_recent_entries.into_iter().collect(),
_subscriptions: subscriptions,
_save_recently_opened_entries_task: Task::ready(()),
}
}
@@ -174,63 +69,4 @@ impl HistoryStore {
pub fn recent_entries(&self, limit: usize, cx: &mut Context<Self>) -> Vec<HistoryEntry> {
self.entries(cx).into_iter().take(limit).collect()
}
fn save_recently_opened_entries(&mut self, cx: &mut Context<Self>) {
let serialized_entries = self
.recently_opened_entries
.iter()
.filter_map(|entry| match entry {
RecentEntry::Context(context) => Some(SerializedRecentEntry::Context(
context.read(cx).path()?.to_str()?.to_owned(),
)),
RecentEntry::Thread(id, _) => Some(SerializedRecentEntry::Thread(id.to_string())),
})
.collect::<Vec<_>>();
self._save_recently_opened_entries_task = cx.spawn(async move |_, cx| {
cx.background_executor()
.timer(SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE)
.await;
cx.background_spawn(async move {
let path = paths::data_dir().join(NAVIGATION_HISTORY_PATH);
let content = serde_json::to_string(&serialized_entries)?;
std::fs::write(path, content)?;
anyhow::Ok(())
})
.await
.log_err();
});
}
pub fn push_recently_opened_entry(&mut self, entry: RecentEntry, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != &entry);
self.recently_opened_entries.push_front(entry);
self.recently_opened_entries
.truncate(MAX_RECENTLY_OPENED_ENTRIES);
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_thread(&mut self, id: ThreadId, cx: &mut Context<Self>) {
self.recently_opened_entries.retain(|entry| match entry {
RecentEntry::Thread(thread_id, _) if thread_id == &id => false,
_ => true,
});
self.save_recently_opened_entries(cx);
}
pub fn remove_recently_opened_entry(&mut self, entry: &RecentEntry, cx: &mut Context<Self>) {
self.recently_opened_entries
.retain(|old_entry| old_entry != entry);
self.save_recently_opened_entries(cx);
}
pub fn recently_opened_entries(&self, _cx: &mut Context<Self>) -> VecDeque<RecentEntry> {
#[cfg(debug_assertions)]
if std::env::var("ZED_SIMULATE_NO_THREAD_HISTORY").is_ok() {
return VecDeque::new();
}
self.recently_opened_entries.clone()
}
}

View File

@@ -63,62 +63,12 @@ pub struct MessageEditor {
edits_expanded: bool,
editor_is_expanded: bool,
last_estimated_token_count: Option<usize>,
update_token_count_task: Option<Task<()>>,
update_token_count_task: Option<Task<anyhow::Result<()>>>,
_subscriptions: Vec<Subscription>,
}
const MAX_EDITOR_LINES: usize = 8;
pub(crate) fn create_editor(
workspace: WeakEntity<Workspace>,
context_store: WeakEntity<ContextStore>,
thread_store: WeakEntity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<Editor> {
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace,
context_store,
Some(thread_store),
editor_entity,
))));
});
editor
}
impl MessageEditor {
pub fn new(
fs: Arc<dyn Fs>,
@@ -133,14 +83,47 @@ impl MessageEditor {
let context_picker_menu_handle = PopoverMenuHandle::default();
let model_selector_menu_handle = PopoverMenuHandle::default();
let editor = create_editor(
workspace.clone(),
context_store.downgrade(),
thread_store.clone(),
window,
cx,
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(
editor::EditorMode::AutoHeight {
max_lines: MAX_EDITOR_LINES,
},
buffer,
None,
window,
cx,
);
editor.set_placeholder_text("Ask anything, @ to mention, ↑ to select", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
let editor_entity = editor.downgrade();
editor.update(cx, |editor, _| {
editor.set_completion_provider(Some(Box::new(ContextPickerCompletionProvider::new(
workspace.clone(),
context_store.downgrade(),
Some(thread_store.clone()),
editor_entity,
))));
});
let context_strip = cx.new(|cx| {
ContextStrip::new(
context_store.clone(),
@@ -430,7 +413,7 @@ impl MessageEditor {
let active_completion_mode = thread.completion_mode();
Some(
IconButton::new("max-mode", IconName::ZedMaxMode)
IconButton::new("max-mode", IconName::SquarePlus)
.icon_size(IconSize::Small)
.toggle_state(active_completion_mode == Some(CompletionMode::Max))
.on_click(cx.listener(move |this, _event, _window, cx| {
@@ -441,7 +424,7 @@ impl MessageEditor {
});
});
}))
.tooltip(Tooltip::text("Toggle Max Mode"))
.tooltip(Tooltip::text("Max Mode"))
.into_any_element(),
)
}
@@ -1058,7 +1041,7 @@ impl MessageEditor {
let load_task = cx.spawn(async move |this, cx| {
let Ok(load_task) = this.update(cx, |this, cx| {
let new_context = this.context_store.read_with(cx, |context_store, cx| {
context_store.new_context_for_thread(this.thread.read(cx), None)
context_store.new_context_for_thread(this.thread.read(cx))
});
load_context(new_context, &this.project, &this.prompt_store, cx)
}) else {
@@ -1105,64 +1088,57 @@ impl MessageEditor {
.await;
}
let token_count = if let Some(task) = this
.update(cx, |this, cx| {
let loaded_context = this
.last_loaded_context
.as_ref()
.map(|context_load_result| &context_load_result.loaded_context);
let message_text = editor.read(cx).text(cx);
let token_count = if let Some(task) = this.update(cx, |this, cx| {
let loaded_context = this
.last_loaded_context
.as_ref()
.map(|context_load_result| &context_load_result.loaded_context);
let message_text = editor.read(cx).text(cx);
if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
{
return None;
}
if message_text.is_empty()
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
{
return None;
}
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
let mut request_message = LanguageModelRequestMessage {
role: language_model::Role::User,
content: Vec::new(),
cache: false,
};
if let Some(loaded_context) = loaded_context {
loaded_context.add_to_request_message(&mut request_message);
}
if let Some(loaded_context) = loaded_context {
loaded_context.add_to_request_message(&mut request_message);
}
if !message_text.is_empty() {
request_message
.content
.push(MessageContent::Text(message_text));
}
if !message_text.is_empty() {
request_message
.content
.push(MessageContent::Text(message_text));
}
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
mode: None,
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,
};
let request = language_model::LanguageModelRequest {
thread_id: None,
prompt_id: None,
mode: None,
messages: vec![request_message],
tools: vec![],
stop: vec![],
temperature: None,
};
Some(model.model.count_tokens(request, cx))
})
.ok()
.flatten()
{
task.await.log_err()
Some(model.model.count_tokens(request, cx))
})? {
task.await?
} else {
Some(0)
0
};
this.update(cx, |this, cx| {
if let Some(token_count) = token_count {
this.last_estimated_token_count = Some(token_count);
cx.emit(MessageEditorEvent::EstimatedTokenCount);
}
this.last_estimated_token_count = Some(token_count);
cx.emit(MessageEditorEvent::EstimatedTokenCount);
this.update_token_count_task.take();
})
.ok();
}));
}
}

View File

@@ -68,41 +68,30 @@ impl ProfileSelector {
menu = menu.header("Profiles");
for (profile_id, profile) in self.profiles.clone() {
let documentation = match profile.name.to_lowercase().as_str() {
"write" => Some("Get help to write anything."),
"ask" => Some("Chat about your codebase."),
"manual" => Some("Chat about anything; no tools."),
_ => None,
};
menu = menu.toggleable_entry(
profile.name.clone(),
profile_id == settings.default_profile,
icon_position,
None,
{
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
move |_window, cx| {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let profile_id = profile_id.clone();
move |settings, _cx| {
settings.set_profile(profile_id.clone());
}
});
let entry = ContextMenuEntry::new(profile.name.clone())
.toggleable(icon_position, profile_id == settings.default_profile);
let entry = if let Some(doc_text) = documentation {
entry.documentation_aside(move |_| Label::new(doc_text).into_any_element())
} else {
entry
};
menu = menu.item(entry.handler({
let fs = self.fs.clone();
let thread_store = self.thread_store.clone();
let profile_id = profile_id.clone();
move |_window, cx| {
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
let profile_id = profile_id.clone();
move |settings, _cx| {
settings.set_profile(profile_id.clone());
}
});
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}
}));
thread_store
.update(cx, |this, cx| {
this.load_profile_by_id(profile_id.clone(), cx);
})
.log_err();
}
},
);
}
menu = menu.separator();
@@ -152,7 +141,6 @@ impl Render for ProfileSelector {
let this = cx.entity().clone();
let focus_handle = self.focus_handle.clone();
PopoverMenu::new("profile-selector")
.menu(move |window, cx| {
Some(this.update(cx, |this, cx| this.build_context_menu(window, cx)))
@@ -195,7 +183,7 @@ impl Render for ProfileSelector {
)
.tooltip(Tooltip::text("The current model does not support tools."))
})
.anchor(gpui::Corner::BottomRight)
.anchor(gpui::Corner::BottomLeft)
.with_handle(self.menu_handle.clone())
}
}

View File

@@ -879,7 +879,6 @@ impl Thread {
id: MessageId,
new_role: Role,
new_segments: Vec<MessageSegment>,
loaded_context: Option<LoadedContext>,
cx: &mut Context<Self>,
) -> bool {
let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
@@ -887,9 +886,6 @@ impl Thread {
};
message.role = new_role;
message.segments = new_segments;
if let Some(context) = loaded_context {
message.loaded_context = context;
}
self.touch_updated_at();
cx.emit(ThreadEvent::MessageEdited(id));
true
@@ -2550,7 +2546,6 @@ fn main() {{
"file1.rs": "fn function1() {}\n",
"file2.rs": "fn function2() {}\n",
"file3.rs": "fn function3() {}\n",
"file4.rs": "fn function4() {}\n",
}),
)
.await;
@@ -2563,7 +2558,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), None)
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@@ -2578,7 +2573,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), None)
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@@ -2594,7 +2589,7 @@ fn main() {{
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), None)
store.new_context_for_thread(thread.read(cx))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
@@ -2645,55 +2640,6 @@ fn main() {{
assert!(!request.messages[3].string_contents().contains("file1.rs"));
assert!(!request.messages[3].string_contents().contains("file2.rs"));
assert!(request.messages[3].string_contents().contains("file3.rs"));
add_file_to_context(&project, &context_store, "test/file4.rs", cx)
.await
.unwrap();
let new_contexts = context_store.update(cx, |store, cx| {
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 3);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file4.rs
store.remove_context(&loaded_context.contexts[2].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 2);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
let new_contexts = context_store.update(cx, |store, cx| {
// Remove file3.rs
store.remove_context(&loaded_context.contexts[1].handle(), cx);
store.new_context_for_thread(thread.read(cx), Some(message2_id))
});
assert_eq!(new_contexts.len(), 1);
let loaded_context = cx
.update(|cx| load_context(new_contexts, &project, &None, cx))
.await
.loaded_context;
assert!(!loaded_context.text.contains("file1.rs"));
assert!(loaded_context.text.contains("file2.rs"));
assert!(!loaded_context.text.contains("file3.rs"));
assert!(!loaded_context.text.contains("file4.rs"));
}
#[gpui::test]
@@ -2904,8 +2850,8 @@ fn main() {{
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
None,
Arc::new(PromptBuilder::new(None).unwrap()),
None,
cx,
)
})

View File

@@ -270,9 +270,9 @@ impl ThreadHistory {
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
if let Some(entry) = self.get_match(self.selected_index) {
let task_result = match entry {
HistoryEntry::Thread(thread) => self.assistant_panel.update(cx, move |this, cx| {
this.open_thread_by_id(&thread.id, window, cx)
}),
HistoryEntry::Thread(thread) => self
.assistant_panel
.update(cx, move |this, cx| this.open_thread(&thread.id, window, cx)),
HistoryEntry::Context(context) => {
self.assistant_panel.update(cx, move |this, cx| {
this.open_saved_prompt_editor(context.path.clone(), window, cx)
@@ -525,8 +525,7 @@ impl RenderOnce for PastThread {
move |_event, window, cx| {
assistant_panel
.update(cx, |this, cx| {
this.open_thread_by_id(&id, window, cx)
.detach_and_log_err(cx);
this.open_thread(&id, window, cx).detach_and_log_err(cx);
})
.ok();
}

View File

@@ -81,8 +81,8 @@ impl ThreadStore {
pub fn load(
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_store: Option<Entity<PromptStore>>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
cx: &mut App,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {

View File

@@ -98,6 +98,10 @@ impl RenderOnce for UsageBanner {
}
impl Component for UsageBanner {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn sort_name() -> &'static str {
"AgentUsageBanner"
}

View File

@@ -33,7 +33,6 @@ use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
use std::ops::Range;
use std::path::Path;
use std::{ops::ControlFlow, path::PathBuf, sync::Arc};
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
use ui::{ContextMenu, PopoverMenu, Tooltip, prelude::*};
@@ -1081,7 +1080,7 @@ impl AssistantPanel {
pub fn open_saved_context(
&mut self,
path: Arc<Path>,
path: PathBuf,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
@@ -1392,7 +1391,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
fn open_saved_context(
&self,
workspace: &mut Workspace,
path: Arc<Path>,
path: PathBuf,
window: &mut Window,
cx: &mut Context<Workspace>,
) -> Task<Result<()>> {

View File

@@ -35,7 +35,7 @@ use std::{
fmt::Debug,
iter, mem,
ops::Range,
path::Path,
path::{Path, PathBuf},
str::FromStr as _,
sync::Arc,
time::{Duration, Instant},
@@ -46,7 +46,7 @@ use ui::IconName;
use util::{ResultExt, TryFutureExt, post_inc};
use uuid::Uuid;
#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ContextId(String);
impl ContextId {
@@ -648,7 +648,7 @@ pub struct AssistantContext {
pending_token_count: Task<Option<()>>,
pending_save: Task<Result<()>>,
pending_cache_warming_task: Task<Option<()>>,
path: Option<Arc<Path>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
telemetry: Option<Arc<Telemetry>>,
language_registry: Arc<LanguageRegistry>,
@@ -839,7 +839,7 @@ impl AssistantContext {
pub fn deserialize(
saved_context: SavedContext,
path: Arc<Path>,
path: PathBuf,
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
@@ -1147,8 +1147,8 @@ impl AssistantContext {
self.prompt_builder.clone()
}
pub fn path(&self) -> Option<&Arc<Path>> {
self.path.as_ref()
pub fn path(&self) -> Option<&Path> {
self.path.as_deref()
}
pub fn summary(&self) -> Option<&ContextSummary> {
@@ -3181,7 +3181,7 @@ impl AssistantContext {
fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
.await?;
if let Some(old_path) = old_path {
if new_path.as_path() != old_path.as_ref() {
if new_path != old_path {
fs.remove_file(
&old_path,
RemoveOptions {
@@ -3193,7 +3193,7 @@ impl AssistantContext {
}
}
this.update(cx, |this, _| this.path = Some(new_path.into()))?;
this.update(cx, |this, _| this.path = Some(new_path))?;
}
Ok(())
@@ -3589,6 +3589,6 @@ impl SavedContextV0_1_0 {
#[derive(Debug, Clone)]
pub struct SavedContextMetadata {
pub title: String,
pub path: Arc<Path>,
pub path: PathBuf,
pub mtime: chrono::DateTime<chrono::Local>,
}

View File

@@ -959,7 +959,7 @@ async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
let deserialized_context = cx.new(|cx| {
AssistantContext::deserialize(
serialized_context,
Path::new("").into(),
Default::default(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),
@@ -1120,7 +1120,7 @@ async fn test_serialization(cx: &mut TestAppContext) {
let deserialized_context = cx.new(|cx| {
AssistantContext::deserialize(
serialized_context,
Path::new("").into(),
Default::default(),
registry.clone(),
prompt_builder.clone(),
Arc::new(SlashCommandWorkingSet::default()),

View File

@@ -48,14 +48,7 @@ use project::{Project, Worktree};
use rope::Point;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore, update_settings_file};
use std::{
any::TypeId,
cmp,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use std::{any::TypeId, cmp, ops::Range, path::PathBuf, sync::Arc, time::Duration};
use text::SelectionGoal;
use ui::{
ButtonLike, Disclosure, ElevationIndex, KeyBinding, PopoverMenuHandle, TintColor, Tooltip,
@@ -146,7 +139,7 @@ pub trait AssistantPanelDelegate {
fn open_saved_context(
&self,
workspace: &mut Workspace,
path: Arc<Path>,
path: PathBuf,
window: &mut Window,
cx: &mut Context<Workspace>,
) -> Task<Result<()>>;

View File

@@ -20,7 +20,14 @@ use prompt_store::PromptBuilder;
use regex::Regex;
use rpc::AnyProtoClient;
use std::sync::LazyLock;
use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
use std::{
cmp::Reverse,
ffi::OsStr,
mem,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use util::{ResultExt, TryFutureExt};
pub(crate) fn init(client: &AnyProtoClient) {
@@ -423,7 +430,7 @@ impl ContextStore {
pub fn open_local_context(
&mut self,
path: Arc<Path>,
path: PathBuf,
cx: &Context<Self>,
) -> Task<Result<Entity<AssistantContext>>> {
if let Some(existing_context) = self.loaded_context_for_path(&path, cx) {
@@ -471,7 +478,7 @@ impl ContextStore {
pub fn delete_local_context(
&mut self,
path: Arc<Path>,
path: PathBuf,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let fs = self.fs.clone();
@@ -494,7 +501,7 @@ impl ContextStore {
!= Some(&path)
});
this.contexts_metadata
.retain(|context| context.path.as_ref() != path.as_ref());
.retain(|context| context.path != path);
})?;
Ok(())
@@ -504,7 +511,7 @@ impl ContextStore {
fn loaded_context_for_path(&self, path: &Path, cx: &App) -> Option<Entity<AssistantContext>> {
self.contexts.iter().find_map(|context| {
let context = context.upgrade()?;
if context.read(cx).path().map(Arc::as_ref) == Some(path) {
if context.read(cx).path() == Some(path) {
Some(context)
} else {
None
@@ -787,7 +794,7 @@ impl ContextStore {
{
contexts.push(SavedContextMetadata {
title: title.to_string(),
path: path.into(),
path,
mtime: metadata.mtime.timestamp_for_user().into(),
});
}

View File

@@ -6,7 +6,7 @@ use ::open_ai::Model as OpenAiModel;
use anthropic::Model as AnthropicModel;
use anyhow::{Result, bail};
use deepseek::Model as DeepseekModel;
use feature_flags::{AgentStreamEditsFeatureFlag, Assistant2FeatureFlag, FeatureFlagAppExt};
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
use gpui::{App, Pixels};
use indexmap::IndexMap;
use language_model::{CloudModel, LanguageModel};
@@ -87,14 +87,9 @@ pub struct AssistantSettings {
pub profiles: IndexMap<AgentProfileId, AgentProfile>,
pub always_allow_tool_actions: bool,
pub notify_when_agent_waiting: NotifyWhenAgentWaiting,
pub stream_edits: bool,
}
impl AssistantSettings {
pub fn stream_edits(&self, cx: &App) -> bool {
cx.has_flag::<AgentStreamEditsFeatureFlag>() || self.stream_edits
}
pub fn are_live_diffs_enabled(&self, cx: &App) -> bool {
if cx.has_flag::<Assistant2FeatureFlag>() {
return false;
@@ -223,7 +218,6 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
@@ -251,7 +245,6 @@ impl AssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
None => AssistantSettingsContentV2::default(),
}
@@ -502,7 +495,6 @@ impl Default for VersionedAssistantSettingsContent {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
})
}
}
@@ -558,10 +550,6 @@ pub struct AssistantSettingsContentV2 {
///
/// Default: "primary_screen"
notify_when_agent_waiting: Option<NotifyWhenAgentWaiting>,
/// Whether to stream edits from the agent as they are received.
///
/// Default: false
stream_edits: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -724,7 +712,6 @@ impl Settings for AssistantSettings {
&mut settings.notify_when_agent_waiting,
value.notify_when_agent_waiting,
);
merge(&mut settings.stream_edits, value.stream_edits);
merge(&mut settings.default_profile, value.default_profile);
if let Some(profiles) = value.profiles {
@@ -856,7 +843,6 @@ mod tests {
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
stream_edits: None,
},
)),
}

View File

@@ -24,7 +24,6 @@ language.workspace = true
language_model.workspace = true
parking_lot.workspace = true
project.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
text.workspace = true

View File

@@ -1,5 +1,4 @@
mod action_log;
pub mod outline;
mod tool_registry;
mod tool_schema;
mod tool_working_set;
@@ -52,13 +51,6 @@ impl ToolUseStatus {
ToolUseStatus::Error(out) => out.clone(),
}
}
pub fn error(&self) -> Option<SharedString> {
match self {
ToolUseStatus::Error(out) => Some(out.clone()),
_ => None,
}
}
}
/// The result of running a tool, containing both the asynchronous output

View File

@@ -1,132 +0,0 @@
use crate::ActionLog;
use anyhow::{Result, anyhow};
use gpui::{AsyncApp, Entity};
use language::{OutlineItem, ParseStatus};
use project::Project;
use regex::Regex;
use std::fmt::Write;
use text::Point;
/// For files over this size, instead of reading them (or including them in context),
/// we automatically provide the file's symbol outline instead, with line numbers.
pub const AUTO_OUTLINE_SIZE: usize = 16384;
pub async fn file_outline(
project: Entity<Project>,
path: String,
action_log: Entity<ActionLog>,
regex: Option<Regex>,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&path, cx)
.ok_or_else(|| anyhow!("Path {path} not found in project"))
})??;
project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {
return Err(anyhow!("No outline information available for this file."));
};
render_outline(
outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot)),
regex,
0,
usize::MAX,
)
.await
}
pub async fn render_outline(
items: impl IntoIterator<Item = OutlineItem<Point>>,
regex: Option<Regex>,
offset: usize,
results_per_page: usize,
) -> Result<String> {
let mut items = items.into_iter().skip(offset);
let entries = items
.by_ref()
.filter(|item| {
regex
.as_ref()
.is_none_or(|regex| regex.is_match(&item.text))
})
.take(results_per_page)
.collect::<Vec<_>>();
let has_more = items.next().is_some();
let mut output = String::new();
let entries_rendered = render_entries(&mut output, entries);
// Calculate pagination information
let page_start = offset + 1;
let page_end = offset + entries_rendered;
let total_symbols = if has_more {
format!("more than {}", page_end)
} else {
page_end.to_string()
};
// Add pagination information
if has_more {
writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
)
} else {
writeln!(
&mut output,
"\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
)
}
.ok();
Ok(output)
}
fn render_entries(
output: &mut String,
items: impl IntoIterator<Item = OutlineItem<Point>>,
) -> usize {
let mut entries_rendered = 0;
for item in items {
// Indent based on depth ("" for level 0, " " for level 1, etc.)
for _ in 0..item.depth {
output.push(' ');
}
output.push_str(&item.text);
// Add position information - convert to 1-based line numbers for display
let start_line = item.range.start.row + 1;
let end_line = item.range.end.row + 1;
if start_line == end_line {
writeln!(output, " [L{}]", start_line).ok();
} else {
writeln!(output, " [L{}-{}]", start_line, end_line).ok();
}
entries_rendered += 1;
}
entries_rendered
}

View File

@@ -11,24 +11,16 @@ workspace = true
[lib]
path = "src/assistant_tools.rs"
[features]
eval = []
[dependencies]
aho-corasick.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_settings.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
derive_more.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
handlebars = { workspace = true, features = ["rust-embed"] }
html_to_markdown.workspace = true
http_client.workspace = true
indoc.workspace = true
@@ -39,17 +31,9 @@ linkme.workspace = true
open.workspace = true
project.workspace = true
regex.workspace = true
rust-embed.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
streaming_diff.workspace = true
strsim.workspace = true
task.workspace = true
terminal.workspace = true
terminal_view.workspace = true
ui.workspace = true
util.workspace = true
web_search.workspace = true
@@ -62,18 +46,11 @@ client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true
fs = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
language_models.workspace = true
project = { workspace = true, features = ["test-support"] }
rand.workspace = true
pretty_assertions.workspace = true
reqwest_client.workspace = true
settings = { workspace = true, features = ["test-support"] }
task = { workspace = true, features = ["test-support"]}
tempfile.workspace = true
tree-sitter-rust.workspace = true
workspace = { workspace = true, features = ["test-support"] }
unindent.workspace = true

View File

@@ -7,7 +7,6 @@ mod create_directory_tool;
mod create_file_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_agent;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
@@ -20,9 +19,7 @@ mod read_file_tool;
mod rename_tool;
mod replace;
mod schema;
mod streaming_edit_file_tool;
mod symbol_info_tool;
mod templates;
mod terminal_tool;
mod thinking_tool;
mod ui;
@@ -30,19 +27,14 @@ mod web_search_tool;
use std::sync::Arc;
use assistant_settings::AssistantSettings;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use settings::{Settings, SettingsStore};
use web_search_tool::WebSearchTool;
pub(crate) use templates::*;
use crate::batch_tool::BatchTool;
use crate::code_action_tool::CodeActionTool;
use crate::code_symbols_tool::CodeSymbolsTool;
@@ -60,7 +52,6 @@ use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::streaming_edit_file_tool::StreamingEditFileTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
@@ -80,6 +71,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(CreateFileTool);
registry.register_tool(CopyPathTool);
registry.register_tool(DeletePathTool);
registry.register_tool(EditFileTool);
registry.register_tool(SymbolInfoTool);
registry.register_tool(CodeActionTool);
registry.register_tool(MovePathTool);
@@ -96,12 +88,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
register_edit_file_tool(cx);
cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
.detach();
cx.observe_global::<SettingsStore>(register_edit_file_tool)
.detach();
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
@@ -122,19 +108,6 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
.detach();
}
fn register_edit_file_tool(cx: &mut App) {
let registry = ToolRegistry::global(cx);
registry.unregister_tool(EditFileTool);
registry.unregister_tool(StreamingEditFileTool);
if AssistantSettings::get_global(cx).stream_edits(cx) {
registry.register_tool(StreamingEditFileTool);
} else {
registry.register_tool(EditFileTool);
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -173,7 +146,6 @@ mod tests {
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
settings::init(cx);
AssistantSettings::register(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),

View File

@@ -4,10 +4,10 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, Symbol};
use regex::{Regex, RegexBuilder};
@@ -148,13 +148,59 @@ impl Tool for CodeSymbolsTool {
};
cx.spawn(async move |cx| match input.path {
Some(path) => outline::file_outline(project, path, action_log, regex, cx).await,
Some(path) => file_outline(project, path, action_log, regex, cx).await,
None => project_symbols(project, regex, input.offset, cx).await,
})
.into()
}
}
pub async fn file_outline(
project: Entity<Project>,
path: String,
action_log: Entity<ActionLog>,
regex: Option<Regex>,
cx: &mut AsyncApp,
) -> anyhow::Result<String> {
let buffer = {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&path, cx)
.ok_or_else(|| anyhow!("Path {path} not found in project"))
})??;
project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
})?;
// Wait until the buffer has been fully parsed, so that we can read its outline.
let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
let Some(outline) = snapshot.outline(None) else {
return Err(anyhow!("No outline information available for this file."));
};
render_outline(
outline
.items
.into_iter()
.map(|item| item.to_point(&snapshot)),
regex,
0,
usize::MAX,
)
.await
}
async fn project_symbols(
project: Entity<Project>,
regex: Option<Regex>,
@@ -245,3 +291,77 @@ async fn project_symbols(
output
})
}
async fn render_outline(
items: impl IntoIterator<Item = OutlineItem<Point>>,
regex: Option<Regex>,
offset: usize,
results_per_page: usize,
) -> Result<String> {
let mut items = items.into_iter().skip(offset);
let entries = items
.by_ref()
.filter(|item| {
regex
.as_ref()
.is_none_or(|regex| regex.is_match(&item.text))
})
.take(results_per_page)
.collect::<Vec<_>>();
let has_more = items.next().is_some();
let mut output = String::new();
let entries_rendered = render_entries(&mut output, entries);
// Calculate pagination information
let page_start = offset + 1;
let page_end = offset + entries_rendered;
let total_symbols = if has_more {
format!("more than {}", page_end)
} else {
page_end.to_string()
};
// Add pagination information
if has_more {
writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
)
} else {
writeln!(
&mut output,
"\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
)
}
.ok();
Ok(output)
}
fn render_entries(
output: &mut String,
items: impl IntoIterator<Item = OutlineItem<Point>>,
) -> usize {
let mut entries_rendered = 0;
for item in items {
// Indent based on depth ("" for level 0, " " for level 1, etc.)
for _ in 0..item.depth {
output.push(' ');
}
output.push_str(&item.text);
// Add position information - convert to 1-based line numbers for display
let start_line = item.range.start.row + 1;
let end_line = item.range.end.row + 1;
if start_line == end_line {
writeln!(output, " [L{}]", start_line).ok();
} else {
writeln!(output, " [L{}-{}]", start_line, end_line).ok();
}
entries_rendered += 1;
}
entries_rendered
}

View File

@@ -1,8 +1,8 @@
use std::sync::Arc;
use crate::schema::json_schema_for;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, outline};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -14,6 +14,10 @@ use ui::IconName;
use util::markdown::MarkdownInlineCode;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return the file's symbol outline instead of its contents,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
/// If the model requests to list the entries in a directory with more
/// entries than this, then the tool will return a subset of the entries
/// and suggest trying again.
@@ -214,7 +218,7 @@ impl Tool for ContentsTool {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= outline::AUTO_OUTLINE_SIZE {
if file_size <= MAX_FILE_SIZE_TO_READ {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
@@ -225,7 +229,7 @@ impl Tool for ContentsTool {
} else {
// File is too big, so return its outline and a suggestion to
// read again with a line number range specified.
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
let outline = file_outline(project, file_path, action_log, None, cx).await?;
Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline."))
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,408 +0,0 @@
use derive_more::{Add, AddAssign};
use smallvec::SmallVec;
use std::{cmp, mem, ops::Range};
const OLD_TEXT_END_TAG: &str = "</old_text>";
const NEW_TEXT_END_TAG: &str = "</new_text>";
const END_TAG_LEN: usize = OLD_TEXT_END_TAG.len();
const _: () = debug_assert!(OLD_TEXT_END_TAG.len() == NEW_TEXT_END_TAG.len());
#[derive(Debug)]
pub enum EditParserEvent {
OldText(String),
NewTextChunk { chunk: String, done: bool },
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Add, AddAssign)]
pub struct EditParserMetrics {
pub tags: usize,
pub mismatched_tags: usize,
}
#[derive(Debug)]
pub struct EditParser {
state: EditParserState,
buffer: String,
metrics: EditParserMetrics,
}
#[derive(Debug, PartialEq)]
enum EditParserState {
Pending,
WithinOldText,
AfterOldText,
WithinNewText { start: bool },
}
impl EditParser {
pub fn new() -> Self {
EditParser {
state: EditParserState::Pending,
buffer: String::new(),
metrics: EditParserMetrics::default(),
}
}
pub fn push(&mut self, chunk: &str) -> SmallVec<[EditParserEvent; 1]> {
self.buffer.push_str(chunk);
let mut edit_events = SmallVec::new();
loop {
match &mut self.state {
EditParserState::Pending => {
if let Some(start) = self.buffer.find("<old_text>") {
self.buffer.drain(..start + "<old_text>".len());
self.state = EditParserState::WithinOldText;
} else {
break;
}
}
EditParserState::WithinOldText => {
if let Some(tag_range) = self.find_end_tag() {
let mut start = 0;
if self.buffer.starts_with('\n') {
start = 1;
}
let mut old_text = self.buffer[start..tag_range.start].to_string();
if old_text.ends_with('\n') {
old_text.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != OLD_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::AfterOldText;
edit_events.push(EditParserEvent::OldText(old_text));
} else {
break;
}
}
EditParserState::AfterOldText => {
if let Some(start) = self.buffer.find("<new_text>") {
self.buffer.drain(..start + "<new_text>".len());
self.state = EditParserState::WithinNewText { start: true };
} else {
break;
}
}
EditParserState::WithinNewText { start } => {
if !self.buffer.is_empty() {
if *start && self.buffer.starts_with('\n') {
self.buffer.remove(0);
}
*start = false;
}
if let Some(tag_range) = self.find_end_tag() {
let mut chunk = self.buffer[..tag_range.start].to_string();
if chunk.ends_with('\n') {
chunk.pop();
}
self.metrics.tags += 1;
if &self.buffer[tag_range.clone()] != NEW_TEXT_END_TAG {
self.metrics.mismatched_tags += 1;
}
self.buffer.drain(..tag_range.end);
self.state = EditParserState::Pending;
edit_events.push(EditParserEvent::NewTextChunk { chunk, done: true });
} else {
let mut end_prefixes = (1..END_TAG_LEN)
.flat_map(|i| [&NEW_TEXT_END_TAG[..i], &OLD_TEXT_END_TAG[..i]])
.chain(["\n"]);
if end_prefixes.all(|prefix| !self.buffer.ends_with(&prefix)) {
edit_events.push(EditParserEvent::NewTextChunk {
chunk: mem::take(&mut self.buffer),
done: false,
});
}
break;
}
}
}
}
edit_events
}
fn find_end_tag(&self) -> Option<Range<usize>> {
let old_text_end_tag_ix = self.buffer.find(OLD_TEXT_END_TAG);
let new_text_end_tag_ix = self.buffer.find(NEW_TEXT_END_TAG);
let start_ix = if let Some((old_text_ix, new_text_ix)) =
old_text_end_tag_ix.zip(new_text_end_tag_ix)
{
cmp::min(old_text_ix, new_text_ix)
} else {
old_text_end_tag_ix.or(new_text_end_tag_ix)?
};
Some(start_ix..start_ix + END_TAG_LEN)
}
pub fn finish(self) -> EditParserMetrics {
self.metrics
}
}
#[cfg(test)]
mod tests {
use super::*;
use indoc::indoc;
use rand::prelude::*;
use std::cmp;
#[gpui::test(iterations = 1000)]
fn test_single_edit(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>original</old_text><new_text>updated</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "original".to_string(),
new_text: "updated".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_multiple_edits(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
<old_text>
first old
</old_text><new_text>first new</new_text>
<old_text>second old</old_text><new_text>
second new
</new_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "first old".to_string(),
new_text: "first new".to_string(),
},
Edit {
old_text: "second old".to_string(),
new_text: "second new".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_edits_with_extra_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
indoc! {"
ignore this <old_text>
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
more text <old_text>second item
</old_text>middle text<new_text>modified second item</new_text>end
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "content".to_string(),
new_text: "updated content".to_string(),
},
Edit {
old_text: "second item".to_string(),
new_text: "modified second item".to_string(),
},
Edit {
old_text: "third case".to_string(),
new_text: "improved third case".to_string(),
},
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 6,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_nested_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "code with <tag>nested</tag> elements".to_string(),
new_text: "new <code>content</code>".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_empty_old_and_new_text(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text></old_text><new_text></new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "".to_string(),
new_text: "".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 100)]
fn test_multiline_content(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
&mut parser,
&mut rng
),
vec![Edit {
old_text: "line1\nline2\nline3".to_string(),
new_text: "line1\nmodified line2\nline3".to_string(),
}]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 2,
mismatched_tags: 0
}
);
}
#[gpui::test(iterations = 1000)]
fn test_mismatched_tags(mut rng: StdRng) {
let mut parser = EditParser::new();
assert_eq!(
parse_random_chunks(
// Reduced from an actual Sonnet 3.7 output
indoc! {"
<old_text>
a
b
c
</new_text>
<new_text>
a
B
c
</old_text>
<old_text>
d
e
f
</new_text>
<new_text>
D
e
F
</old_text>
"},
&mut parser,
&mut rng
),
vec![
Edit {
old_text: "a\nb\nc".to_string(),
new_text: "a\nB\nc".to_string(),
},
Edit {
old_text: "d\ne\nf".to_string(),
new_text: "D\ne\nF".to_string(),
}
]
);
assert_eq!(
parser.finish(),
EditParserMetrics {
tags: 4,
mismatched_tags: 4
}
);
}
#[derive(Default, Debug, PartialEq, Eq)]
struct Edit {
old_text: String,
new_text: String,
}
fn parse_random_chunks(input: &str, parser: &mut EditParser, rng: &mut StdRng) -> Vec<Edit> {
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
chunk_indices.sort();
chunk_indices.push(input.len());
let mut pending_edit = Edit::default();
let mut edits = Vec::new();
let mut last_ix = 0;
for chunk_ix in chunk_indices {
for event in parser.push(&input[last_ix..chunk_ix]) {
match event {
EditParserEvent::OldText(old_text) => {
pending_edit.old_text = old_text;
}
EditParserEvent::NewTextChunk { chunk, done } => {
pending_edit.new_text.push_str(&chunk);
if done {
edits.push(pending_edit);
pending_edit = Edit::default();
}
}
}
}
last_ix = chunk_ix;
}
edits
}
}

View File

@@ -1,889 +0,0 @@
use super::*;
use crate::{
ReadFileToolInput, grep_tool::GrepToolInput,
streaming_edit_file_tool::StreamingEditFileToolInput,
};
use Role::*;
use anyhow::{Context, anyhow};
use client::{Client, UserStore};
use collections::HashMap;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::{
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
};
use project::Project;
use rand::prelude::*;
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::{
cmp::Reverse,
fmt::{self, Display},
io::Write as _,
sync::mpsc,
};
use util::path;
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_extract_handle_command_output() {
let input_file_path = "root/blame.rs";
let input_file_content = include_str!("evals/fixtures/extract_handle_command_output/before.rs");
let output_file_content = include_str!("evals/fixtures/extract_handle_command_output/after.rs");
let edit_description = "Extract `handle_command_output` method from `run_git_blame`.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and extract a method in
the final stanza of `run_git_blame` to deal with command failures,
call it `handle_command_output` and take the std::process::Output as the only parameter.
Add it right next to `run_git_blame` and copy it verbatim from `run_git_blame`.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_delete_run_git_blame() {
let input_file_path = "root/blame.rs";
let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs");
let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs");
let edit_description = "Delete the `run_git_blame` function.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and delete `run_git_blame`. Just that
one function, not its usages.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
let input_file_path = "root/lib.rs";
let input_file_content =
include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs");
let edit_description = "Update compile_parser_to_wasm to use wasi-sdk instead of emscripten";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
Use `ureq` to download the SDK for the current platform and architecture.
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
Compile the parser to wasm using the `bin/clang` executable (or `bin/clang.exe` on windows)
that's inside of the archive.
Don't re-download the SDK if that executable already exists.
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
Here are the available wasi-sdk assets:
- wasi-sdk-25.0-x86_64-macos.tar.gz
- wasi-sdk-25.0-arm64-macos.tar.gz
- wasi-sdk-25.0-x86_64-linux.tar.gz
- wasi-sdk-25.0-arm64-linux.tar.gz
- wasi-sdk-25.0-x86_64-linux.tar.gz
- wasi-sdk-25.0-arm64-linux.tar.gz
- wasi-sdk-25.0-x86_64-windows.tar.gz
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(971),
end_line: Some(1050),
},
)],
),
message(
User,
[tool_result(
"tool_1",
"read_file",
lines(input_file_content, 971..1050),
)],
),
message(
Assistant,
[tool_use(
"tool_2",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(1050),
end_line: Some(1100),
},
)],
),
message(
User,
[tool_result(
"tool_2",
"read_file",
lines(input_file_content, 1050..1100),
)],
),
message(
Assistant,
[tool_use(
"tool_3",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: Some(1100),
end_line: Some(1150),
},
)],
),
message(
User,
[tool_result(
"tool_3",
"read_file",
lines(input_file_content, 1100..1150),
)],
),
message(
Assistant,
[tool_use(
"tool_4",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
- The compile_parser_to_wasm method has been changed to use wasi-sdk
- ureq is used to download the SDK for current platform and architecture
"}),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_disable_cursor_blinking() {
let input_file_path = "root/editor.rs";
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
let edit_description = "Comment out the call to `BlinkManager::enable`";
eval(
100,
0.6, // TODO: make this eval better
EvalInput {
conversation: vec![
message(User, [text("Let's research how to cursor blinking works.")]),
message(
Assistant,
[tool_use(
"tool_1",
"grep",
GrepToolInput {
regex: "blink".into(),
include_pattern: None,
offset: 0,
case_sensitive: false,
},
)],
),
message(
User,
[tool_result(
"tool_1",
"grep",
[
lines(input_file_content, 100..400),
lines(input_file_content, 800..1300),
lines(input_file_content, 1600..2000),
lines(input_file_content, 5000..5500),
lines(input_file_content, 8000..9000),
lines(input_file_content, 18455..18470),
lines(input_file_content, 20000..20500),
lines(input_file_content, 21000..21300),
]
.join("Match found:\n\n"),
)],
),
message(
User,
[text(indoc! {"
Comment out the lines that interact with the BlinkManager.
Keep the outer `update` blocks, but comments everything that's inside (including if statements).
Don't add additional comments.
"})],
),
message(
Assistant,
[tool_use(
"tool_4",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::AssertEqual(output_file_content.into()),
},
);
}
#[test]
#[cfg_attr(not(feature = "eval"), ignore)]
fn eval_from_pixels_constructor() {
let input_file_path = "root/canvas.rs";
let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs");
let edit_description = "Implement from_pixels constructor and add tests.";
eval(
100,
0.95,
EvalInput {
conversation: vec![
message(
User,
[text(indoc! {"
Introduce a new `from_pixels` constructor in Canvas and
also add tests for it in the same file.
"})],
),
message(
Assistant,
[tool_use(
"tool_1",
"read_file",
ReadFileToolInput {
path: input_file_path.into(),
start_line: None,
end_line: None,
},
)],
),
message(
User,
[tool_result("tool_1", "read_file", input_file_content)],
),
message(
Assistant,
[tool_use(
"tool_2",
"grep",
GrepToolInput {
regex: "mod\\s+tests".into(),
include_pattern: Some("font-kit/src/canvas.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(User, [tool_result("tool_2", "grep", "No matches found")]),
message(
Assistant,
[tool_use(
"tool_3",
"grep",
GrepToolInput {
regex: "mod\\s+tests".into(),
include_pattern: Some("font-kit/src/**/*.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(User, [tool_result("tool_3", "grep", "No matches found")]),
message(
Assistant,
[tool_use(
"tool_4",
"grep",
GrepToolInput {
regex: "#\\[test\\]".into(),
include_pattern: Some("font-kit/src/**/*.rs".into()),
offset: 0,
case_sensitive: false,
},
)],
),
message(
User,
[tool_result(
"tool_4",
"grep",
indoc! {"
Found 6 matches:
## Matches in font-kit/src/loaders/core_text.rs
### mod test L926-936
```
mod test {
use super::Font;
use crate::properties::{Stretch, Weight};
#[cfg(feature = \"source\")]
use crate::source::SystemSource;
static TEST_FONT_POSTSCRIPT_NAME: &'static str = \"ArialMT\";
#[cfg(feature = \"source\")]
#[test]
```
55 lines remaining in ancestor node. Read the file to see all.
### mod test L947-951
```
}
#[test]
fn test_core_text_to_css_font_weight() {
// Exact matches
```
### mod test L959-963
```
}
#[test]
fn test_core_text_to_css_font_stretch() {
// Exact matches
```
## Matches in font-kit/src/loaders/freetype.rs
### mod test L1238-1248
```
mod test {
use crate::loaders::freetype::Font;
static PCF_FONT_PATH: &str = \"resources/tests/times-roman-pcf/timR12.pcf\";
static PCF_FONT_POSTSCRIPT_NAME: &str = \"Times-Roman\";
#[test]
fn get_pcf_postscript_name() {
let font = Font::from_path(PCF_FONT_PATH, 0).unwrap();
assert_eq!(font.postscript_name().unwrap(), PCF_FONT_POSTSCRIPT_NAME);
}
```
1 lines remaining in ancestor node. Read the file to see all.
## Matches in font-kit/src/sources/core_text.rs
### mod test L265-275
```
mod test {
use crate::properties::{Stretch, Weight};
#[test]
fn test_css_to_core_text_font_weight() {
// Exact matches
assert_eq!(super::css_to_core_text_font_weight(Weight(100.0)), -0.7);
assert_eq!(super::css_to_core_text_font_weight(Weight(400.0)), 0.0);
assert_eq!(super::css_to_core_text_font_weight(Weight(700.0)), 0.4);
assert_eq!(super::css_to_core_text_font_weight(Weight(900.0)), 0.8);
```
27 lines remaining in ancestor node. Read the file to see all.
### mod test L278-282
```
}
#[test]
fn test_css_to_core_text_font_stretch() {
// Exact matches
```
"},
)],
),
message(
Assistant,
[tool_use(
"tool_5",
"edit_file",
StreamingEditFileToolInput {
display_description: edit_description.into(),
path: input_file_path.into(),
},
)],
),
],
input_path: input_file_path.into(),
input_content: input_file_content.into(),
edit_description: edit_description.into(),
assertion: EvalAssertion::JudgeDiff(indoc! {"
- The diff contains a new `from_pixels` constructor
- The diff contains new tests for the `from_pixels` constructor
"}),
},
);
}
fn message(
role: Role,
contents: impl IntoIterator<Item = MessageContent>,
) -> LanguageModelRequestMessage {
LanguageModelRequestMessage {
role,
content: contents.into_iter().collect(),
cache: false,
}
}
fn text(text: impl Into<String>) -> MessageContent {
MessageContent::Text(text.into())
}
fn lines(input: &str, range: Range<usize>) -> String {
input
.lines()
.skip(range.start)
.take(range.len())
.collect::<Vec<_>>()
.join("\n")
}
fn tool_use(
id: impl Into<Arc<str>>,
name: impl Into<Arc<str>>,
input: impl Serialize,
) -> MessageContent {
MessageContent::ToolUse(LanguageModelToolUse {
id: LanguageModelToolUseId::from(id.into()),
name: name.into(),
raw_input: serde_json::to_string_pretty(&input).unwrap(),
input: serde_json::to_value(input).unwrap(),
is_input_complete: true,
})
}
fn tool_result(
id: impl Into<Arc<str>>,
name: impl Into<Arc<str>>,
result: impl Into<Arc<str>>,
) -> MessageContent {
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: LanguageModelToolUseId::from(id.into()),
tool_name: name.into(),
is_error: false,
content: result.into(),
})
}
#[derive(Clone)]
struct EvalInput {
conversation: Vec<LanguageModelRequestMessage>,
input_path: PathBuf,
input_content: String,
edit_description: String,
assertion: EvalAssertion,
}
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
let mut evaluated_count = 0;
report_progress(evaluated_count, iterations);
let (tx, rx) = mpsc::channel();
// Cache the last message in the conversation, and run one instance of the eval so that
// all the next ones are cached.
eval.conversation.last_mut().unwrap().cache = true;
run_eval(eval.clone(), tx.clone());
let executor = gpui::background_executor();
for _ in 1..iterations {
let eval = eval.clone();
let tx = tx.clone();
executor.spawn(async move { run_eval(eval, tx) }).detach();
}
drop(tx);
let mut failed_count = 0;
let mut failed_evals = HashMap::default();
let mut errored_evals = HashMap::default();
let mut eval_outputs = Vec::new();
let mut cumulative_parser_metrics = EditParserMetrics::default();
while let Ok(output) = rx.recv() {
match output {
Ok(output) => {
cumulative_parser_metrics += output.edit_output._parser_metrics.clone();
eval_outputs.push(output.clone());
if output.assertion.score < 80 {
failed_count += 1;
failed_evals
.entry(output.buffer_text.clone())
.or_insert(Vec::new())
.push(output);
}
}
Err(error) => {
failed_count += 1;
*errored_evals.entry(format!("{:?}", error)).or_insert(0) += 1;
}
}
evaluated_count += 1;
report_progress(evaluated_count, iterations);
}
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
println!("Actual pass ratio: {}\n", actual_pass_ratio);
if actual_pass_ratio < expected_pass_ratio {
let mut errored_evals = errored_evals.into_iter().collect::<Vec<_>>();
errored_evals.sort_by_key(|(_, count)| Reverse(*count));
for (error, count) in errored_evals {
println!("Eval errored {} times. Error: {}", count, error);
}
let mut failed_evals = failed_evals.into_iter().collect::<Vec<_>>();
failed_evals.sort_by_key(|(_, evals)| Reverse(evals.len()));
for (_buffer_output, failed_evals) in failed_evals {
let eval_output = failed_evals.first().unwrap();
println!("Eval failed {} times", failed_evals.len());
println!("{}", eval_output);
}
panic!(
"Actual pass ratio: {}\nExpected pass ratio: {}",
actual_pass_ratio, expected_pass_ratio
);
}
let mismatched_tag_ratio =
cumulative_parser_metrics.mismatched_tags as f32 / cumulative_parser_metrics.tags as f32;
if mismatched_tag_ratio > 0.02 {
for eval_output in eval_outputs {
println!("{}", eval_output);
}
panic!("Too many mismatched tags: {:?}", cumulative_parser_metrics);
}
}
fn run_eval(eval: EvalInput, tx: mpsc::Sender<Result<EvalOutput>>) {
let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
let mut cx = TestAppContext::build(dispatcher, None);
let output = cx.executor().block_test(async {
let test = EditAgentTest::new(&mut cx).await;
test.eval(eval, &mut cx).await
});
tx.send(output).unwrap();
}
#[derive(Clone)]
struct EvalOutput {
assertion: EvalAssertionResult,
buffer_text: String,
edit_output: EditAgentOutput,
diff: String,
}
impl Display for EvalOutput {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Score: {:?}", self.assertion.score)?;
if let Some(message) = self.assertion.message.as_ref() {
writeln!(f, "Message: {}", message)?;
}
writeln!(f, "Diff:\n{}", self.diff)?;
writeln!(
f,
"Parser Metrics:\n{:#?}",
self.edit_output._parser_metrics
)?;
writeln!(f, "Raw Edits:\n{}", self.edit_output._raw_edits)?;
Ok(())
}
}
fn report_progress(evaluated_count: usize, iterations: usize) {
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
std::io::stdout().flush().unwrap();
}
struct EditAgentTest {
agent: EditAgent,
project: Entity<Project>,
judge_model: Arc<dyn LanguageModel>,
}
impl EditAgentTest {
async fn new(cx: &mut TestAppContext) -> Self {
cx.executor().allow_parking();
cx.update(settings::init);
cx.update(Project::init_settings);
cx.update(language::init);
cx.update(gpui_tokio::init);
cx.update(client::init_settings);
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let (agent_model, judge_model) = cx
.update(|cx| {
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
cx.spawn(async move |cx| {
let agent_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
let judge_model =
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
(agent_model.unwrap(), judge_model.unwrap())
})
})
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
Self {
agent: EditAgent::new(agent_model, action_log, Templates::new()),
project,
judge_model,
}
}
async fn load_model(
provider: &str,
id: &str,
cx: &mut AsyncApp,
) -> Result<Arc<dyn LanguageModel>> {
let (provider, model) = cx.update(|cx| {
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.provider_id().0 == provider && model.id().0 == id)
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
(provider, model)
})?;
cx.update(|cx| provider.authenticate(cx))?.await?;
Ok(model)
}
async fn eval(&self, eval: EvalInput, cx: &mut TestAppContext) -> Result<EvalOutput> {
let path = self
.project
.read_with(cx, |project, cx| {
project.find_project_path(eval.input_path, cx)
})
.unwrap();
let buffer = self
.project
.update(cx, |project, cx| project.open_buffer(path, cx))
.await
.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text(eval.input_content.clone(), cx)
});
let (edit_output, _events) = self.agent.edit(
buffer.clone(),
eval.edit_description,
eval.conversation,
&mut cx.to_async(),
);
let edit_output = edit_output.await?;
let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
let actual_diff = language::unified_diff(&eval.input_content, &buffer_text);
let assertion = match eval.assertion {
EvalAssertion::AssertEqual(expected_output) => EvalAssertionResult {
score: if strip_empty_lines(&buffer_text) == strip_empty_lines(&expected_output) {
100
} else {
0
},
message: None,
},
EvalAssertion::JudgeDiff(assertions) => self
.judge_diff(&actual_diff, assertions, &cx.to_async())
.await
.context("failed comparing diffs")?,
};
Ok(EvalOutput {
assertion,
diff: actual_diff,
buffer_text,
edit_output,
})
}
async fn judge_diff(
&self,
diff: &str,
assertions: &'static str,
cx: &AsyncApp,
) -> Result<EvalAssertionResult> {
let prompt = DiffJudgeTemplate {
diff: diff.to_string(),
assertions,
}
.render(&self.agent.templates)
.unwrap();
let request = LanguageModelRequest {
messages: vec![LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
cache: false,
}],
..Default::default()
};
let mut response = self.judge_model.stream_completion_text(request, cx).await?;
let mut output = String::new();
while let Some(chunk) = response.stream.next().await {
let chunk = chunk?;
output.push_str(&chunk);
}
// Parse the score from the response
let re = regex::Regex::new(r"<score>(\d+)</score>").unwrap();
if let Some(captures) = re.captures(&output) {
if let Some(score_match) = captures.get(1) {
let score = score_match.as_str().parse().unwrap_or(0);
return Ok(EvalAssertionResult {
score,
message: Some(output),
});
}
}
Err(anyhow!(
"No score found in response. Raw output: {}",
output
))
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum EvalAssertion {
AssertEqual(String),
JudgeDiff(&'static str),
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct EvalAssertionResult {
score: usize,
message: Option<String>,
}
#[derive(Serialize)]
pub struct DiffJudgeTemplate {
diff: String,
assertions: &'static str,
}
impl Template for DiffJudgeTemplate {
const TEMPLATE_NAME: &'static str = "diff_judge.hbs";
}
fn strip_empty_lines(text: &str) -> String {
text.lines()
.filter(|line| !line.trim().is_empty())
.collect::<Vec<_>>()
.join("\n")
}

View File

@@ -1,328 +0,0 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -1,374 +0,0 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -1,378 +0,0 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
handle_command_output(output)
}
fn handle_command_output(output: std::process::Output) -> Result<String> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -1,374 +0,0 @@
use crate::commit::get_messages;
use crate::{GitRemote, Oid};
use anyhow::{Context as _, Result, anyhow};
use collections::{HashMap, HashSet};
use futures::AsyncWriteExt;
use gpui::SharedString;
use serde::{Deserialize, Serialize};
use std::process::Stdio;
use std::{ops::Range, path::Path};
use text::Rope;
use time::OffsetDateTime;
use time::UtcOffset;
use time::macros::format_description;
pub use git2 as libgit;
#[derive(Debug, Clone, Default)]
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
pub struct ParsedCommitMessage {
pub message: SharedString,
pub permalink: Option<url::Url>,
pub pull_request: Option<crate::hosting_provider::PullRequest>,
pub remote: Option<GitRemote>,
}
impl Blame {
pub async fn for_path(
git_binary: &Path,
working_directory: &Path,
path: &Path,
content: &Rope,
remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
let mut unique_shas = HashSet::default();
for entry in entries.iter_mut() {
unique_shas.insert(entry.sha);
}
let shas = unique_shas.into_iter().collect::<Vec<_>>();
let messages = get_messages(working_directory, &shas)
.await
.context("failed to get commit messages")?;
Ok(Self {
entries,
messages,
remote_url,
})
}
}
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
async fn run_git_blame(
git_binary: &Path,
working_directory: &Path,
path: &Path,
contents: &Rope,
) -> Result<String> {
let mut child = util::command::new_smol_command(git_binary)
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
let stdin = child
.stdin
.as_mut()
.context("failed to get pipe to stdin of git blame command")?;
for chunk in contents.chunks() {
stdin.write_all(chunk.as_bytes()).await?;
}
stdin.flush().await?;
let output = child
.output()
.await
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let trimmed = stderr.trim();
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
return Ok(String::new());
}
return Err(anyhow!("git blame process failed: {}", stderr));
}
Ok(String::from_utf8(output.stdout)?)
}
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
pub struct BlameEntry {
pub sha: Oid,
pub range: Range<u32>,
pub original_line_number: u32,
pub author: Option<String>,
pub author_mail: Option<String>,
pub author_time: Option<i64>,
pub author_tz: Option<String>,
pub committer_name: Option<String>,
pub committer_email: Option<String>,
pub committer_time: Option<i64>,
pub committer_tz: Option<String>,
pub summary: Option<String>,
pub previous: Option<String>,
pub filename: String,
}
impl BlameEntry {
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
// entry. The line MUST have this format:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
let mut parts = line.split_whitespace();
let sha = parts
.next()
.and_then(|line| line.parse::<Oid>().ok())
.ok_or_else(|| anyhow!("failed to parse sha"))?;
let original_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
let final_line_number = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let line_count = parts
.next()
.and_then(|line| line.parse::<u32>().ok())
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
let start_line = final_line_number.saturating_sub(1);
let end_line = start_line + line_count;
let range = start_line..end_line;
Ok(Self {
sha,
range,
original_line_number,
..Default::default()
})
}
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
let format = format_description!("[offset_hour][offset_minute]");
let offset = UtcOffset::parse(author_tz, &format)?;
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
Ok(date_time_utc.to_offset(offset))
} else {
// Directly return current time in UTC if there's no committer time or timezone
Ok(time::OffsetDateTime::now_utc())
}
}
}
// parse_git_blame parses the output of `git blame --incremental`, which returns
// all the blame-entries for a given path incrementally, as it finds them.
//
// Each entry *always* starts with:
//
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
//
// Each entry *always* ends with:
//
// filename <whitespace-quoted-filename-goes-here>
//
// Line numbers are 1-indexed.
//
// A `git blame --incremental` entry looks like this:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
// author Joe Schmoe
// author-mail <joe.schmoe@example.com>
// author-time 1709741400
// author-tz +0100
// committer Joe Schmoe
// committer-mail <joe.schmoe@example.com>
// committer-time 1709741400
// committer-tz +0100
// summary Joe's cool commit
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// If the entry has the same SHA as an entry that was already printed then no
// signature information is printed:
//
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
// previous 486c2409237a2c627230589e567024a96751d475 index.js
// filename index.js
//
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
let mut entries: Vec<BlameEntry> = Vec::new();
let mut index: HashMap<Oid, usize> = HashMap::default();
let mut current_entry: Option<BlameEntry> = None;
for line in output.lines() {
let mut done = false;
match &mut current_entry {
None => {
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
if let Some(existing_entry) = index
.get(&new_entry.sha)
.and_then(|slot| entries.get(*slot))
{
new_entry.author.clone_from(&existing_entry.author);
new_entry
.author_mail
.clone_from(&existing_entry.author_mail);
new_entry.author_time = existing_entry.author_time;
new_entry.author_tz.clone_from(&existing_entry.author_tz);
new_entry
.committer_name
.clone_from(&existing_entry.committer_name);
new_entry
.committer_email
.clone_from(&existing_entry.committer_email);
new_entry.committer_time = existing_entry.committer_time;
new_entry
.committer_tz
.clone_from(&existing_entry.committer_tz);
new_entry.summary.clone_from(&existing_entry.summary);
}
current_entry.replace(new_entry);
}
Some(entry) => {
let Some((key, value)) = line.split_once(' ') else {
continue;
};
let is_committed = !entry.sha.is_zero();
match key {
"filename" => {
entry.filename = value.into();
done = true;
}
"previous" => entry.previous = Some(value.into()),
"summary" if is_committed => entry.summary = Some(value.into()),
"author" if is_committed => entry.author = Some(value.into()),
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
"author-time" if is_committed => {
entry.author_time = Some(value.parse::<i64>()?)
}
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
"committer" if is_committed => entry.committer_name = Some(value.into()),
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
"committer-time" if is_committed => {
entry.committer_time = Some(value.parse::<i64>()?)
}
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
_ => {}
}
}
};
if done {
if let Some(entry) = current_entry.take() {
index.insert(entry.sha, entries.len());
// We only want annotations that have a commit.
if !entry.sha.is_zero() {
entries.push(entry);
}
}
}
}
Ok(entries)
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::BlameEntry;
use super::parse_git_blame;
fn read_test_data(filename: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push(filename);
std::fs::read_to_string(&path)
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
}
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("test_data");
path.push("golden");
path.push(format!("{}.json", golden_filename));
let mut have_json =
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
// We always want to save with a trailing newline.
have_json.push('\n');
let update = std::env::var("UPDATE_GOLDEN")
.map(|val| val.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if update {
std::fs::create_dir_all(path.parent().unwrap())
.expect("could not create golden test data directory");
std::fs::write(&path, have_json).expect("could not write out golden data");
} else {
let want_json =
std::fs::read_to_string(&path).unwrap_or_else(|_| {
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
}).replace("\r\n", "\n");
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
}
}
#[test]
fn test_parse_git_blame_not_committed() {
let output = read_test_data("blame_incremental_not_committed");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_not_committed");
}
#[test]
fn test_parse_git_blame_simple() {
let output = read_test_data("blame_incremental_simple");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_simple");
}
#[test]
fn test_parse_git_blame_complex() {
let output = read_test_data("blame_incremental_complex");
let entries = parse_git_blame(&output).unwrap();
assert_eq_golden(&entries, "blame_incremental_complex");
}
}

View File

@@ -1,339 +0,0 @@
// font-kit/src/canvas.rs
//
// Copyright © 2018 The Pathfinder Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! An in-memory bitmap surface for glyph rasterization.
use lazy_static::lazy_static;
use pathfinder_geometry::rect::RectI;
use pathfinder_geometry::vector::Vector2I;
use std::cmp;
use std::fmt;
use crate::utils;
lazy_static! {
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
let mut lut = [[0; 8]; 256];
for byte in 0..0x100 {
let mut value = [0; 8];
for bit in 0..8 {
if (byte & (0x80 >> bit)) != 0 {
value[bit] = 0xff;
}
}
lut[byte] = value
}
lut
};
}
/// An in-memory bitmap surface for glyph rasterization.
pub struct Canvas {
/// The raw pixel data.
pub pixels: Vec<u8>,
/// The size of the buffer, in pixels.
pub size: Vector2I,
/// The number of *bytes* between successive rows.
pub stride: usize,
/// The image format of the canvas.
pub format: Format,
}
impl Canvas {
/// Creates a new blank canvas with the given pixel size and format.
///
/// Stride is automatically calculated from width.
///
/// The canvas is initialized with transparent black (all values 0).
#[inline]
pub fn new(size: Vector2I, format: Format) -> Canvas {
Canvas::with_stride(
size,
size.x() as usize * format.bytes_per_pixel() as usize,
format,
)
}
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
/// successive rows), and format.
///
/// The canvas is initialized with transparent black (all values 0).
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
Canvas {
pixels: vec![0; stride * size.y() as usize],
size,
stride,
format,
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
self.blit_from(
Vector2I::default(),
&src.pixels,
src.size,
src.stride,
src.format,
)
}
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
/// `src_stride` must be equal or larger than the actual data length.
#[allow(dead_code)]
pub(crate) fn blit_from(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
src_format: Format,
) {
assert_eq!(
src_stride * src_size.y() as usize,
src_bytes.len(),
"Number of pixels in src_bytes does not match stride and size."
);
assert!(
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
"src_stride must be >= than src_size.x()"
);
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
match (self.format, src_format) {
(Format::A8, Format::A8)
| (Format::Rgb24, Format::Rgb24)
| (Format::Rgba32, Format::Rgba32) => {
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::A8, Format::Rgb24) => {
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::A8) => {
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
}
(Format::Rgb24, Format::Rgba32) => self
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::Rgb24) => self
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
}
}
#[allow(dead_code)]
pub(crate) fn blit_from_bitmap_1bpp(
&mut self,
dst_point: Vector2I,
src_bytes: &[u8],
src_size: Vector2I,
src_stride: usize,
) {
if self.format != Format::A8 {
unimplemented!()
}
let dst_rect = RectI::new(dst_point, src_size);
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
let dst_rect = match dst_rect {
Some(dst_rect) => dst_rect,
None => return,
};
let size = dst_rect.size();
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
for y in 0..size.y() {
let (dest_row_start, src_row_start) = (
(y + dst_rect.origin_y()) as usize * self.stride
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + dest_row_stride;
let src_row_end = src_row_start + src_row_stride;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
for x in 0..src_row_stride {
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
let dest_start = x * 8;
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
let src = &pattern[0..(dest_end - dest_start)];
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
}
}
}
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
/// `src_stride` must be specified in bytes.
/// The dimensions of `rect` must be in pixels.
fn blit_from_with<B: Blit>(
&mut self,
rect: RectI,
src_bytes: &[u8],
src_stride: usize,
src_format: Format,
) {
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
for y in 0..rect.height() {
let (dest_row_start, src_row_start) = (
(y + rect.origin_y()) as usize * self.stride
+ rect.origin_x() as usize * dest_bytes_per_pixel,
y as usize * src_stride,
);
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
B::blit(dest_row_pixels, src_row_pixels)
}
}
}
impl fmt::Debug for Canvas {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Canvas")
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
.field("size", &self.size)
.field("stride", &self.stride)
.field("format", &self.format)
.finish()
}
}
/// The image format for the canvas.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Premultiplied R8G8B8A8, little-endian.
Rgba32,
/// R8G8B8, little-endian.
Rgb24,
/// A8.
A8,
}
impl Format {
/// Returns the number of bits per pixel that this image format corresponds to.
#[inline]
pub fn bits_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 32,
Format::Rgb24 => 24,
Format::A8 => 8,
}
}
/// Returns the number of color channels per pixel that this image format corresponds to.
#[inline]
pub fn components_per_pixel(self) -> u8 {
match self {
Format::Rgba32 => 4,
Format::Rgb24 => 3,
Format::A8 => 1,
}
}
/// Returns the number of bits per color channel that this image format contains.
#[inline]
pub fn bits_per_component(self) -> u8 {
self.bits_per_pixel() / self.components_per_pixel()
}
/// Returns the number of bytes per pixel that this image format corresponds to.
#[inline]
pub fn bytes_per_pixel(self) -> u8 {
self.bits_per_pixel() / 8
}
}
/// The antialiasing strategy that should be used when rasterizing glyphs.
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum RasterizationOptions {
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
Bilevel,
/// Grayscale antialiasing. Only one channel is used.
GrayscaleAa,
/// Subpixel RGB antialiasing, for LCD screens.
SubpixelAa,
}
trait Blit {
fn blit(dest: &mut [u8], src: &[u8]);
}
struct BlitMemcpy;
impl Blit for BlitMemcpy {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
dest.clone_from_slice(src)
}
}
struct BlitRgb24ToA8;
impl Blit for BlitRgb24ToA8 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
*dest = src[1]
}
}
}
struct BlitA8ToRgb24;
impl Blit for BlitA8ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
dest[0] = *src;
dest[1] = *src;
dest[2] = *src;
}
}
}
struct BlitRgba32ToRgb24;
impl Blit for BlitRgba32ToRgb24 {
#[inline]
fn blit(dest: &mut [u8], src: &[u8]) {
// TODO(pcwalton): SIMD.
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
dest.copy_from_slice(&src[0..3])
}
}
}
struct BlitRgb24ToRgba32;
impl Blit for BlitRgb24ToRgba32 {
fn blit(dest: &mut [u8], src: &[u8]) {
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
dest[0] = src[0];
dest[1] = src[1];
dest[2] = src[2];
dest[3] = 255;
}
}
}

View File

@@ -11,7 +11,6 @@ use gpui::{
};
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
language_settings::SoftWrap,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -25,8 +24,6 @@ use ui::{Disclosure, Tooltip, Window, prelude::*};
use util::ResultExt;
use workspace::Workspace;
pub struct EditFileTool;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFileToolInput {
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
@@ -78,6 +75,8 @@ struct PartialInput {
new_string: String,
}
pub struct EditFileTool;
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool {
@@ -275,14 +274,12 @@ pub struct EditFileToolCard {
project: Entity<Project>,
diff_task: Option<Task<Result<()>>>,
preview_expanded: bool,
error_expanded: bool,
full_height_expanded: bool,
total_lines: Option<u32>,
editor_unique_id: EntityId,
}
impl EditFileToolCard {
pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
@@ -296,13 +293,11 @@ impl EditFileToolCard {
window,
cx,
);
editor.set_show_scrollbars(false, cx);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_scrolling(cx);
editor.disable_expand_excerpt_buttons(cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_scrollbars(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
@@ -317,13 +312,11 @@ impl EditFileToolCard {
multibuffer,
diff_task: None,
preview_expanded: true,
error_expanded: false,
full_height_expanded: false,
total_lines: None,
}
}
pub fn set_diff(
fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
@@ -336,14 +329,13 @@ impl EditFileToolCard {
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
this.update(cx, |this, cx| {
this.total_lines = this.multibuffer.update(cx, |multibuffer, cx| {
this.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let diff = buffer_diff.read(cx);
let diff_hunk_ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
multibuffer.clear(cx);
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
@@ -353,10 +345,7 @@ impl EditFileToolCard {
);
debug_assert!(is_newly_added);
multibuffer.add_diff(buffer_diff, cx);
let end = multibuffer.len(cx);
Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
});
cx.notify();
})
}));
@@ -371,10 +360,7 @@ impl ToolCard for EditFileToolCard {
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let (failed, error_message) = match status {
ToolUseStatus::Error(err) => (true, Some(err.to_string())),
_ => (false, None),
};
let failed = matches!(status, ToolUseStatus::Error(_));
let path_label_button = h_flex()
.id(("edit-tool-path-label-button", self.editor_unique_id))
@@ -466,26 +452,9 @@ impl ToolCard for EditFileToolCard {
.map(|container| {
if failed {
container.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
.child(
Disclosure::new(
("edit-file-error-disclosure", self.editor_unique_id),
self.error_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.error_expanded = !this.error_expanded;
},
)),
),
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
} else {
container.child(
@@ -504,14 +473,8 @@ impl ToolCard for EditFileToolCard {
}
});
let (editor, editor_line_height) = self.editor.update(cx, |editor, cx| {
let line_height = editor
.style()
.map(|style| style.text.line_height_in_pixels(window.rem_size()))
.unwrap_or_default();
let element = editor.render(window, cx);
(element.into_any_element(), line_height)
let editor = self.editor.update(cx, |editor, cx| {
editor.render(window, cx).into_any_element()
});
let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
@@ -535,9 +498,6 @@ impl ToolCard for EditFileToolCard {
let border_color = cx.theme().colors().border.opacity(0.6);
const DEFAULT_COLLAPSED_LINES: u32 = 10;
let is_collapsible = self.total_lines.unwrap_or(0) > DEFAULT_COLLAPSED_LINES;
v_flex()
.mb_2()
.border_1()
@@ -546,79 +506,50 @@ impl ToolCard for EditFileToolCard {
.rounded_lg()
.overflow_hidden()
.child(codeblock_header)
.when(failed && self.error_expanded, |card| {
card.child(
v_flex()
.p_2()
.gap_1()
.border_t_1()
.border_dashed()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.rounded_b_md()
.child(
Label::new("Error")
.size(LabelSize::XSmall)
.color(Color::Error),
)
.child(
div()
.rounded_md()
.text_ui_sm(cx)
.bg(cx.theme().colors().editor_background)
.children(
error_message
.map(|error| div().child(error).into_any_element()),
),
),
)
})
.when(!failed && self.preview_expanded, |card| {
card.child(
v_flex()
.relative()
.map(|editor_container| {
if self.full_height_expanded {
editor_container.h_full()
} else {
editor_container
.h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
}
})
.overflow_hidden()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.map(|editor_container| {
if self.full_height_expanded {
editor_container.h_full()
} else {
editor_container.max_h_64()
}
})
.child(div().pl_1().child(editor))
.when(
!self.full_height_expanded && is_collapsible,
|editor_container| editor_container.child(gradient_overlay),
),
.when(!self.full_height_expanded, |editor_container| {
editor_container.child(gradient_overlay)
}),
)
})
.when(!failed && self.preview_expanded, |card| {
card.child(
h_flex()
.id(("edit-tool-card-inner-hflex", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.rounded_b_md()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
)
.when(is_collapsible, |editor_container| {
editor_container.child(
h_flex()
.id(("expand-button", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.rounded_b_md()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
)
})
})
}
}
@@ -696,6 +627,7 @@ mod tests {
#[test]
fn still_streaming_ui_text_with_path() {
let tool = EditFileTool;
let input = json!({
"path": "src/main.rs",
"display_description": "",
@@ -703,11 +635,12 @@ mod tests {
"new_string": "new code"
});
assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let tool = EditFileTool;
let input = json!({
"path": "",
"display_description": "Fix error handling",
@@ -715,14 +648,12 @@ mod tests {
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let tool = EditFileTool;
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
@@ -730,14 +661,12 @@ mod tests {
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let tool = EditFileTool;
let input = json!({
"path": "",
"display_description": "",
@@ -745,19 +674,14 @@ mod tests {
"new_string": "new code"
});
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn still_streaming_ui_text_with_null() {
let tool = EditFileTool;
let input = serde_json::Value::Null;
assert_eq!(
EditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
}

View File

@@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, Entity, Task};
use language::{OffsetRangeExt, ParseStatus, Point};
use language::OffsetRangeExt;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{
Project,
@@ -13,7 +13,6 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{cmp, fmt::Write, sync::Arc};
use ui::IconName;
use util::RangeExt;
use util::markdown::MarkdownInlineCode;
use util::paths::PathMatcher;
@@ -103,7 +102,6 @@ impl Tool for GrepTool {
cx: &mut App,
) -> ToolResult {
const CONTEXT_LINES: u32 = 2;
const MAX_ANCESTOR_LINES: u32 = 10;
let input = match serde_json::from_value::<GrepToolInput>(input) {
Ok(input) => input,
@@ -142,7 +140,7 @@ impl Tool for GrepTool {
let results = project.update(cx, |project, cx| project.search(query, cx));
cx.spawn(async move |cx| {
cx.spawn(async move|cx| {
futures::pin_mut!(results);
let mut output = String::new();
@@ -150,113 +148,68 @@ impl Tool for GrepTool {
let mut matches_found = 0;
let mut has_more_matches = false;
'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
if ranges.is_empty() {
continue;
}
let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
(buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
})? else {
continue;
};
buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
let mut file_header_written = false;
let mut ranges = ranges
.into_iter()
.map(|range| {
let mut point_range = range.to_point(buffer);
point_range.start.row =
point_range.start.row.saturating_sub(CONTEXT_LINES);
point_range.start.column = 0;
point_range.end.row = cmp::min(
buffer.max_point().row,
point_range.end.row + CONTEXT_LINES,
);
point_range.end.column = buffer.line_len(point_range.end.row);
point_range
})
.peekable();
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
}
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let mut ranges = ranges
.into_iter()
.map(|range| {
let matched = range.to_point(&snapshot);
let matched_end_line_len = snapshot.line_len(matched.end.row);
let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
let symbols = snapshot.symbols_containing(matched.start, None);
if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
let end_col = snapshot.line_len(end_row);
let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
if capped_ancestor_range.contains_inclusive(&full_lines) {
return (capped_ancestor_range, Some(full_ancestor_range), symbols)
while let Some(mut range) = ranges.next() {
if skips_remaining > 0 {
skips_remaining -= 1;
continue;
}
}
let mut matched = matched;
matched.start.column = 0;
matched.start.row =
matched.start.row.saturating_sub(CONTEXT_LINES);
matched.end.row = cmp::min(
snapshot.max_point().row,
matched.end.row + CONTEXT_LINES,
);
matched.end.column = snapshot.line_len(matched.end.row);
// We'd already found a full page of matches, and we just found one more.
if matches_found >= RESULTS_PER_PAGE {
has_more_matches = true;
return Ok(());
}
(matched, None, symbols)
})
.peekable();
while let Some(next_range) = ranges.peek() {
if range.end.row >= next_range.start.row {
range.end = next_range.end;
ranges.next();
} else {
break;
}
}
let mut file_header_written = false;
if !file_header_written {
writeln!(output, "\n## Matches in {}", path.display())?;
file_header_written = true;
}
while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
if skips_remaining > 0 {
skips_remaining -= 1;
continue;
}
let start_line = range.start.row + 1;
let end_line = range.end.row + 1;
writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
output.extend(buffer.text_for_range(range));
output.push_str("\n```\n");
// We'd already found a full page of matches, and we just found one more.
if matches_found >= RESULTS_PER_PAGE {
has_more_matches = true;
break 'outer;
}
while let Some((next_range, _, _)) = ranges.peek() {
if range.end.row >= next_range.start.row {
range.end = next_range.end;
ranges.next();
} else {
break;
matches_found += 1;
}
}
if !file_header_written {
writeln!(output, "\n## Matches in {}", path.display())?;
file_header_written = true;
}
let end_row = range.end.row;
output.push_str("\n### ");
if let Some(parent_symbols) = &parent_symbols {
for symbol in parent_symbols {
write!(output, "{} ", symbol.text)?;
}
}
if range.start.row == end_row {
writeln!(output, "L{}", range.start.row + 1)?;
} else {
writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
}
output.push_str("```\n");
output.extend(snapshot.text_for_range(range));
output.push_str("\n```\n");
if let Some(ancestor_range) = ancestor_range {
if end_row < ancestor_range.end.row {
let remaining_lines = ancestor_range.end.row - end_row;
writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
}
}
matches_found += 1;
}
Ok(())
})??;
}
if matches_found == 0 {
@@ -280,16 +233,13 @@ mod tests {
use super::*;
use assistant_tool::Tool;
use gpui::{AppContext, TestAppContext};
use language::{Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use settings::SettingsStore;
use unindent::Unindent;
use util::path;
#[gpui::test]
async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
init_test(cx);
cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree(
@@ -377,7 +327,6 @@ mod tests {
#[gpui::test]
async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
init_test(cx);
cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor().clone());
fs.insert_tree(
@@ -452,290 +401,6 @@ mod tests {
);
}
/// Helper function to set up a syntax test environment
async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
use unindent::Unindent;
init_test(cx);
cx.executor().allow_parking();
let fs = FakeFs::new(cx.executor().clone());
// Create test file with syntax structures
fs.insert_tree(
"/root",
serde_json::json!({
"test_syntax.rs": r#"
fn top_level_function() {
println!("This is at the top level");
}
mod feature_module {
pub mod nested_module {
pub fn nested_function(
first_arg: String,
second_arg: i32,
) {
println!("Function in nested module");
println!("{first_arg}");
println!("{second_arg}");
}
}
}
struct MyStruct {
field1: String,
field2: i32,
}
impl MyStruct {
fn method_with_block() {
let condition = true;
if condition {
println!("Inside if block");
}
}
fn long_function() {
println!("Line 1");
println!("Line 2");
println!("Line 3");
println!("Line 4");
println!("Line 5");
println!("Line 6");
println!("Line 7");
println!("Line 8");
println!("Line 9");
println!("Line 10");
println!("Line 11");
println!("Line 12");
}
}
trait Processor {
fn process(&self, input: &str) -> String;
}
impl Processor for MyStruct {
fn process(&self, input: &str) -> String {
format!("Processed: {}", input)
}
}
"#.unindent().trim(),
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
project.update(cx, |project, _cx| {
project.languages().add(rust_lang().into())
});
project
}
#[gpui::test]
async fn test_grep_top_level_function(cx: &mut TestAppContext) {
let project = setup_syntax_test(cx).await;
// Test: Line at the top level of the file
let input = serde_json::to_value(GrepToolInput {
regex: "This is at the top level".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
Found 1 matches:
## Matches in root/test_syntax.rs
### fn top_level_function L1-3
```
fn top_level_function() {
println!("This is at the top level");
}
```
"#
.unindent();
assert_eq!(result, expected);
}
#[gpui::test]
async fn test_grep_function_body(cx: &mut TestAppContext) {
let project = setup_syntax_test(cx).await;
// Test: Line inside a function body
let input = serde_json::to_value(GrepToolInput {
regex: "Function in nested module".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
Found 1 matches:
## Matches in root/test_syntax.rs
### mod feature_module pub mod nested_module pub fn nested_function L10-14
```
) {
println!("Function in nested module");
println!("{first_arg}");
println!("{second_arg}");
}
```
"#
.unindent();
assert_eq!(result, expected);
}
#[gpui::test]
async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
let project = setup_syntax_test(cx).await;
// Test: Line with a function argument
let input = serde_json::to_value(GrepToolInput {
regex: "second_arg".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
Found 1 matches:
## Matches in root/test_syntax.rs
### mod feature_module pub mod nested_module pub fn nested_function L7-14
```
pub fn nested_function(
first_arg: String,
second_arg: i32,
) {
println!("Function in nested module");
println!("{first_arg}");
println!("{second_arg}");
}
```
"#
.unindent();
assert_eq!(result, expected);
}
#[gpui::test]
async fn test_grep_if_block(cx: &mut TestAppContext) {
use unindent::Unindent;
let project = setup_syntax_test(cx).await;
// Test: Line inside an if block
let input = serde_json::to_value(GrepToolInput {
regex: "Inside if block".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
Found 1 matches:
## Matches in root/test_syntax.rs
### impl MyStruct fn method_with_block L26-28
```
if condition {
println!("Inside if block");
}
```
"#
.unindent();
assert_eq!(result, expected);
}
#[gpui::test]
async fn test_grep_long_function_top(cx: &mut TestAppContext) {
use unindent::Unindent;
let project = setup_syntax_test(cx).await;
// Test: Line in the middle of a long function - should show message about remaining lines
let input = serde_json::to_value(GrepToolInput {
regex: "Line 5".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
Found 1 matches:
## Matches in root/test_syntax.rs
### impl MyStruct fn long_function L31-41
```
fn long_function() {
println!("Line 1");
println!("Line 2");
println!("Line 3");
println!("Line 4");
println!("Line 5");
println!("Line 6");
println!("Line 7");
println!("Line 8");
println!("Line 9");
println!("Line 10");
```
3 lines remaining in ancestor node. Read the file to see all.
"#
.unindent();
assert_eq!(result, expected);
}
#[gpui::test]
async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
use unindent::Unindent;
let project = setup_syntax_test(cx).await;
// Test: Line in the long function
let input = serde_json::to_value(GrepToolInput {
regex: "Line 12".to_string(),
include_pattern: Some("**/*.rs".to_string()),
offset: 0,
case_sensitive: false,
})
.unwrap();
let result = run_grep_tool(input, project.clone(), cx).await;
let expected = r#"
Found 1 matches:
## Matches in root/test_syntax.rs
### impl MyStruct fn long_function L41-45
```
println!("Line 10");
println!("Line 11");
println!("Line 12");
}
}
```
"#
.unindent();
assert_eq!(result, expected);
}
async fn run_grep_tool(
input: serde_json::Value,
project: Entity<Project>,
@@ -746,13 +411,7 @@ mod tests {
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
match task.output.await {
Ok(result) => {
if cfg!(windows) {
result.replace("root\\", "root/")
} else {
result
}
}
Ok(result) => result,
Err(e) => panic!("Failed to run grep tool: {}", e),
}
}
@@ -765,20 +424,4 @@ mod tests {
Project::init_settings(cx);
});
}
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
name: "Rust".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
)
.with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
.unwrap()
}
}

View File

@@ -6,7 +6,7 @@ use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use std::sync::Arc;
use ui::IconName;
use util::markdown::MarkdownEscaped;
@@ -50,7 +50,7 @@ impl Tool for OpenTool {
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
@@ -60,107 +60,11 @@ impl Tool for OpenTool {
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, project, cx);
cx.background_spawn(async move {
match abs_path {
Some(path) => open::that(path),
None => open::that(&input.path_or_url),
}
.context("Failed to open URL or file path")?;
open::that(&input.path_or_url).context("Failed to open URL or file path")?;
Ok(format!("Successfully opened {}", input.path_or_url))
})
.into()
}
}
fn to_absolute_path(
potential_path: &str,
project: Entity<Project>,
cx: &mut App,
) -> Option<PathBuf> {
let project = project.read(cx);
project
.find_project_path(PathBuf::from(potential_path), cx)
.and_then(|project_path| project.absolute_path(&project_path, cx))
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use project::{FakeFs, Project};
use settings::SettingsStore;
use std::path::Path;
use tempfile::TempDir;
#[gpui::test]
async fn test_to_absolute_path(cx: &mut TestAppContext) {
init_test(cx);
let temp_dir = TempDir::new().expect("Failed to create temp directory");
let temp_path = temp_dir.path().to_string_lossy().to_string();
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
&temp_path,
serde_json::json!({
"src": {
"main.rs": "fn main() {}",
"lib.rs": "pub fn lib_fn() {}"
},
"docs": {
"readme.md": "# Project Documentation"
}
}),
)
.await;
// Use the temp_path as the root directory, not just its filename
let project = Project::test(fs.clone(), [temp_dir.path()], cx).await;
// Test cases where the function should return Some
cx.update(|cx| {
// Project-relative paths should return Some
// Create paths using the last segment of the temp path to simulate a project-relative path
let root_dir_name = Path::new(&temp_path)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("temp"))
.to_string_lossy();
assert!(
to_absolute_path(&format!("{root_dir_name}/src/main.rs"), project.clone(), cx)
.is_some(),
"Failed to resolve main.rs path"
);
assert!(
to_absolute_path(
&format!("{root_dir_name}/docs/readme.md",),
project.clone(),
cx,
)
.is_some(),
"Failed to resolve readme.md path"
);
// External URL should return None
let result = to_absolute_path("https://example.com", project.clone(), cx);
assert_eq!(result, None, "External URLs should return None");
// Path outside project
let result = to_absolute_path("../invalid/path", project.clone(), cx);
assert_eq!(result, None, "Paths outside the project should return None");
});
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
});
}
}

View File

@@ -1,6 +1,5 @@
use crate::schema::json_schema_for;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::outline;
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
@@ -15,6 +14,10 @@ use ui::IconName;
use util::markdown::MarkdownInlineCode;
/// If the model requests to read a file whose size exceeds this, then
/// the tool will return an error along with the model's symbol outline,
/// and suggest trying again using line ranges from the outline.
const MAX_FILE_SIZE_TO_READ: usize = 16384;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ReadFileToolInput {
/// The relative path of the file to read.
@@ -122,8 +125,7 @@ impl Tool for ReadFileTool {
if input.start_line.is_some() || input.end_line.is_some() {
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
let start = input.start_line.unwrap_or(1).max(1);
let start = input.start_line.unwrap_or(1);
let lines = text.split('\n').skip(start - 1);
if let Some(end) = input.end_line {
let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line
@@ -142,7 +144,7 @@ impl Tool for ReadFileTool {
// No line ranges specified, so check file size to see if it's too big.
let file_size = buffer.read_with(cx, |buffer, _cx| buffer.text().len())?;
if file_size <= outline::AUTO_OUTLINE_SIZE {
if file_size <= MAX_FILE_SIZE_TO_READ {
// File is small enough, so return its contents.
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
@@ -152,9 +154,9 @@ impl Tool for ReadFileTool {
Ok(result)
} else {
// File is too big, so return the outline
// File is too big, so return an error with the outline
// and a suggestion to read again with line numbers.
let outline = outline::file_outline(project, file_path, action_log, None, cx).await?;
let outline = file_outline(project, file_path, action_log, None, cx).await?;
Ok(formatdoc! {"
This file was too big to read all at once. Here is an outline of its symbols:
@@ -330,67 +332,6 @@ mod test {
assert_eq!(result.unwrap(), "Line 2\nLine 3\nLine 4");
}
#[gpui::test]
async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/root",
json!({
"multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
// start_line of 0 should be treated as 1
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 0,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 1\nLine 2");
// end_line of 0 should result in at least 1 line
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 1,
"end_line": 0
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 1");
// when start_line > end_line, should still return at least 1 line
let result = cx
.update(|cx| {
let input = json!({
"path": "root/multiline.txt",
"start_line": 3,
"end_line": 2
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, None, cx)
.output
})
.await;
assert_eq!(result.unwrap(), "Line 3");
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);

View File

@@ -1,339 +0,0 @@
use crate::{
Templates,
edit_agent::{EditAgent, EditAgentOutputEvent},
edit_file_tool::EditFileToolCard,
schema::json_schema_for,
};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language_model::{
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::prelude::*;
use util::ResultExt;
pub struct StreamingEditFileTool;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct StreamingEditFileToolInput {
/// A one-line, user-friendly markdown description of the edit. This will be
/// shown in the UI and also passed to another model to perform the edit.
///
/// Be terse, but also descriptive in what you want to achieve with this
/// edit. Avoid generic instructions.
///
/// NEVER mention the file path in this description.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
}
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for StreamingEditFileTool {
fn name(&self) -> String {
"edit_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("streaming_edit_file_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<StreamingEditFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<StreamingEditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<StreamingEditFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
return Task::ready(Err(anyhow!(
"Path {} not found in project",
input.path.display()
)))
.into();
};
let Some(worktree) = project
.read(cx)
.worktree_for_id(project_path.worktree_id, cx)
else {
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
};
let exists = worktree.update(cx, |worktree, cx| {
worktree.file_exists(&project_path.path, cx)
});
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let messages = messages.to_vec();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
if !exists.await? {
return Err(anyhow!("{} not found", input.path.display()));
}
let model = cx
.update(|cx| LanguageModelRegistry::read_global(cx).default_model())?
.context("default model not set")?
.model;
let edit_agent = EditAgent::new(model, action_log, Templates::new());
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.await?;
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
.background_spawn({
let old_snapshot = old_snapshot.clone();
async move { old_snapshot.text() }
})
.await;
let (output, mut events) = edit_agent.edit(
buffer.clone(),
input.display_description.clone(),
messages,
cx,
);
let mut hallucinated_old_text = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited => {
if let Some(card) = card_clone.as_ref() {
let new_snapshot =
buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx
.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
})
.await;
card.update(cx, |card, cx| {
card.set_diff(
project_path.path.clone(),
old_text.clone(),
new_text,
cx,
);
})
.log_err();
}
}
EditAgentOutputEvent::HallucinatedOldText(_) => hallucinated_old_text = true,
}
}
output.await?;
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let new_text = cx.background_spawn({
let new_snapshot = new_snapshot.clone();
async move { new_snapshot.text() }
});
let diff = cx.background_spawn(async move {
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
});
let (new_text, diff) = futures::join!(new_text, diff);
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
let input_path = input.path.display();
if diff.is_empty() {
if hallucinated_old_text {
Err(anyhow!(formatdoc! {"
Some edits were produced but none of them could be applied.
Read the relevant sections of {input_path} again so that
I can perform the requested edits.
"}))
} else {
Ok("No edits were made.".to_string())
}
} else {
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input_path, diff))
}
});
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"src/main.rs"
);
}
#[test]
fn still_streaming_ui_text_with_description() {
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
"Fix error handling",
);
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
#[test]
fn still_streaming_ui_text_with_null() {
let input = serde_json::Value::Null;
assert_eq!(
StreamingEditFileTool.still_streaming_ui_text(&input),
DEFAULT_UI_TEXT,
);
}
}

View File

@@ -1,8 +0,0 @@
This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files.
Before using this tool:
1. Use the `read_file` tool to understand the file's contents and context
2. Verify the directory path is correct (only applicable when creating new files):
- Use the `list_directory` tool to verify the parent directory exists and is the correct location

View File

@@ -1,32 +0,0 @@
use anyhow::Result;
use handlebars::Handlebars;
use rust_embed::RustEmbed;
use serde::Serialize;
use std::sync::Arc;
#[derive(RustEmbed)]
#[folder = "src/templates"]
#[include = "*.hbs"]
struct Assets;
pub struct Templates(Handlebars<'static>);
impl Templates {
pub fn new() -> Arc<Self> {
let mut handlebars = Handlebars::new();
handlebars.register_embed_templates::<Assets>().unwrap();
handlebars.register_escape_fn(|text| text.into());
Arc::new(Self(handlebars))
}
}
pub trait Template: Sized {
const TEMPLATE_NAME: &'static str;
fn render(&self, templates: &Templates) -> Result<String>
where
Self: Serialize + Sized,
{
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
}
}

View File

@@ -1,23 +0,0 @@
You are an expert coder, and have been tasked with looking at the following diff:
<diff>
{{diff}}
</diff>
Evaluate the following assertions:
<assertions>
{{assertions}}
</assertions>
You must respond with a short analysis and a score between 0 and 100, where:
- 0 means no assertions pass
- 100 means all the assertions pass perfectly
<analysis>
- Assertion 1: one line describing why the first assertion passes or fails (even partially)
- Assertion 2: one line describing why the second assertion passes or fails (even partially)
- ...
- Assertion N: one line describing why the Nth assertion passes or fails (even partially)
</analysis>
<score>YOUR FINAL SCORE HERE</score>

View File

@@ -1,49 +0,0 @@
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
You MUST respond with a series of edits to that one file in the following format:
```
<edits>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 2 HERE
</old_text>
<new_text>
NEW TEXT 2 HERE
</new_text>
<old_text>
OLD TEXT 3 HERE
</old_text>
<new_text>
NEW TEXT 3 HERE
</new_text>
</edits>
```
Rules for editing:
- `old_text` represents lines in the input file that will be replaced with `new_text`. `old_text` MUST exactly match the existing file content, character for character, including indentation.
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
- When reporting multiple edits, each edit assumes the previous one has already been applied! Therefore, you must ensure `old_text` doesn't reference text that has already been modified by a previous edit.
- Don't explain the edits, just report them.
- Only edit the file specified in `<file_to_edit>` and NEVER include edits to other files!
- If you open an <old_text> tag, you MUST close it using </old_text>
- If you open an <new_text> tag, you MUST close it using </new_text>
<file_to_edit>
{{path}}
</file_to_edit>
<edit_description>
{{edit_description}}
</edit_description>

View File

@@ -1,32 +1,23 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use gpui::{
Animation, AnimationExt, AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task,
Transformation, WeakEntity, Window, percentage,
};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, terminals::TerminalKind};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
env,
path::{Path, PathBuf},
process::ExitStatus,
sync::Arc,
time::{Duration, Instant},
};
use terminal_view::TerminalView;
use ui::{Disclosure, IconName, Tooltip, prelude::*};
use util::{
get_system_shell, markdown::MarkdownInlineCode, size::format_file_size,
time::duration_alt_display,
};
use workspace::Workspace;
use std::future;
use util::get_system_shell;
const COMMAND_OUTPUT_LIMIT: usize = 16 * 1024;
use std::path::Path;
use std::sync::Arc;
use ui::IconName;
use util::command::new_smol_command;
use util::markdown::MarkdownInlineCode;
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct TerminalToolInput {
/// The one-liner command to execute.
command: String,
@@ -84,426 +75,308 @@ impl Tool for TerminalTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let Some(window) = window else {
return Task::ready(Err(anyhow!("no window options"))).into();
};
let input: TerminalToolInput = match serde_json::from_value(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let project = project.read(cx);
let input_path = Path::new(&input.cd);
let working_dir = match working_dir(cx, &input, &project, input_path) {
Ok(dir) => dir,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let terminal = project.update(cx, |project, cx| {
project.create_terminal(
TerminalKind::Task(task::SpawnInTerminal {
command: get_system_shell(),
args: vec!["-c".into(), input.command.clone()],
cwd: working_dir.clone(),
..Default::default()
}),
window,
cx,
)
});
let working_dir = if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
let card = cx.new(|cx| {
TerminalToolCard::new(input.command.clone(), working_dir.clone(), cx.entity_id())
});
let only_worktree = match worktrees.next() {
Some(worktree) => worktree,
None => {
return Task::ready(Err(anyhow!("No worktrees found in the project"))).into();
}
};
let output = cx.spawn({
let card = card.clone();
async move |cx| {
let terminal = terminal.await?;
let workspace = window
.downcast::<Workspace>()
.and_then(|handle| handle.entity(cx).ok())
.context("no workspace entity in root of window")?;
let terminal_view = window.update(cx, |_, window, cx| {
cx.new(|cx| {
TerminalView::new(
terminal.clone(),
workspace.downgrade(),
None,
project.downgrade(),
window,
cx,
)
})
})?;
let _ = card.update(cx, |card, _| {
card.terminal = Some(terminal_view.clone());
card.start_instant = Instant::now();
});
let exit_status = terminal
.update(cx, |terminal, cx| terminal.wait_for_completed_task(cx))?
.await;
let (content, content_line_count) = terminal.update(cx, |terminal, _| {
(terminal.get_content(), terminal.total_lines())
})?;
let previous_len = content.len();
let (processed_content, finished_with_empty_output) =
process_content(content, &input.command, exit_status);
let _ = card.update(cx, |card, _| {
card.command_finished = true;
card.exit_status = exit_status;
card.was_content_truncated = processed_content.len() < previous_len;
card.original_content_len = previous_len;
card.content_line_count = content_line_count;
card.finished_with_empty_output = finished_with_empty_output;
card.elapsed_time = Some(card.start_instant.elapsed());
});
Ok(processed_content)
if worktrees.next().is_some() {
return Task::ready(Err(anyhow!(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly."
))).into();
}
});
ToolResult {
output,
card: Some(card.into()),
}
only_worktree.read(cx).abs_path()
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Task::ready(Err(anyhow!(
"The absolute path must be within one of the project's worktrees"
)))
.into();
}
input_path.into()
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
return Task::ready(Err(anyhow!(
"`cd` directory {} not found in the project",
&input.cd
)))
.into();
};
worktree.read(cx).abs_path()
};
cx.background_spawn(run_command_limited(working_dir, input.command))
.into()
}
}
fn process_content(
content: String,
command: &str,
exit_status: Option<ExitStatus>,
) -> (String, bool) {
let should_truncate = content.len() > COMMAND_OUTPUT_LIMIT;
const LIMIT: usize = 16 * 1024;
let content = if should_truncate {
let mut end_ix = COMMAND_OUTPUT_LIMIT.min(content.len());
while !content.is_char_boundary(end_ix) {
end_ix -= 1;
}
// Don't truncate mid-line, clear the remainder of the last line
end_ix = content[..end_ix].rfind('\n').unwrap_or(end_ix);
&content[..end_ix]
} else {
content.as_str()
};
let is_empty = content.trim().is_empty();
async fn run_command_limited(working_dir: Arc<Path>, command: String) -> Result<String> {
let shell = get_system_shell();
let content = format!(
"```\n{}{}```",
content,
if content.ends_with('\n') { "" } else { "\n" }
let mut cmd = new_smol_command(&shell)
.arg("-c")
.arg(&command)
.current_dir(working_dir)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.context("Failed to execute terminal command")?;
let mut combined_buffer = String::with_capacity(LIMIT + 1);
let mut out_reader = BufReader::new(cmd.stdout.take().context("Failed to get stdout")?);
let mut out_tmp_buffer = String::with_capacity(512);
let mut err_reader = BufReader::new(cmd.stderr.take().context("Failed to get stderr")?);
let mut err_tmp_buffer = String::with_capacity(512);
let mut out_line = Box::pin(
out_reader
.read_line(&mut out_tmp_buffer)
.left_future()
.fuse(),
);
let mut err_line = Box::pin(
err_reader
.read_line(&mut err_tmp_buffer)
.left_future()
.fuse(),
);
let content = if should_truncate {
format!(
"Command output too long. The first {} bytes:\n\n{}",
content.len(),
content,
)
} else {
content
};
let content = match exit_status {
Some(exit_status) if exit_status.success() => {
if is_empty {
"Command executed successfully.".to_string()
} else {
content.to_string()
}
}
Some(exit_status) => {
let code = exit_status.code().unwrap_or(-1);
if is_empty {
format!("Command \"{command}\" failed with exit code {code}.")
} else {
format!("Command \"{command}\" failed with exit code {code}.\n\n{content}")
}
}
None => {
format!(
"Command failed or was interrupted.\nPartial output captured:\n\n{}",
content,
)
}
};
(content, is_empty)
}
fn working_dir(
cx: &mut App,
input: &TerminalToolInput,
project: &Entity<Project>,
input_path: &Path,
) -> Result<Option<PathBuf>, &'static str> {
let project = project.read(cx);
if input.cd == "." {
// Accept "." as meaning "the one worktree" if we only have one worktree.
let mut worktrees = project.worktrees(cx);
match worktrees.next() {
Some(worktree) => {
if worktrees.next().is_some() {
return Err(
"'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.",
);
}
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
}
None => Ok(None),
}
} else if input_path.is_absolute() {
// Absolute paths are allowed, but only if they're in one of the project's worktrees.
if !project
.worktrees(cx)
.any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path()))
{
return Err("The absolute path must be within one of the project's worktrees");
}
Ok(Some(input_path.into()))
} else {
let Some(worktree) = project.worktree_for_root_name(&input.cd, cx) else {
return Err("`cd` directory {} not found in the project");
};
Ok(Some(worktree.read(cx).abs_path().to_path_buf()))
}
}
struct TerminalToolCard {
input_command: String,
working_dir: Option<PathBuf>,
entity_id: EntityId,
exit_status: Option<ExitStatus>,
terminal: Option<Entity<TerminalView>>,
command_finished: bool,
was_content_truncated: bool,
finished_with_empty_output: bool,
content_line_count: usize,
original_content_len: usize,
preview_expanded: bool,
start_instant: Instant,
elapsed_time: Option<Duration>,
}
impl TerminalToolCard {
pub fn new(input_command: String, working_dir: Option<PathBuf>, entity_id: EntityId) -> Self {
Self {
input_command,
working_dir,
entity_id,
exit_status: None,
terminal: None,
command_finished: false,
was_content_truncated: false,
finished_with_empty_output: false,
original_content_len: 0,
content_line_count: 0,
preview_expanded: true,
start_instant: Instant::now(),
elapsed_time: None,
}
}
}
impl ToolCard for TerminalToolCard {
fn render(
&mut self,
status: &ToolUseStatus,
_window: &mut Window,
_workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let Some(terminal) = self.terminal.as_ref() else {
return Empty.into_any();
};
let tool_failed = matches!(status, ToolUseStatus::Error(_));
let command_failed =
self.command_finished && self.exit_status.is_none_or(|code| !code.success());
if (tool_failed || command_failed) && self.elapsed_time.is_none() {
self.elapsed_time = Some(self.start_instant.elapsed());
}
let time_elapsed = self
.elapsed_time
.unwrap_or_else(|| self.start_instant.elapsed());
let should_hide_terminal =
tool_failed || self.finished_with_empty_output || !self.preview_expanded;
let border_color = cx.theme().colors().border.opacity(0.6);
let header_bg = cx
.theme()
.colors()
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.025));
let header_label = h_flex()
.w_full()
.max_w_full()
.px_1()
.gap_0p5()
.opacity(0.8)
.child(
h_flex()
.child(
Icon::new(IconName::Terminal)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
div()
.id(("terminal-tool-header-input-command", self.entity_id))
.text_size(rems(0.8125))
.font_buffer(cx)
.child(self.input_command.clone())
.ml_1p5()
.mr_0p5()
.tooltip({
let path = self
.working_dir
.as_ref()
.cloned()
.or_else(|| env::current_dir().ok())
.map(|path| format!("\"{}\"", path.display()))
.unwrap_or_else(|| "current directory".to_string());
Tooltip::text(if self.command_finished {
format!("Ran in {path}")
} else {
format!("Running in {path}")
})
}),
),
)
.into_any_element();
let header = h_flex()
.flex_none()
.p_1()
.gap_1()
.justify_between()
.rounded_t_md()
.bg(header_bg)
.child(header_label)
.map(|header| {
let header = header
.when(self.was_content_truncated, |header| {
let tooltip =
if self.content_line_count + 10 > terminal::MAX_SCROLL_HISTORY_LINES {
"Output exceeded terminal max lines and was \
truncated, the model received the first 16 KB."
.to_string()
} else {
format!(
"Output is {} long, to avoid unexpected token usage, \
only 16 KB was sent back to the model.",
format_file_size(self.original_content_len as u64, true),
)
};
header.child(
div()
.id(("terminal-tool-truncated-label", self.entity_id))
.tooltip(Tooltip::text(tooltip))
.child(
Label::new("(truncated)")
.color(Color::Disabled)
.size(LabelSize::Small),
),
)
})
.when(time_elapsed > Duration::from_secs(10), |header| {
header.child(
Label::new(format!("({})", duration_alt_display(time_elapsed)))
.buffer_font(cx)
.color(Color::Disabled)
.size(LabelSize::Small),
)
});
if tool_failed || command_failed {
header.child(
div()
.id(("terminal-tool-error-code-indicator", self.entity_id))
.child(
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
.when(command_failed && self.exit_status.is_some(), |this| {
this.tooltip(Tooltip::text(format!(
"Exited with code {}",
self.exit_status
.and_then(|status| status.code())
.unwrap_or(-1),
)))
})
.when(
!command_failed && tool_failed && status.error().is_some(),
|this| {
this.tooltip(Tooltip::text(format!(
"Error: {}",
status.error().unwrap(),
)))
},
),
)
} else if self.command_finished {
header.child(
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success),
)
let mut has_stdout = true;
let mut has_stderr = true;
while (has_stdout || has_stderr) && combined_buffer.len() < LIMIT + 1 {
futures::select_biased! {
read = out_line => {
drop(out_line);
combined_buffer.extend(out_tmp_buffer.drain(..));
if read? == 0 {
out_line = Box::pin(future::pending().right_future().fuse());
has_stdout = false;
} else {
header.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Info)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(delta)))
},
),
)
out_line = Box::pin(out_reader.read_line(&mut out_tmp_buffer).left_future().fuse());
}
})
.when(!tool_failed && !self.finished_with_empty_output, |header| {
header.child(
Disclosure::new(
("terminal-tool-disclosure", self.entity_id),
self.preview_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.preview_expanded = !this.preview_expanded;
},
)),
)
});
}
read = err_line => {
drop(err_line);
combined_buffer.extend(err_tmp_buffer.drain(..));
if read? == 0 {
err_line = Box::pin(future::pending().right_future().fuse());
has_stderr = false;
} else {
err_line = Box::pin(err_reader.read_line(&mut err_tmp_buffer).left_future().fuse());
}
}
};
}
v_flex()
.mb_2()
.border_1()
.when(tool_failed || command_failed, |card| card.border_dashed())
.border_color(border_color)
.rounded_lg()
.overflow_hidden()
.child(header)
.when(!should_hide_terminal, |this| {
this.child(div().child(terminal.clone()).min_h(px(250.0)))
})
.into_any()
drop((out_line, err_line));
let truncated = combined_buffer.len() > LIMIT;
combined_buffer.truncate(LIMIT);
consume_reader(out_reader, truncated).await?;
consume_reader(err_reader, truncated).await?;
// Handle potential errors during status retrieval, including interruption.
match cmd.status().await {
Ok(status) => {
let output_string = if truncated {
// Valid to find `\n` in UTF-8 since 0-127 ASCII characters are not used in
// multi-byte characters.
let last_line_ix = combined_buffer.bytes().rposition(|b| b == b'\n');
let buffer_content =
&combined_buffer[..last_line_ix.unwrap_or(combined_buffer.len())];
format!(
"Command output too long. The first {} bytes:\n\n{}",
buffer_content.len(),
output_block(buffer_content),
)
} else {
output_block(&combined_buffer)
};
let output_with_status = if status.success() {
if output_string.is_empty() {
"Command executed successfully.".to_string()
} else {
output_string
}
} else {
format!(
"Command failed with exit code {} (shell: {}).\n\n{}",
status.code().unwrap_or(-1),
shell,
output_string,
)
};
Ok(output_with_status)
}
Err(err) => {
// Error occurred getting status (potential interruption). Include partial output.
let partial_output = output_block(&combined_buffer);
let error_message = format!(
"Command failed or was interrupted.\nPartial output captured:\n\n{}",
partial_output
);
Err(anyhow!(err).context(error_message))
}
}
}
async fn consume_reader<T: AsyncReadExt + Unpin>(
mut reader: BufReader<T>,
truncated: bool,
) -> Result<(), std::io::Error> {
loop {
let skipped_bytes = reader.fill_buf().await?;
if skipped_bytes.is_empty() {
break;
}
let skipped_bytes_len = skipped_bytes.len();
reader.consume_unpin(skipped_bytes_len);
// Should only skip if we went over the limit
debug_assert!(truncated);
}
Ok(())
}
fn output_block(output: &str) -> String {
format!(
"```\n{}{}```",
output,
if output.ends_with('\n') { "" } else { "\n" }
)
}
#[cfg(test)]
#[cfg(not(windows))]
mod tests {
use gpui::TestAppContext;
use super::*;
#[gpui::test(iterations = 10)]
async fn test_run_command_simple(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result =
run_command_limited(Path::new(".").into(), "echo 'Hello, World!'".to_string()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\nHello, World!\n```");
}
#[gpui::test(iterations = 10)]
async fn test_interleaved_stdout_stderr(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let command = "echo 'stdout 1' && sleep 0.01 && echo 'stderr 1' >&2 && sleep 0.01 && echo 'stdout 2' && sleep 0.01 && echo 'stderr 2' >&2";
let result = run_command_limited(Path::new(".").into(), command.to_string()).await;
assert!(result.is_ok());
assert_eq!(
result.unwrap(),
"```\nstdout 1\nstderr 1\nstdout 2\nstderr 2\n```"
);
}
#[gpui::test(iterations = 10)]
async fn test_multiple_output_reads(cx: &mut TestAppContext) {
cx.executor().allow_parking();
// Command with multiple outputs that might require multiple reads
let result = run_command_limited(
Path::new(".").into(),
"echo '1'; sleep 0.01; echo '2'; sleep 0.01; echo '3'".to_string(),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "```\n1\n2\n3\n```");
}
#[gpui::test(iterations = 10)]
async fn test_output_truncation_single_line(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; sleep 0.01;", "X".repeat(LIMIT * 2));
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
// Output should be exactly the limit
assert_eq!(content_length, LIMIT);
}
#[gpui::test(iterations = 10)]
async fn test_output_truncation_multiline(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let cmd = format!("echo '{}'; ", "X".repeat(120)).repeat(160);
let result = run_command_limited(Path::new(".").into(), cmd).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.starts_with("Command output too long. The first 16334 bytes:\n\n"));
let content_start = output.find("```\n").map(|i| i + 4).unwrap_or(0);
let content_end = output.rfind("\n```").unwrap_or(output.len());
let content_length = content_end - content_start;
assert!(content_length <= LIMIT);
}
#[gpui::test(iterations = 10)]
async fn test_command_failure(cx: &mut TestAppContext) {
cx.executor().allow_parking();
let result = run_command_limited(Path::new(".").into(), "exit 42".to_string()).await;
assert!(result.is_ok());
let output = result.unwrap();
// Extract the shell name from path for cleaner test output
let shell_path = std::env::var("SHELL").unwrap_or("bash".to_string());
let expected_output = format!(
"Command failed with exit code 42 (shell: {}).\n\n```\n\n```",
shell_path
);
assert_eq!(output, expected_output);
}
}

View File

@@ -1,11 +1,9 @@
Executes a shell one-liner and returns the combined output.
This tool spawns a process using the user's shell, reads from stdout and stderr (preserving the order of writes), and returns a string with the combined output result.
The output results will be shown to the user already, only list it again if necessary, avoid being redundant.
This tool spawns a process using the user's current shell, combines stdout and stderr into one interleaved stream as they are produced (preserving the order of writes), and captures that stream into a string which is returned.
Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error.
Do not use this tool for commands that run indefinitely, such as servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers that don't terminate on their own.
Do not use this tool for commands that run indefinitely, such as servers (e.g., `python -m http.server`) or file watchers that don't terminate on their own.
Remember that each invocation of this tool will spawn a new shell process, so you can't rely on any state from previous invocations.

View File

@@ -23,6 +23,7 @@ pub struct WebSearchToolInput {
query: String,
}
#[derive(RegisterComponent)]
pub struct WebSearchTool;
impl Tool for WebSearchTool {
@@ -83,7 +84,6 @@ impl Tool for WebSearchTool {
}
}
#[derive(RegisterComponent)]
struct WebSearchToolCard {
response: Option<Result<WebSearchResponse>>,
_task: Task<()>,
@@ -185,11 +185,15 @@ impl ToolCard for WebSearchToolCard {
}
}
impl Component for WebSearchToolCard {
impl Component for WebSearchTool {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn sort_name() -> &'static str {
"ToolWebSearch"
}
fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let in_progress_search = cx.new(|cx| WebSearchToolCard {
response: None,

View File

@@ -11,14 +11,12 @@ pub use aws_sdk_bedrockruntime::types::{
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
};
pub use aws_smithy_types::Blob as BedrockBlob;
use aws_smithy_types::{Document, Number as AwsNumber};
pub use bedrock::operation::converse_stream::ConverseStreamInput as BedrockStreamingRequest;
pub use bedrock::types::{
ContentBlock as BedrockRequestContent, ConversationRole as BedrockRole,
ConverseOutput as BedrockResponse, ConverseStreamOutput as BedrockStreamingResponse,
ImageBlock as BedrockImageBlock, Message as BedrockMessage,
ReasoningContentBlock as BedrockThinkingBlock, ReasoningTextBlock as BedrockThinkingTextBlock,
ResponseStream as BedrockResponseStream, ToolResultBlock as BedrockToolResultBlock,
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultStatus as BedrockToolResultStatus, ToolUseBlock as BedrockToolUseBlock,

View File

@@ -26,7 +26,6 @@ use crate::api::events::SnowflakeRow;
use crate::db::billing_subscription::{
StripeCancellationReason, StripeSubscriptionStatus, SubscriptionKind,
};
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::rpc::{ResultExt as _, Server};
use crate::{AppState, Cents, Error, Result};
@@ -362,7 +361,12 @@ async fn create_billing_subscription(
let checkout_session_url = match body.product {
Some(ProductCode::ZedPro) => {
stripe_billing
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
.checkout_with_price(
app.config.zed_pro_price_id()?,
customer_id,
&user.github_login,
&success_url,
)
.await?
}
Some(ProductCode::ZedProTrial) => {
@@ -375,13 +379,11 @@ async fn create_billing_subscription(
}
}
let feature_flags = app.db.get_user_flags(user.id).await?;
stripe_billing
.checkout_with_zed_pro_trial(
app.config.zed_pro_price_id()?,
customer_id,
&user.github_login,
feature_flags,
&success_url,
)
.await?
@@ -452,14 +454,6 @@ async fn manage_billing_subscription(
))?
};
let Some(stripe_billing) = app.stripe_billing.clone() else {
log::error!("failed to retrieve Stripe billing object");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let customer = app
.db
.get_billing_customer_by_user_id(user.id)
@@ -510,8 +504,8 @@ async fn manage_billing_subscription(
let flow = match body.intent {
ManageSubscriptionIntent::ManageSubscription => None,
ManageSubscriptionIntent::UpgradeToPro => {
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await?;
let zed_pro_price_id = app.config.zed_pro_price_id()?;
let zed_free_price_id = app.config.zed_free_price_id()?;
let stripe_subscription =
Subscription::retrieve(&stripe_client, &subscription_id, &[]).await?;
@@ -858,11 +852,9 @@ async fn handle_customer_subscription_event(
log::info!("handling Stripe {} event: {}", event.type_, event.id);
let subscription_kind = maybe!(async {
let stripe_billing = app.stripe_billing.clone()?;
let zed_pro_price_id = stripe_billing.zed_pro_price_id().await.ok()?;
let zed_free_price_id = stripe_billing.zed_free_price_id().await.ok()?;
let subscription_kind = maybe!({
let zed_pro_price_id = app.config.zed_pro_price_id().ok()?;
let zed_free_price_id = app.config.zed_free_price_id().ok()?;
subscription.items.data.iter().find_map(|item| {
let price = item.price.as_ref()?;
@@ -879,8 +871,7 @@ async fn handle_customer_subscription_event(
None
}
})
})
.await;
});
let billing_customer =
find_or_create_billing_customer(app, stripe_client, subscription.customer)
@@ -1095,17 +1086,9 @@ struct UsageCounts {
pub remaining: Option<i32>,
}
#[derive(Debug, Serialize)]
struct ModelRequestUsage {
pub model: String,
pub mode: CompletionMode,
pub requests: i32,
}
#[derive(Debug, Serialize)]
struct GetCurrentUsageResponse {
pub model_requests: UsageCounts,
pub model_request_usage: Vec<ModelRequestUsage>,
pub edit_predictions: UsageCounts,
}
@@ -1132,7 +1115,6 @@ async fn get_current_usage(
limit: Some(0),
remaining: Some(0),
},
model_request_usage: Vec::new(),
edit_predictions: UsageCounts {
used: 0,
limit: Some(0),
@@ -1177,30 +1159,12 @@ async fn get_current_usage(
zed_llm_client::UsageLimit::Unlimited => None,
};
let subscription_usage_meters = llm_db
.get_current_subscription_usage_meters_for_user(user.id, Utc::now())
.await?;
let model_request_usage = subscription_usage_meters
.into_iter()
.filter_map(|(usage_meter, _usage)| {
let model = llm_db.model_by_id(usage_meter.model_id).ok()?;
Some(ModelRequestUsage {
model: model.name.clone(),
mode: usage_meter.mode,
requests: usage_meter.requests,
})
})
.collect::<Vec<_>>();
Ok(Json(GetCurrentUsageResponse {
model_requests: UsageCounts {
used: usage.model_requests,
limit: model_requests_limit,
remaining: model_requests_limit.map(|limit| (limit - usage.model_requests).max(0)),
},
model_request_usage,
edit_predictions: UsageCounts {
used: usage.edit_predictions,
limit: edit_prediction_limit,
@@ -1408,9 +1372,6 @@ async fn sync_model_request_usage_with_stripe(
let claude_3_7_sonnet = stripe_billing
.find_price_by_lookup_key("claude-3-7-sonnet-requests")
.await?;
let claude_3_7_sonnet_max = stripe_billing
.find_price_by_lookup_key("claude-3-7-sonnet-requests-max")
.await?;
for (usage_meter, usage) in usage_meters {
maybe!(async {
@@ -1434,21 +1395,16 @@ async fn sync_model_request_usage_with_stripe(
let model = llm_db.model_by_id(usage_meter.model_id)?;
let (price, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => match usage_meter.mode {
CompletionMode::Normal => (&claude_3_7_sonnet, "claude_3_7_sonnet/requests"),
CompletionMode::Max => {
(&claude_3_7_sonnet_max, "claude_3_7_sonnet/requests/max")
}
},
let (price_id, meter_event_name) = match model.name.as_str() {
"claude-3-5-sonnet" => (&claude_3_5_sonnet.id, "claude_3_5_sonnet/requests"),
"claude-3-7-sonnet" => (&claude_3_7_sonnet.id, "claude_3_7_sonnet/requests"),
model_name => {
bail!("Attempted to sync usage meter for unsupported model: {model_name:?}")
}
};
stripe_billing
.subscribe_to_price(&stripe_subscription_id, price)
.subscribe_to_price(&stripe_subscription_id, price_id)
.await?;
stripe_billing
.bill_model_request_usage(

View File

@@ -180,6 +180,9 @@ pub struct Config {
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
pub stripe_api_key: Option<String>,
pub stripe_zed_pro_price_id: Option<String>,
pub stripe_zed_pro_trial_price_id: Option<String>,
pub stripe_zed_free_price_id: Option<String>,
pub supermaven_admin_api_key: Option<Arc<str>>,
pub user_backfiller_github_access_token: Option<Arc<str>>,
}
@@ -198,6 +201,22 @@ impl Config {
}
}
pub fn zed_pro_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Pro", self.stripe_zed_pro_price_id.as_deref())
}
pub fn zed_free_price_id(&self) -> anyhow::Result<stripe::PriceId> {
Self::parse_stripe_price_id("Zed Free", self.stripe_zed_pro_price_id.as_deref())
}
fn parse_stripe_price_id(name: &str, value: Option<&str>) -> anyhow::Result<stripe::PriceId> {
use std::str::FromStr as _;
let price_id = value.ok_or_else(|| anyhow!("{name} price ID not set"))?;
Ok(stripe::PriceId::from_str(price_id)?)
}
#[cfg(test)]
pub fn test() -> Self {
Self {
@@ -235,6 +254,9 @@ impl Config {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View File

@@ -5,8 +5,6 @@ use crate::Cents;
pub use token::*;
pub const AGENT_EXTENDED_TRIAL_FEATURE_FLAG: &str = "agent-extended-trial";
/// The maximum monthly spending an individual user can reach on the free tier
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(10);

View File

@@ -1,4 +1,3 @@
use crate::db::UserId;
use crate::llm::db::queries::subscription_usages::convert_chrono_to_time;
use super::*;
@@ -35,38 +34,4 @@ impl LlmDatabase {
})
.await
}
/// Returns all current subscription usage meters for the given user as of the given timestamp.
pub async fn get_current_subscription_usage_meters_for_user(
&self,
user_id: UserId,
now: DateTimeUtc,
) -> Result<Vec<(subscription_usage_meter::Model, subscription_usage::Model)>> {
let now = convert_chrono_to_time(now)?;
self.transaction(|tx| async move {
let result = subscription_usage_meter::Entity::find()
.inner_join(subscription_usage::Entity)
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(
subscription_usage::Column::PeriodStartAt
.lte(now)
.and(subscription_usage::Column::PeriodEndAt.gte(now)),
)
.select_also(subscription_usage::Entity)
.all(&*tx)
.await?;
let result = result
.into_iter()
.filter_map(|(meter, usage)| {
let usage = usage?;
Some((meter, usage))
})
.collect();
Ok(result)
})
.await
}
}

View File

@@ -1,5 +1,4 @@
use sea_orm::entity::prelude::*;
use serde::Serialize;
use crate::llm::db::ModelId;
@@ -10,7 +9,6 @@ pub struct Model {
pub id: i32,
pub subscription_usage_id: i32,
pub model_id: ModelId,
pub mode: CompletionMode,
pub requests: i32,
}
@@ -43,13 +41,3 @@ impl Related<super::model::Entity> for Entity {
}
impl ActiveModelBehavior for ActiveModel {}
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Hash, Serialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
#[sea_orm(string_value = "normal")]
Normal,
#[sea_orm(string_value = "max")]
Max,
}

View File

@@ -1,9 +1,7 @@
use crate::Cents;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{billing_subscription, user};
use crate::llm::{
AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT,
};
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::{Config, db::billing_preference};
use anyhow::{Result, anyhow};
use chrono::{NaiveDateTime, Utc};
@@ -32,12 +30,8 @@ pub struct LlmTokenClaims {
pub has_llm_subscription: bool,
pub max_monthly_spend_in_cents: u32,
pub custom_llm_monthly_allowance_in_cents: Option<u32>,
#[serde(default)]
pub use_new_billing: bool,
pub plan: Plan,
#[serde(default)]
pub has_extended_trial: bool,
#[serde(default)]
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
#[serde(default)]
pub enable_model_request_overages: bool,
@@ -92,7 +86,6 @@ impl LlmTokenClaims {
custom_llm_monthly_allowance_in_cents: user
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
use_new_billing: feature_flags.iter().any(|flag| flag == "new-billing"),
plan: subscription
.as_ref()
.and_then(|subscription| subscription.kind)
@@ -101,9 +94,6 @@ impl LlmTokenClaims {
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
}),
has_extended_trial: feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG),
subscription_period: maybe!({
let subscription = subscription?;
let period_start_at = subscription.current_period_start_at()?;

View File

@@ -1,7 +1,6 @@
use std::sync::Arc;
use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
use crate::{Cents, Result};
use crate::{Cents, Result, llm};
use anyhow::{Context as _, anyhow};
use chrono::{Datelike, Utc};
use collections::HashMap;
@@ -81,24 +80,6 @@ impl StripeBilling {
Ok(())
}
pub async fn zed_pro_price_id(&self) -> Result<PriceId> {
self.find_price_id_by_lookup_key("zed-pro").await
}
pub async fn zed_free_price_id(&self) -> Result<PriceId> {
self.find_price_id_by_lookup_key("zed-free").await
}
pub async fn find_price_id_by_lookup_key(&self, lookup_key: &str) -> Result<PriceId> {
self.state
.read()
.await
.prices_by_lookup_key
.get(lookup_key)
.map(|price| price.id.clone())
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn find_price_by_lookup_key(&self, lookup_key: &str) -> Result<stripe::Price> {
self.state
.read()
@@ -106,7 +87,7 @@ impl StripeBilling {
.prices_by_lookup_key
.get(lookup_key)
.cloned()
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
.ok_or_else(|| crate::Error::Internal(anyhow!("no price ID found for {lookup_key:?}")))
}
pub async fn register_model_for_token_based_usage(
@@ -248,29 +229,21 @@ impl StripeBilling {
pub async fn subscribe_to_price(
&self,
subscription_id: &stripe::SubscriptionId,
price: &stripe::Price,
price_id: &stripe::PriceId,
) -> Result<()> {
let subscription =
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
if subscription_contains_price(&subscription, &price.id) {
if subscription_contains_price(&subscription, price_id) {
return Ok(());
}
const BILLING_THRESHOLD_IN_CENTS: i64 = 20 * 100;
let price_per_unit = price.unit_amount.unwrap_or_default();
let units_for_billing_threshold = BILLING_THRESHOLD_IN_CENTS / price_per_unit;
stripe::Subscription::update(
&self.client,
subscription_id,
stripe::UpdateSubscription {
items: Some(vec![stripe::UpdateSubscriptionItems {
price: Some(price.id.to_string()),
billing_thresholds: Some(stripe::SubscriptionItemBillingThresholds {
usage_gte: Some(units_for_billing_threshold),
}),
price: Some(price_id.to_string()),
..Default::default()
}]),
trial_settings: Some(stripe::UpdateSubscriptionTrialSettings {
@@ -489,20 +462,19 @@ impl StripeBilling {
Ok(session.url.context("no checkout session URL")?)
}
pub async fn checkout_with_zed_pro(
pub async fn checkout_with_price(
&self,
price_id: PriceId,
customer_id: stripe::CustomerId,
github_login: &str,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let mut params = stripe::CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {
price: Some(zed_pro_price_id.to_string()),
price: Some(price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
@@ -514,40 +486,19 @@ impl StripeBilling {
pub async fn checkout_with_zed_pro_trial(
&self,
zed_pro_price_id: PriceId,
customer_id: stripe::CustomerId,
github_login: &str,
feature_flags: Vec<String>,
success_url: &str,
) -> Result<String> {
let zed_pro_price_id = self.zed_pro_price_id().await?;
let eligible_for_extended_trial = feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG);
let trial_period_days = if eligible_for_extended_trial { 60 } else { 14 };
let mut subscription_metadata = std::collections::HashMap::new();
if eligible_for_extended_trial {
subscription_metadata.insert(
"promo_feature_flag".to_string(),
AGENT_EXTENDED_TRIAL_FEATURE_FLAG.to_string(),
);
}
let mut params = stripe::CreateCheckoutSession::new();
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
trial_period_days: Some(trial_period_days),
trial_period_days: Some(14),
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
}
}),
metadata: if !subscription_metadata.is_empty() {
Some(subscription_metadata)
} else {
None
},
..Default::default()
});
params.mode = Some(stripe::CheckoutSessionMode::Subscription);

View File

@@ -554,6 +554,9 @@ impl TestServer {
migrations_path: None,
seed_path: None,
stripe_api_key: None,
stripe_zed_pro_price_id: None,
stripe_zed_pro_trial_price_id: None,
stripe_zed_free_price_id: None,
supermaven_admin_api_key: None,
user_backfiller_github_access_token: None,
kinesis_region: None,

View File

@@ -393,23 +393,6 @@ impl ChannelView {
buffer.acknowledge_buffer_version(cx);
});
}
fn get_channel(&self, cx: &App) -> (SharedString, Option<SharedString>) {
if let Some(channel) = self.channel(cx) {
let status = match (
self.channel_buffer.read(cx).buffer().read(cx).read_only(),
self.channel_buffer.read(cx).is_connected(),
) {
(false, true) => None,
(true, true) => Some("read-only"),
(_, false) => Some("disconnected"),
};
(channel.name.clone(), status.map(Into::into))
} else {
("<unknown>".into(), Some("disconnected".into()))
}
}
}
impl EventEmitter<EditorEvent> for ChannelView {}
@@ -457,21 +440,26 @@ impl Item for ChannelView {
Some(Icon::new(icon))
}
fn tab_content_text(&self, _detail: usize, cx: &App) -> SharedString {
let (name, status) = self.get_channel(cx);
if let Some(status) = status {
format!("{name} - {status}").into()
} else {
name
}
}
fn tab_content(&self, params: TabContentParams, _: &Window, cx: &App) -> gpui::AnyElement {
let (name, status) = self.get_channel(cx);
let (channel_name, status) = if let Some(channel) = self.channel(cx) {
let status = match (
self.channel_buffer.read(cx).buffer().read(cx).read_only(),
self.channel_buffer.read(cx).is_connected(),
) {
(false, true) => None,
(true, true) => Some("read-only"),
(_, false) => Some("disconnected"),
};
(channel.name.clone(), status)
} else {
("<unknown>".into(), Some("disconnected"))
};
h_flex()
.gap_2()
.child(
Label::new(name)
Label::new(channel_name)
.color(params.text_color())
.when(params.preview, |this| this.italic()),
)
@@ -552,6 +540,10 @@ impl Item for ChannelView {
fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) {
Editor::to_item_events(event, f)
}
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
"Channels".into()
}
}
impl FollowableItem for ChannelView {

View File

@@ -15,18 +15,24 @@ path = "src/component_preview.rs"
default = []
[dependencies]
anyhow.workspace = true
client.workspace = true
collections.workspace = true
component.workspace = true
db.workspace = true
gpui.workspace = true
languages.workspace = true
notifications.workspace = true
log.workspace = true
notifications.workspace = true
project.workspace = true
serde.workspace = true
ui.workspace = true
ui_input.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
db.workspace = true
anyhow.workspace = true
serde.workspace = true
util.workspace = true
# Dependencies for supporting specific previews
agent.workspace = true
assistant_tool.workspace = true
prompt_store.workspace = true

View File

@@ -3,10 +3,12 @@
//! A view for exploring Zed components.
mod persistence;
mod preview_support;
use std::iter::Iterator;
use std::sync::Arc;
use agent::{ActiveThread, ThreadStore};
use client::UserStore;
use component::{ComponentId, ComponentMetadata, components};
use gpui::{
@@ -19,6 +21,7 @@ use gpui::{ListState, ScrollHandle, ScrollStrategy, UniformListScrollHandle};
use languages::LanguageRegistry;
use notifications::status_toast::{StatusToast, ToastIcon};
use persistence::COMPONENT_PREVIEW_DB;
use preview_support::active_thread::{load_preview_thread_store, static_active_thread};
use project::Project;
use ui::{Divider, HighlightedLabel, ListItem, ListSubHeader, prelude::*};
@@ -33,6 +36,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
cx.observe_new(move |workspace: &mut Workspace, _window, cx| {
let app_state = app_state.clone();
let project = workspace.project().clone();
let weak_workspace = cx.entity().downgrade();
workspace.register_action(
@@ -45,6 +49,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
let component_preview = cx.new(|cx| {
ComponentPreview::new(
weak_workspace.clone(),
project.clone(),
language_registry,
user_store,
None,
@@ -52,6 +57,8 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
window,
cx,
)
// TODO: don't panic here if we fail to create
.expect("Failed to create component preview")
});
workspace.add_item_to_active_pane(
@@ -69,6 +76,7 @@ pub fn init(app_state: Arc<AppState>, cx: &mut App) {
enum PreviewEntry {
AllComponents,
ActiveThread,
Separator,
Component(ComponentMetadata, Option<Vec<usize>>),
SectionHeader(SharedString),
@@ -91,6 +99,7 @@ enum PreviewPage {
#[default]
AllComponents,
Component(ComponentId),
ActiveThread,
}
struct ComponentPreview {
@@ -105,21 +114,50 @@ struct ComponentPreview {
cursor_index: usize,
language_registry: Arc<LanguageRegistry>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
user_store: Entity<UserStore>,
filter_editor: Entity<SingleLineInput>,
filter_text: String,
// preview support
thread_store: Option<Entity<ThreadStore>>,
active_thread: Option<Entity<ActiveThread>>,
}
impl ComponentPreview {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
language_registry: Arc<LanguageRegistry>,
user_store: Entity<UserStore>,
selected_index: impl Into<Option<usize>>,
active_page: Option<PreviewPage>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
) -> anyhow::Result<Self> {
let workspace_clone = workspace.clone();
let project_clone = project.clone();
let entity = cx.weak_entity();
window
.spawn(cx, async move |cx| {
let thread_store_task =
load_preview_thread_store(workspace_clone.clone(), project_clone.clone(), cx)
.await;
if let Ok(thread_store) = dbg!(thread_store_task.await) {
entity
.update_in(cx, |this, window, cx| {
this.thread_store = Some(thread_store.clone());
this.create_active_thread(window, cx);
cx.notify();
})
.ok();
}
})
.detach();
let sorted_components = components().all_sorted();
let selected_index = selected_index.into().unwrap_or(0);
let active_page = active_page.unwrap_or(PreviewPage::AllComponents);
@@ -151,6 +189,7 @@ impl ComponentPreview {
language_registry,
user_store,
workspace,
project,
active_page,
component_map: components().0,
components: sorted_components,
@@ -158,6 +197,8 @@ impl ComponentPreview {
cursor_index: selected_index,
filter_editor,
filter_text: String::new(),
thread_store: None,
active_thread: None,
};
if component_preview.cursor_index > 0 {
@@ -169,13 +210,36 @@ impl ComponentPreview {
let focus_handle = component_preview.filter_editor.read(cx).focus_handle(cx);
window.focus(&focus_handle);
component_preview
dbg!(component_preview.thread_store.clone());
Ok(component_preview)
}
pub fn create_active_thread(
&mut self,
window: &mut Window,
cx: &mut Context<Self>,
) -> &mut Self {
println!("Creating active thread");
let workspace = self.workspace.clone();
let language_registry = self.language_registry.clone();
if let Some(thread_store) = self.thread_store.clone() {
let active_thread =
static_active_thread(workspace, language_registry, thread_store, window, cx);
self.active_thread = Some(active_thread);
cx.notify();
}
dbg!(self.active_thread.clone());
self
}
pub fn active_page_id(&self, _cx: &App) -> ActivePageId {
match &self.active_page {
PreviewPage::AllComponents => ActivePageId::default(),
PreviewPage::Component(component_id) => ActivePageId(component_id.0.to_string()),
PreviewPage::ActiveThread => ActivePageId("active_thread".to_string()),
}
}
@@ -289,6 +353,7 @@ impl ComponentPreview {
// Always show all components first
entries.push(PreviewEntry::AllComponents);
entries.push(PreviewEntry::ActiveThread);
entries.push(PreviewEntry::Separator);
let mut scopes: Vec<_> = scope_groups
@@ -389,6 +454,19 @@ impl ComponentPreview {
}))
.into_any_element()
}
PreviewEntry::ActiveThread => {
let selected = self.active_page == PreviewPage::ActiveThread;
ListItem::new(ix)
.child(Label::new("Active Thread").color(Color::Default))
.selectable(true)
.toggle_state(selected)
.inset(true)
.on_click(cx.listener(move |this, _, _, cx| {
this.set_active_page(PreviewPage::ActiveThread, cx);
}))
.into_any_element()
}
PreviewEntry::Separator => ListItem::new(ix)
.child(
h_flex()
@@ -471,6 +549,7 @@ impl ComponentPreview {
.render_scope_header(ix, shared_string.clone(), window, cx)
.into_any_element(),
PreviewEntry::AllComponents => div().w_full().h_0().into_any_element(),
PreviewEntry::ActiveThread => div().w_full().h_0().into_any_element(),
PreviewEntry::Separator => div().w_full().h_0().into_any_element(),
})
.unwrap()
@@ -595,6 +674,19 @@ impl ComponentPreview {
}
}
fn render_active_thread(&self, _cx: &mut Context<Self>) -> impl IntoElement {
dbg!(&self.active_thread);
v_flex()
.id("render-active-thread")
.size_full()
.children(self.active_thread.clone().map(|thread| thread.clone()))
.when_none(&self.active_thread.clone(), |this| {
this.child("No active thread")
})
.into_any_element()
}
fn test_status_toast(&self, cx: &mut Context<Self>) {
if let Some(workspace) = self.workspace.upgrade() {
workspace.update(cx, |workspace, cx| {
@@ -704,6 +796,9 @@ impl Render for ComponentPreview {
PreviewPage::Component(id) => self
.render_component_page(&id, window, cx)
.into_any_element(),
PreviewPage::ActiveThread => {
self.render_active_thread(cx).into_any_element()
}
}),
)
}
@@ -759,12 +854,14 @@ impl Item for ComponentPreview {
let language_registry = self.language_registry.clone();
let user_store = self.user_store.clone();
let weak_workspace = self.workspace.clone();
let project = self.project.clone();
let selected_index = self.cursor_index;
let active_page = self.active_page.clone();
Some(cx.new(|cx| {
Self::new(
weak_workspace,
project,
language_registry,
user_store,
selected_index,
@@ -772,6 +869,7 @@ impl Item for ComponentPreview {
window,
cx,
)
.expect("Failed to create new component preview")
}))
}
@@ -838,10 +936,12 @@ impl SerializableItem for ComponentPreview {
let user_store = user_store.clone();
let language_registry = language_registry.clone();
let weak_workspace = workspace.clone();
let project = project.clone();
cx.update(move |window, cx| {
Ok(cx.new(|cx| {
ComponentPreview::new(
weak_workspace,
project,
language_registry,
user_store,
None,
@@ -849,6 +949,7 @@ impl SerializableItem for ComponentPreview {
window,
cx,
)
.expect("Failed to create new component preview")
}))
})?
})

View File

@@ -0,0 +1 @@
pub mod active_thread;

View File

@@ -0,0 +1,64 @@
use languages::LanguageRegistry;
use project::Project;
use std::sync::Arc;
use agent::{ActiveThread, MessageSegment, ThreadStore};
use assistant_tool::ToolWorkingSet;
use gpui::{AppContext, AsyncApp, Entity, Task, WeakEntity};
use prompt_store::PromptBuilder;
use ui::{App, Window};
use workspace::Workspace;
pub async fn load_preview_thread_store(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
cx: &mut AsyncApp,
) -> Task<anyhow::Result<Entity<ThreadStore>>> {
cx.spawn(async move |cx| {
workspace
.update(cx, |_, cx| {
ThreadStore::load(
project.clone(),
cx.new(|_| ToolWorkingSet::default()),
Arc::new(PromptBuilder::new(None).unwrap()),
None,
cx,
)
})?
.await
})
}
pub fn static_active_thread(
workspace: WeakEntity<Workspace>,
language_registry: Arc<LanguageRegistry>,
thread_store: Entity<ThreadStore>,
window: &mut Window,
cx: &mut App,
) -> Entity<ActiveThread> {
let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
thread.update(cx, |thread, cx| {
thread.insert_assistant_message(vec![
MessageSegment::Text("I'll help you fix the lifetime error in your `cx.spawn` call. When working with async operations in GPUI, there are specific patterns to follow for proper lifetime management.".to_string()),
MessageSegment::Text("\n\nLet's look at what's happening in your code:".to_string()),
MessageSegment::Text("\n\n---\n\nLet's check the current state of the active_thread.rs file to understand what might have changed:".to_string()),
MessageSegment::Text("\n\n---\n\nLooking at the implementation of `load_preview_thread_store` and understanding GPUI's async patterns, here's the issue:".to_string()),
MessageSegment::Text("\n\n1. `load_preview_thread_store` returns a `Task<anyhow::Result<Entity<ThreadStore>>>`, which means it's already a task".to_string()),
MessageSegment::Text("\n2. When you call this function inside another `spawn` call, you're nesting tasks incorrectly".to_string()),
MessageSegment::Text("\n3. The `this` parameter you're trying to use in your closure has the wrong context".to_string()),
MessageSegment::Text("\n\nHere's the correct way to implement this:".to_string()),
MessageSegment::Text("\n\n---\n\nThe problem is in how you're setting up the async closure and trying to reference variables like `window` and `language_registry` that aren't accessible in that scope.".to_string()),
MessageSegment::Text("\n\nHere's how to fix it:".to_string()),
], cx);
});
cx.new(|cx| {
ActiveThread::new(
thread,
thread_store,
language_registry,
workspace.clone(),
window,
cx,
)
})
}

View File

@@ -7,6 +7,7 @@ use crate::{
};
use crate::{new_session_modal::NewSessionModal, session::DebugSession};
use anyhow::{Result, anyhow};
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use dap::DebugRequest;
use dap::{
@@ -14,6 +15,7 @@ use dap::{
client::SessionId, debugger_settings::DebuggerSettings,
};
use dap::{StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
use futures::{SinkExt as _, channel::mpsc};
use gpui::{
Action, App, AsyncWindowContext, Context, DismissEvent, Entity, EntityId, EventEmitter,
FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task, WeakEntity,
@@ -22,11 +24,21 @@ use gpui::{
use language::Buffer;
use project::debugger::session::{Session, SessionStateEvent};
use project::{Project, debugger::session::ThreadStatus};
use project::{
Project,
debugger::{
dap_store::{self, DapStore},
session::ThreadStatus,
},
terminals::TerminalKind,
};
use rpc::proto::{self};
use settings::Settings;
use std::any::TypeId;
use task::{DebugScenario, TaskContext};
use std::path::Path;
use std::sync::Arc;
use task::{DebugScenario, HideStrategy, RevealStrategy, RevealTarget, TaskContext, TaskId};
use terminal_view::TerminalView;
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
use workspace::SplitDirection;
use workspace::{
@@ -62,21 +74,27 @@ pub struct DebugPanel {
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
_subscriptions: Vec<Subscription>,
}
impl DebugPanel {
pub fn new(
workspace: &Workspace,
_window: &mut Window,
window: &mut Window,
cx: &mut Context<Workspace>,
) -> Entity<Self> {
cx.new(|cx| {
let project = workspace.project().clone();
let dap_store = project.read(cx).dap_store();
let _subscriptions =
vec![cx.subscribe_in(&dap_store, window, Self::handle_dap_store_event)];
let debug_panel = Self {
size: px(300.),
sessions: vec![],
active_session: None,
_subscriptions,
past_debug_definition: None,
focus_handle: cx.focus_handle(),
project,
@@ -92,14 +110,20 @@ impl DebugPanel {
let (has_active_session, supports_restart, support_step_back, status) = self
.active_session()
.map(|item| {
let running = item.read(cx).running_state().clone();
let caps = running.read(cx).capabilities(cx);
(
!running.read(cx).session().read(cx).is_terminated(),
caps.supports_restart_request.unwrap_or_default(),
caps.supports_step_back.unwrap_or_default(),
running.read(cx).thread_status(cx),
)
let running = item.read(cx).mode().as_running().cloned();
match running {
Some(running) => {
let caps = running.read(cx).capabilities(cx);
(
!running.read(cx).session().read(cx).is_terminated(),
caps.supports_restart_request.unwrap_or_default(),
caps.supports_step_back.unwrap_or_default(),
running.read(cx).thread_status(cx),
)
}
None => (false, false, false, None),
}
})
.unwrap_or((false, false, false, None));
@@ -264,7 +288,7 @@ impl DebugPanel {
cx.subscribe_in(
&session,
window,
move |this, session, event: &SessionStateEvent, window, cx| match event {
move |_, session, event: &SessionStateEvent, window, cx| match event {
SessionStateEvent::Restart => {
let mut curr_session = session.clone();
while let Some(parent_session) = curr_session
@@ -286,9 +310,6 @@ impl DebugPanel {
})
.detach_and_log_err(cx);
}
SessionStateEvent::SpawnChildSession { request } => {
this.handle_start_debugging_request(request, session.clone(), window, cx);
}
_ => {}
},
)
@@ -302,11 +323,11 @@ impl DebugPanel {
this.sessions.retain(|session| {
session
.read(cx)
.running_state()
.read(cx)
.session()
.read(cx)
.is_terminated()
.mode()
.as_running()
.map_or(false, |running_state| {
!running_state.read(cx).session().read(cx).is_terminated()
})
});
let session_item = DebugSession::running(
@@ -319,13 +340,11 @@ impl DebugPanel {
cx,
);
// We might want to make this an event subscription and only notify when a new thread is selected
// This is used to filter the command menu correctly
cx.observe(
&session_item.read(cx).running_state().clone(),
|_, _, cx| cx.notify(),
)
.detach();
if let Some(running) = session_item.read(cx).mode().as_running().cloned() {
// We might want to make this an event subscription and only notify when a new thread is selected
// This is used to filter the command menu correctly
cx.observe(&running, |_, _, cx| cx.notify()).detach();
}
this.sessions.push(session_item.clone());
this.activate_session(session_item, window, cx);
@@ -338,7 +357,7 @@ impl DebugPanel {
Ok(())
}
pub fn handle_start_debugging_request(
pub fn start_child_session(
&mut self,
request: &StartDebuggingRequestArguments,
parent_session: Entity<Session>,
@@ -400,6 +419,47 @@ impl DebugPanel {
self.active_session.clone()
}
fn handle_dap_store_event(
&mut self,
_dap_store: &Entity<DapStore>,
event: &dap_store::DapStoreEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
match event {
dap_store::DapStoreEvent::RunInTerminal {
session_id,
title,
cwd,
command,
args,
envs,
sender,
..
} => {
self.handle_run_in_terminal_request(
*session_id,
title.clone(),
cwd.clone(),
command.clone(),
args.clone(),
envs.clone(),
sender.clone(),
window,
cx,
)
.detach_and_log_err(cx);
}
dap_store::DapStoreEvent::SpawnChildSession {
request,
parent_session,
} => {
self.start_child_session(request, parent_session.clone(), window, cx);
}
_ => {}
}
}
pub fn resolve_scenario(
&self,
scenario: DebugScenario,
@@ -469,6 +529,101 @@ impl DebugPanel {
})
}
fn handle_run_in_terminal_request(
&self,
session_id: SessionId,
title: Option<String>,
cwd: Option<Arc<Path>>,
command: Option<String>,
args: Vec<String>,
envs: HashMap<String, String>,
mut sender: mpsc::Sender<Result<u32>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let Some(session) = self
.sessions
.iter()
.find(|s| s.read(cx).session_id(cx) == session_id)
else {
return Task::ready(Err(anyhow!("no session {:?} found", session_id)));
};
let running = session.read(cx).running_state();
let cwd = cwd.map(|p| p.to_path_buf());
let shell = self
.project
.read(cx)
.terminal_settings(&cwd, cx)
.shell
.clone();
let kind = if let Some(command) = command {
let title = title.clone().unwrap_or(command.clone());
TerminalKind::Task(task::SpawnInTerminal {
id: TaskId("debug".to_string()),
full_label: title.clone(),
label: title.clone(),
command: command.clone(),
args,
command_label: title.clone(),
cwd,
env: envs,
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: RevealStrategy::NoFocus,
reveal_target: RevealTarget::Dock,
hide: HideStrategy::Never,
shell,
show_summary: false,
show_command: false,
show_rerun: false,
})
} else {
TerminalKind::Shell(cwd.map(|c| c.to_path_buf()))
};
let workspace = self.workspace.clone();
let project = self.project.downgrade();
let terminal_task = self.project.update(cx, |project, cx| {
project.create_terminal(kind, window.window_handle(), cx)
});
let terminal_task = cx.spawn_in(window, async move |_, cx| {
let terminal = terminal_task.await?;
let terminal_view = cx.new_window_entity(|window, cx| {
TerminalView::new(terminal.clone(), workspace, None, project, window, cx)
})?;
running.update_in(cx, |running, window, cx| {
running.ensure_pane_item(DebuggerPaneItem::Terminal, window, cx);
running.debug_terminal.update(cx, |debug_terminal, cx| {
debug_terminal.terminal = Some(terminal_view);
cx.notify();
});
})?;
anyhow::Ok(terminal.read_with(cx, |terminal, _| terminal.pty_info.pid())?)
});
cx.background_spawn(async move {
match terminal_task.await {
Ok(pid_task) => match pid_task {
Some(pid) => sender.send(Ok(pid.as_u32())).await?,
None => {
sender
.send(Err(anyhow!(
"Terminal was spawned but PID was not available"
)))
.await?
}
},
Err(error) => sender.send(Err(anyhow!(error))).await?,
};
Ok(())
})
}
fn close_session(&mut self, entity_id: EntityId, window: &mut Window, cx: &mut Context<Self>) {
let Some(session) = self
.sessions
@@ -479,9 +634,11 @@ impl DebugPanel {
return;
};
session.update(cx, |this, cx| {
this.running_state().update(cx, |this, cx| {
this.serialize_layout(window, cx);
});
if let Some(running) = this.mode().as_running() {
running.update(cx, |this, cx| {
this.serialize_layout(window, cx);
});
}
});
let session_id = session.update(cx, |this, cx| this.session_id(cx));
let should_prompt = self
@@ -618,7 +775,7 @@ impl DebugPanel {
if let Some(running_state) = self
.active_session
.as_ref()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
{
let pane_items_status = running_state.read(cx).pane_items_status(cx);
let this = cx.weak_entity();
@@ -629,10 +786,10 @@ impl DebugPanel {
let this = this.clone();
move |window, cx| {
this.update(cx, |this, cx| {
if let Some(running_state) = this
.active_session
.as_ref()
.map(|session| session.read(cx).running_state().clone())
if let Some(running_state) =
this.active_session.as_ref().and_then(|session| {
session.read(cx).mode().as_running().cloned()
})
{
running_state.update(cx, |state, cx| {
if is_visible {
@@ -675,7 +832,7 @@ impl DebugPanel {
h_flex().gap_2().w_full().when_some(
active_session
.as_ref()
.map(|session| session.read(cx).running_state()),
.and_then(|session| session.read(cx).mode().as_running()),
|this, running_session| {
let thread_status = running_session
.read(cx)
@@ -913,7 +1070,7 @@ impl DebugPanel {
.when_some(
active_session
.as_ref()
.map(|session| session.read(cx).running_state())
.and_then(|session| session.read(cx).mode().as_running())
.cloned(),
|this, session| {
this.child(
@@ -976,10 +1133,12 @@ impl DebugPanel {
) {
if let Some(session) = self.active_session() {
session.update(cx, |session, cx| {
session.running_state().update(cx, |running, cx| {
running.activate_pane_in_direction(direction, window, cx);
})
});
if let Some(running) = session.mode().as_running() {
running.update(cx, |running, cx| {
running.activate_pane_in_direction(direction, window, cx);
})
}
})
}
}
@@ -991,10 +1150,12 @@ impl DebugPanel {
) {
if let Some(session) = self.active_session() {
session.update(cx, |session, cx| {
session.running_state().update(cx, |running, cx| {
running.activate_item(item, window, cx);
});
});
if let Some(running) = session.mode().as_running() {
running.update(cx, |running, cx| {
running.activate_item(item, window, cx);
})
}
})
}
}
@@ -1007,9 +1168,11 @@ impl DebugPanel {
debug_assert!(self.sessions.contains(&session_item));
session_item.focus_handle(cx).focus(window);
session_item.update(cx, |this, cx| {
this.running_state().update(cx, |this, cx| {
this.go_to_selected_stack_frame(window, cx);
});
if let Some(running) = this.mode().as_running() {
running.update(cx, |this, cx| {
this.go_to_selected_stack_frame(window, cx);
});
}
});
self.active_session = Some(session_item);
cx.notify();
@@ -1094,7 +1257,7 @@ impl Render for DebugPanel {
if self
.active_session
.as_ref()
.map(|session| session.read(cx).running_state())
.and_then(|session| session.read(cx).mode().as_running().cloned())
.map(|state| state.read(cx).has_open_context_menu(cx))
.unwrap_or(false)
{
@@ -1212,9 +1375,10 @@ impl Render for DebugPanel {
if this
.active_session
.as_ref()
.map(|session| {
let state = session.read(cx).running_state();
state.read(cx).has_pane_at_position(event.position)
.and_then(|session| {
session.read(cx).mode().as_running().map(|state| {
state.read(cx).has_pane_at_position(event.position)
})
})
.unwrap_or(false)
{

View File

@@ -64,7 +64,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
active_item.update(cx, |item, cx| item.pause_thread(cx))
}
@@ -75,7 +75,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
active_item.update(cx, |item, cx| item.restart_session(cx))
}
@@ -86,7 +86,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
active_item.update(cx, |item, cx| item.step_in(cx))
}
@@ -97,7 +97,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
active_item.update(cx, |item, cx| item.step_over(cx))
}
@@ -108,7 +108,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
active_item.update(cx, |item, cx| item.step_back(cx))
}
@@ -119,7 +119,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
cx.defer(move |cx| {
active_item.update(cx, |item, cx| item.stop_thread(cx))
@@ -132,7 +132,7 @@ pub fn init(cx: &mut App) {
if let Some(active_item) = debug_panel.read_with(cx, |panel, cx| {
panel
.active_session()
.map(|session| session.read(cx).running_state().clone())
.and_then(|session| session.read(cx).mode().as_running().cloned())
}) {
active_item.update(cx, |item, cx| item.toggle_ignore_breakpoints(cx))
}
@@ -209,8 +209,11 @@ pub fn init(cx: &mut App) {
state: debugger::breakpoint_store::BreakpointState::Enabled,
};
active_session.update(cx, |session, cx| {
session.running_state().update(cx, |state, cx| {
active_session
.update(cx, |session_item, _| {
session_item.mode().as_running().cloned()
})?
.update(cx, |state, cx| {
if let Some(thread_id) = state.selected_thread_id() {
state.session().update(cx, |session, cx| {
session.run_to_position(
@@ -221,7 +224,6 @@ pub fn init(cx: &mut App) {
})
}
});
});
Some(())
});
@@ -244,16 +246,17 @@ pub fn init(cx: &mut App) {
cx,
)?;
active_session.update(cx, |session, cx| {
session.running_state().update(cx, |state, cx| {
active_session
.update(cx, |session_item, _| {
session_item.mode().as_running().cloned()
})?
.update(cx, |state, cx| {
let stack_id = state.selected_stack_frame_id(cx);
state.session().update(cx, |session, cx| {
session.evaluate(text, None, stack_id, None, cx).detach();
});
});
});
Some(())
});
},

View File

@@ -5,7 +5,7 @@ use std::sync::OnceLock;
use dap::client::SessionId;
use gpui::{App, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity};
use project::Project;
use project::debugger::session::Session;
use project::debugger::{dap_store::DapStore, session::Session};
use project::worktree_store::WorktreeStore;
use rpc::proto::{self, PeerId};
use running::RunningState;
@@ -18,10 +18,23 @@ use workspace::{
use crate::debugger_panel::DebugPanel;
use crate::persistence::SerializedPaneLayout;
pub(crate) enum DebugSessionState {
Running(Entity<running::RunningState>),
}
impl DebugSessionState {
pub(crate) fn as_running(&self) -> Option<&Entity<running::RunningState>> {
match &self {
DebugSessionState::Running(entity) => Some(entity),
}
}
}
pub struct DebugSession {
remote_id: Option<workspace::ViewId>,
running_state: Entity<RunningState>,
mode: DebugSessionState,
label: OnceLock<SharedString>,
dap_store: WeakEntity<DapStore>,
_debug_panel: WeakEntity<DebugPanel>,
_worktree_store: WeakEntity<WorktreeStore>,
_workspace: WeakEntity<Workspace>,
@@ -44,7 +57,7 @@ impl DebugSession {
window: &mut Window,
cx: &mut App,
) -> Entity<Self> {
let running_state = cx.new(|cx| {
let mode = cx.new(|cx| {
RunningState::new(
session.clone(),
project.clone(),
@@ -56,12 +69,13 @@ impl DebugSession {
});
cx.new(|cx| Self {
_subscriptions: [cx.subscribe(&running_state, |_, _, _, cx| {
_subscriptions: [cx.subscribe(&mode, |_, _, _, cx| {
cx.notify();
})],
remote_id: None,
running_state,
mode: DebugSessionState::Running(mode),
label: OnceLock::new(),
dap_store: project.read(cx).dap_store().downgrade(),
_debug_panel,
_worktree_store: project.read(cx).worktree_store().downgrade(),
_workspace: workspace,
@@ -69,16 +83,31 @@ impl DebugSession {
}
pub(crate) fn session_id(&self, cx: &App) -> SessionId {
self.running_state.read(cx).session_id()
match &self.mode {
DebugSessionState::Running(entity) => entity.read(cx).session_id(),
}
}
pub fn session(&self, cx: &App) -> Entity<Session> {
self.running_state.read(cx).session().clone()
match &self.mode {
DebugSessionState::Running(entity) => entity.read(cx).session().clone(),
}
}
pub(crate) fn shutdown(&mut self, cx: &mut Context<Self>) {
self.running_state
.update(cx, |state, cx| state.shutdown(cx));
match &self.mode {
DebugSessionState::Running(state) => state.update(cx, |state, cx| state.shutdown(cx)),
}
}
pub(crate) fn mode(&self) -> &DebugSessionState {
&self.mode
}
pub(crate) fn running_state(&self) -> Entity<RunningState> {
match &self.mode {
DebugSessionState::Running(running_state) => running_state.clone(),
}
}
pub(crate) fn label(&self, cx: &App) -> SharedString {
@@ -86,40 +115,36 @@ impl DebugSession {
return label.clone();
}
let session = self.running_state.read(cx).session();
let session_id = match &self.mode {
DebugSessionState::Running(running_state) => running_state.read(cx).session_id(),
};
let Ok(Some(session)) = self
.dap_store
.read_with(cx, |store, _| store.session_by_id(session_id))
else {
return "".into();
};
self.label
.get_or_init(|| session.read(cx).label())
.to_owned()
}
pub(crate) fn running_state(&self) -> &Entity<RunningState> {
&self.running_state
}
pub(crate) fn label_element(&self, cx: &App) -> AnyElement {
let label = self.label(cx);
let icon = {
if self
.running_state
.read(cx)
.session()
.read(cx)
.is_terminated()
{
Some(Indicator::dot().color(Color::Error))
} else {
match self
.running_state
.read(cx)
.thread_status(cx)
.unwrap_or_default()
{
project::debugger::session::ThreadStatus::Stopped => {
Some(Indicator::dot().color(Color::Conflict))
let icon = match &self.mode {
DebugSessionState::Running(state) => {
if state.read(cx).session().read(cx).is_terminated() {
Some(Indicator::dot().color(Color::Error))
} else {
match state.read(cx).thread_status(cx).unwrap_or_default() {
project::debugger::session::ThreadStatus::Stopped => {
Some(Indicator::dot().color(Color::Conflict))
}
_ => Some(Indicator::dot().color(Color::Success)),
}
_ => Some(Indicator::dot().color(Color::Success)),
}
}
};
@@ -137,7 +162,9 @@ impl EventEmitter<DebugPanelItemEvent> for DebugSession {}
impl Focusable for DebugSession {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.running_state.focus_handle(cx)
match &self.mode {
DebugSessionState::Running(running_state) => running_state.focus_handle(cx),
}
}
}
@@ -216,7 +243,10 @@ impl FollowableItem for DebugSession {
impl Render for DebugSession {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
self.running_state
.update(cx, |this, cx| this.render(window, cx).into_any_element())
match &self.mode {
DebugSessionState::Running(running_state) => {
running_state.update(cx, |this, cx| this.render(window, cx).into_any_element())
}
}
}
}

View File

@@ -5,20 +5,15 @@ pub(crate) mod module_list;
pub mod stack_frame_list;
pub mod variable_list;
use std::{any::Any, ops::ControlFlow, path::PathBuf, sync::Arc, time::Duration};
use std::{any::Any, ops::ControlFlow, sync::Arc, time::Duration};
use crate::persistence::{self, DebuggerPaneItem, SerializedPaneLayout};
use super::DebugPanelItemEvent;
use anyhow::{Result, anyhow};
use breakpoint_list::BreakpointList;
use collections::{HashMap, IndexMap};
use console::Console;
use dap::{
Capabilities, RunInTerminalRequestArguments, Thread, client::SessionId,
debugger_settings::DebuggerSettings,
};
use futures::{SinkExt, channel::mpsc};
use dap::{Capabilities, Thread, client::SessionId, debugger_settings::DebuggerSettings};
use gpui::{
Action as _, AnyView, AppContext, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
NoAction, Pixels, Point, Subscription, Task, WeakEntity,
@@ -28,10 +23,8 @@ use module_list::ModuleList;
use project::{
Project,
debugger::session::{Session, SessionEvent, ThreadId, ThreadStatus},
terminals::TerminalKind,
};
use rpc::proto::ViewId;
use serde_json::Value;
use settings::Settings;
use stack_frame_list::StackFrameList;
use terminal_view::TerminalView;
@@ -39,7 +32,7 @@ use ui::{
ActiveTheme, AnyElement, App, ButtonCommon as _, Clickable as _, Context, ContextMenu,
DropdownMenu, FluentBuilder, IconButton, IconName, IconSize, InteractiveElement, IntoElement,
Label, LabelCommon as _, ParentElement, Render, SharedString, StatefulInteractiveElement,
Styled, Tab, Tooltip, VisibleOnHover, VisualContext, Window, div, h_flex, v_flex,
Styled, Tab, Tooltip, VisibleOnHover, Window, div, h_flex, v_flex,
};
use util::ResultExt;
use variable_list::VariableList;
@@ -566,9 +559,6 @@ impl RunningState {
this.remove_pane_item(DebuggerPaneItem::LoadedSources, window, cx);
}
}
SessionEvent::RunInTerminal { request, sender } => this
.handle_run_in_terminal(request, sender.clone(), window, cx)
.detach_and_log_err(cx),
_ => {}
}
@@ -667,111 +657,6 @@ impl RunningState {
self.panes.pane_at_pixel_position(position).is_some()
}
fn handle_run_in_terminal(
&self,
request: &RunInTerminalRequestArguments,
mut sender: mpsc::Sender<Result<u32>>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
let running = cx.entity();
let Ok(project) = self
.workspace
.update(cx, |workspace, _| workspace.project().clone())
else {
return Task::ready(Err(anyhow!("no workspace")));
};
let session = self.session.read(cx);
let cwd = Some(&request.cwd)
.filter(|cwd| cwd.len() > 0)
.map(PathBuf::from)
.or_else(|| session.binary().cwd.clone());
let mut args = request.args.clone();
// Handle special case for NodeJS debug adapter
// If only the Node binary path is provided, we set the command to None
// This prevents the NodeJS REPL from appearing, which is not the desired behavior
// The expected usage is for users to provide their own Node command, e.g., `node test.js`
// This allows the NodeJS debug client to attach correctly
let command = if args.len() > 1 {
Some(args.remove(0))
} else {
None
};
let mut envs: HashMap<String, String> = Default::default();
if let Some(Value::Object(env)) = &request.env {
for (key, value) in env {
let value_str = match (key.as_str(), value) {
(_, Value::String(value)) => value,
_ => continue,
};
envs.insert(key.clone(), value_str.clone());
}
}
let shell = project.read(cx).terminal_settings(&cwd, cx).shell.clone();
let kind = if let Some(command) = command {
let title = request.title.clone().unwrap_or(command.clone());
TerminalKind::Task(task::SpawnInTerminal {
id: task::TaskId("debug".to_string()),
full_label: title.clone(),
label: title.clone(),
command: command.clone(),
args,
command_label: title.clone(),
cwd,
env: envs,
use_new_terminal: true,
allow_concurrent_runs: true,
reveal: task::RevealStrategy::NoFocus,
reveal_target: task::RevealTarget::Dock,
hide: task::HideStrategy::Never,
shell,
show_summary: false,
show_command: false,
show_rerun: false,
})
} else {
TerminalKind::Shell(cwd.map(|c| c.to_path_buf()))
};
let workspace = self.workspace.clone();
let weak_project = project.downgrade();
let terminal_task = project.update(cx, |project, cx| {
project.create_terminal(kind, window.window_handle(), cx)
});
let terminal_task = cx.spawn_in(window, async move |_, cx| {
let terminal = terminal_task.await?;
let terminal_view = cx.new_window_entity(|window, cx| {
TerminalView::new(terminal.clone(), workspace, None, weak_project, window, cx)
})?;
running.update_in(cx, |running, window, cx| {
running.ensure_pane_item(DebuggerPaneItem::Terminal, window, cx);
running.debug_terminal.update(cx, |debug_terminal, cx| {
debug_terminal.terminal = Some(terminal_view);
cx.notify();
});
})?;
terminal.read_with(cx, |terminal, _| {
terminal
.pty_info
.pid()
.map(|pid| pid.as_u32())
.ok_or_else(|| anyhow!("Terminal was spawned but PID was not available"))
})?
});
cx.background_spawn(async move { anyhow::Ok(sender.send(terminal_task.await).await?) })
}
fn create_sub_view(
&self,
item_kind: DebuggerPaneItem,

View File

@@ -118,8 +118,8 @@ pub fn start_debug_session_with<T: Fn(&Arc<DebugAdapterClient>) + 'static>(
workspace
.panel::<DebugPanel>(cx)
.and_then(|panel| panel.read(cx).active_session())
.map(|session| session.read(cx).running_state().read(cx).session())
.cloned()
.and_then(|session| session.read(cx).mode().as_running().cloned())
.map(|running| running.read(cx).session().clone())
.ok_or_else(|| anyhow!("Failed to get active session"))
})??;

View File

@@ -7,7 +7,6 @@ use project::{FakeFs, Project};
use serde_json::json;
use task::{AttachRequest, TcpArgumentsTemplate};
use tests::{init_test, init_test_workspace};
use util::path;
#[gpui::test]
async fn test_direct_attach_to_process(executor: BackgroundExecutor, cx: &mut TestAppContext) {
@@ -16,14 +15,14 @@ async fn test_direct_attach_to_process(executor: BackgroundExecutor, cx: &mut Te
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -80,14 +79,14 @@ async fn test_show_attach_modal_and_select_process(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
// Set up handlers for sessions spawned via modal.

View File

@@ -7,7 +7,6 @@ use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
use serde_json::json;
use tests::{init_test, init_test_workspace};
use util::path;
#[gpui::test]
async fn test_handle_output_event(executor: BackgroundExecutor, cx: &mut TestAppContext) {
@@ -16,14 +15,14 @@ async fn test_handle_output_event(executor: BackgroundExecutor, cx: &mut TestApp
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
workspace
@@ -87,7 +86,10 @@ async fn test_handle_output_event(executor: BackgroundExecutor, cx: &mut TestApp
let running_state =
active_debug_session_panel(workspace, cx).update_in(cx, |item, window, cx| {
cx.focus_self(window);
item.running_state().clone()
item.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
cx.run_until_parked();
@@ -102,7 +104,7 @@ async fn test_handle_output_event(executor: BackgroundExecutor, cx: &mut TestApp
assert_eq!(
"First console output line before thread stopped!\nFirst output line before thread stopped!\n",
active_debug_session_panel.read(cx).running_state().read(cx).console().read(cx).editor().read(cx).text(cx).as_str()
active_debug_session_panel.read(cx).mode().as_running().unwrap().read(cx).console().read(cx).editor().read(cx).text(cx).as_str()
);
})
.unwrap();
@@ -151,7 +153,7 @@ async fn test_handle_output_event(executor: BackgroundExecutor, cx: &mut TestApp
assert_eq!(
"First console output line before thread stopped!\nFirst output line before thread stopped!\nSecond output line after thread stopped!\nSecond console output line after thread stopped!\n",
active_session_panel.read(cx).running_state().read(cx).console().read(cx).editor().read(cx).text(cx).as_str()
active_session_panel.read(cx).mode().as_running().unwrap().read(cx).console().read(cx).editor().read(cx).text(cx).as_str()
);
})
.unwrap();

View File

@@ -5,7 +5,6 @@ use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::Project;
use serde_json::json;
use std::cell::OnceCell;
use util::path;
#[gpui::test]
async fn test_dap_logger_captures_all_session_rpc_messages(
@@ -29,7 +28,7 @@ async fn test_dap_logger_captures_all_session_rpc_messages(
// Create a filesystem with a simple project
let fs = project::FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "fn main() {\n println!(\"Hello, world!\");\n}"
}),
@@ -43,7 +42,7 @@ async fn test_dap_logger_captures_all_session_rpc_messages(
"log_store shouldn't contain any session IDs before any sessions were created"
);
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);

View File

@@ -44,14 +44,14 @@ async fn test_basic_show_debug_panel(executor: BackgroundExecutor, cx: &mut Test
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -84,7 +84,11 @@ async fn test_basic_show_debug_panel(executor: BackgroundExecutor, cx: &mut Test
debug_panel.update(cx, |debug_panel, _| debug_panel.active_session().unwrap());
let running_state = active_session.update(cx, |active_session, _| {
active_session.running_state().clone()
active_session
.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
debug_panel.update(cx, |this, cx| {
@@ -116,7 +120,11 @@ async fn test_basic_show_debug_panel(executor: BackgroundExecutor, cx: &mut Test
.unwrap();
let running_state = active_session.update(cx, |active_session, _| {
active_session.running_state().clone()
active_session
.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
assert_eq!(client.id(), running_state.read(cx).session_id());
@@ -145,7 +153,11 @@ async fn test_basic_show_debug_panel(executor: BackgroundExecutor, cx: &mut Test
.unwrap();
let running_state = active_session.update(cx, |active_session, _| {
active_session.running_state().clone()
active_session
.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
debug_panel.update(cx, |this, cx| {
@@ -169,14 +181,14 @@ async fn test_we_can_only_have_one_panel_per_debug_session(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -235,7 +247,11 @@ async fn test_we_can_only_have_one_panel_per_debug_session(
.unwrap();
let running_state = active_session.update(cx, |active_session, _| {
active_session.running_state().clone()
active_session
.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
assert_eq!(client.id(), active_session.read(cx).session_id(cx));
@@ -268,7 +284,11 @@ async fn test_we_can_only_have_one_panel_per_debug_session(
.unwrap();
let running_state = active_session.update(cx, |active_session, _| {
active_session.running_state().clone()
active_session
.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
assert_eq!(client.id(), active_session.read(cx).session_id(cx));
@@ -296,7 +316,11 @@ async fn test_we_can_only_have_one_panel_per_debug_session(
.unwrap();
let running_state = active_session.update(cx, |active_session, _| {
active_session.running_state().clone()
active_session
.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
debug_panel.update(cx, |this, cx| {
@@ -322,14 +346,14 @@ async fn test_handle_successful_run_in_terminal_reverse_request(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -393,14 +417,14 @@ async fn test_handle_start_debugging_request(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -469,14 +493,14 @@ async fn test_handle_error_run_in_terminal_reverse_request(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -499,8 +523,8 @@ async fn test_handle_error_run_in_terminal_reverse_request(
.fake_reverse_request::<RunInTerminal>(RunInTerminalRequestArguments {
kind: None,
title: None,
cwd: "".into(),
args: vec!["oops".into(), "oops".into()],
cwd: "/non-existing/path".into(), // invalid/non-existing path will cause the terminal spawn to fail
args: vec![],
env: None,
args_can_be_interpreted_by_shell: None,
})
@@ -537,14 +561,14 @@ async fn test_handle_start_debugging_reverse_request(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -633,14 +657,14 @@ async fn test_shutdown_children_when_parent_session_shutdown(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let dap_store = project.update(cx, |project, _| project.dap_store());
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -739,14 +763,14 @@ async fn test_shutdown_parent_session_if_all_children_are_shutdown(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let dap_store = project.update(cx, |project, _| project.dap_store());
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -859,14 +883,14 @@ async fn test_debug_panel_item_thread_status_reset_on_failure(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -985,8 +1009,12 @@ async fn test_debug_panel_item_thread_status_reset_on_failure(
cx.run_until_parked();
let running_state = active_debug_session_panel(workspace, cx)
.update(cx, |item, _| item.running_state().clone());
let running_state = active_debug_session_panel(workspace, cx).update_in(cx, |item, _, _| {
item.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
cx.run_until_parked();
let thread_id = ThreadId(1);
@@ -1298,7 +1326,7 @@ async fn test_unsetting_breakpoints_on_clear_breakpoint_action(
.expect("We should always send a breakpoint's path")
.as_str()
{
path!("/project/main.rs") | path!("/project/second.rs") => {}
"/project/main.rs" | "/project/second.rs" => {}
_ => {
panic!("Unset breakpoints for path that doesn't have any")
}
@@ -1326,14 +1354,14 @@ async fn test_debug_session_is_shutdown_when_attach_and_launch_request_fails(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -1374,14 +1402,14 @@ async fn test_we_send_arguments_from_user_config(
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
path!("/project"),
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let debug_definition = DebugTaskDefinition {
@@ -1389,7 +1417,7 @@ async fn test_we_send_arguments_from_user_config(
request: dap::DebugRequest::Launch(LaunchRequest {
program: "main.rs".to_owned(),
args: vec!["arg1".to_owned(), "arg2".to_owned()],
cwd: Some(path!("/Random_path").into()),
cwd: Some("/Random_path".into()),
env: HashMap::from_iter(vec![("KEY".to_owned(), "VALUE".to_owned())]),
}),
label: "test".into(),

View File

@@ -12,7 +12,6 @@ use std::sync::{
Arc,
atomic::{AtomicBool, AtomicI32, Ordering},
};
use util::path;
#[gpui::test]
async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext) {
@@ -20,7 +19,7 @@ async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext)
let fs = FakeFs::new(executor.clone());
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
workspace
.update(cx, |workspace, window, cx| {
@@ -106,7 +105,10 @@ async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext)
let running_state =
active_debug_session_panel(workspace, cx).update_in(cx, |item, window, cx| {
cx.focus_self(window);
item.running_state().clone()
item.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
running_state.update_in(cx, |this, window, cx| {

View File

@@ -138,33 +138,43 @@ async fn test_fetch_initial_stack_frames_and_go_to_stack_frame(
// trigger to load threads
active_debug_session_panel(workspace, cx).update(cx, |session, cx| {
session.running_state().update(cx, |running_state, cx| {
running_state
.session()
.update(cx, |session, cx| session.threads(cx));
});
session
.mode()
.as_running()
.unwrap()
.update(cx, |running_state, cx| {
running_state
.session()
.update(cx, |session, cx| session.threads(cx));
});
});
cx.run_until_parked();
// select first thread
active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| {
session.running_state().update(cx, |running_state, cx| {
running_state.select_current_thread(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
session
.mode()
.as_running()
.unwrap()
.update(cx, |running_state, cx| {
running_state.select_current_thread(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
});
cx.run_until_parked();
active_debug_session_panel(workspace, cx).update(cx, |session, cx| {
let stack_frame_list = session
.running_state()
.mode()
.as_running()
.unwrap()
.update(cx, |state, _| state.stack_frame_list().clone());
stack_frame_list.update(cx, |stack_frame_list, cx| {
@@ -299,26 +309,34 @@ async fn test_select_stack_frame(executor: BackgroundExecutor, cx: &mut TestAppC
// trigger threads to load
active_debug_session_panel(workspace, cx).update(cx, |session, cx| {
session.running_state().update(cx, |running_state, cx| {
running_state
.session()
.update(cx, |session, cx| session.threads(cx));
});
session
.mode()
.as_running()
.unwrap()
.update(cx, |running_state, cx| {
running_state
.session()
.update(cx, |session, cx| session.threads(cx));
});
});
cx.run_until_parked();
// select first thread
active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| {
session.running_state().update(cx, |running_state, cx| {
running_state.select_current_thread(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
session
.mode()
.as_running()
.unwrap()
.update(cx, |running_state, cx| {
running_state.select_current_thread(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
});
cx.run_until_parked();
@@ -365,7 +383,9 @@ async fn test_select_stack_frame(executor: BackgroundExecutor, cx: &mut TestAppC
active_debug_panel_item
.read(cx)
.running_state()
.mode()
.as_running()
.unwrap()
.read(cx)
.stack_frame_list()
.clone()
@@ -656,26 +676,34 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo
// trigger threads to load
active_debug_session_panel(workspace, cx).update(cx, |session, cx| {
session.running_state().update(cx, |running_state, cx| {
running_state
.session()
.update(cx, |session, cx| session.threads(cx));
});
session
.mode()
.as_running()
.unwrap()
.update(cx, |running_state, cx| {
running_state
.session()
.update(cx, |session, cx| session.threads(cx));
});
});
cx.run_until_parked();
// select first thread
active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| {
session.running_state().update(cx, |running_state, cx| {
running_state.select_current_thread(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
session
.mode()
.as_running()
.unwrap()
.update(cx, |running_state, cx| {
running_state.select_current_thread(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
});
cx.run_until_parked();
@@ -683,7 +711,9 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo
// trigger stack frames to loaded
active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| {
let stack_frame_list = debug_panel_item
.running_state()
.mode()
.as_running()
.unwrap()
.update(cx, |state, _| state.stack_frame_list().clone());
stack_frame_list.update(cx, |stack_frame_list, cx| {
@@ -695,7 +725,9 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo
active_debug_session_panel(workspace, cx).update_in(cx, |debug_panel_item, window, cx| {
let stack_frame_list = debug_panel_item
.running_state()
.mode()
.as_running()
.unwrap()
.update(cx, |state, _| state.stack_frame_list().clone());
stack_frame_list.update(cx, |stack_frame_list, cx| {

View File

@@ -183,7 +183,10 @@ async fn test_basic_fetch_initial_scope_and_variables(
let running_state =
active_debug_session_panel(workspace, cx).update_in(cx, |item, window, cx| {
cx.focus_self(window);
item.running_state().clone()
item.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
cx.run_until_parked();
@@ -424,7 +427,10 @@ async fn test_fetch_variables_for_multiple_scopes(
let running_state =
active_debug_session_panel(workspace, cx).update_in(cx, |item, window, cx| {
cx.focus_self(window);
item.running_state().clone()
item.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
cx.run_until_parked();
@@ -704,7 +710,11 @@ async fn test_keyboard_navigation(executor: BackgroundExecutor, cx: &mut TestApp
let running_state =
active_debug_session_panel(workspace, cx).update_in(cx, |item, window, cx| {
cx.focus_self(window);
let running = item.running_state().clone();
let running = item
.mode()
.as_running()
.expect("Session should be running by this point")
.clone();
let variable_list = running.read_with(cx, |state, _| state.variable_list().clone());
variable_list.update(cx, |_, cx| cx.focus_self(window));
@@ -1430,7 +1440,11 @@ async fn test_variable_list_only_sends_requests_when_rendering(
cx.run_until_parked();
let running_state = active_debug_session_panel(workspace, cx).update_in(cx, |item, _, _| {
let state = item.running_state().clone();
let state = item
.mode()
.as_running()
.expect("Session should be running by this point")
.clone();
state
});
@@ -1728,7 +1742,10 @@ async fn test_it_fetches_scopes_variables_when_you_select_a_stack_frame(
let running_state =
active_debug_session_panel(workspace, cx).update_in(cx, |item, window, cx| {
cx.focus_self(window);
item.running_state().clone()
item.mode()
.as_running()
.expect("Session should be running by this point")
.clone()
});
running_state.update(cx, |running_state, cx| {

View File

@@ -14,16 +14,13 @@ doctest = false
[dependencies]
anyhow.workspace = true
cargo_metadata.workspace = true
collections.workspace = true
component.workspace = true
ctor.workspace = true
editor.workspace = true
env_logger.workspace = true
futures.workspace = true
gpui.workspace = true
indoc.workspace = true
itertools.workspace = true
language.workspace = true
linkme.workspace = true
log.workspace = true
@@ -32,9 +29,7 @@ markdown.workspace = true
project.workspace = true
rand.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
text.workspace = true
theme.workspace = true
ui.workspace = true

View File

@@ -1,603 +0,0 @@
use std::{
path::{Component, Path, Prefix},
process::Stdio,
sync::atomic::{self, AtomicUsize},
};
use cargo_metadata::{
Message,
diagnostic::{Applicability, Diagnostic as CargoDiagnostic, DiagnosticLevel, DiagnosticSpan},
};
use collections::HashMap;
use gpui::{AppContext, Entity, Task};
use itertools::Itertools as _;
use language::Diagnostic;
use project::{
Worktree, lsp_store::rust_analyzer_ext::CARGO_DIAGNOSTICS_SOURCE_NAME,
project_settings::ProjectSettings,
};
use serde::{Deserialize, Serialize};
use settings::Settings;
use smol::{
channel::Receiver,
io::{AsyncBufReadExt, BufReader},
process::Command,
};
use ui::App;
use util::ResultExt;
use crate::ProjectDiagnosticsEditor;
#[derive(Debug, serde::Deserialize)]
#[serde(untagged)]
enum CargoMessage {
Cargo(Message),
Rustc(CargoDiagnostic),
}
/// Appends formatted string to a `String`.
macro_rules! format_to {
($buf:expr) => ();
($buf:expr, $lit:literal $($arg:tt)*) => {
{
use ::std::fmt::Write as _;
// We can't do ::std::fmt::Write::write_fmt($buf, format_args!($lit $($arg)*))
// unfortunately, as that loses out on autoref behavior.
_ = $buf.write_fmt(format_args!($lit $($arg)*))
}
};
}
pub fn cargo_diagnostics_sources(
editor: &ProjectDiagnosticsEditor,
cx: &App,
) -> Vec<Entity<Worktree>> {
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
.diagnostics
.fetch_cargo_diagnostics();
if !fetch_cargo_diagnostics {
return Vec::new();
}
editor
.project
.read(cx)
.worktrees(cx)
.filter(|worktree| worktree.read(cx).entry_for_path("Cargo.toml").is_some())
.collect()
}
#[derive(Debug)]
pub enum FetchUpdate {
Diagnostic(CargoDiagnostic),
Progress(String),
}
#[derive(Debug)]
pub enum FetchStatus {
Started,
Progress { message: String },
Finished,
}
pub fn fetch_worktree_diagnostics(
worktree_root: &Path,
cx: &App,
) -> Option<(Task<()>, Receiver<FetchUpdate>)> {
let diagnostics_settings = ProjectSettings::get_global(cx)
.diagnostics
.cargo
.as_ref()
.filter(|cargo_diagnostics| cargo_diagnostics.fetch_cargo_diagnostics)?;
let command_string = diagnostics_settings
.diagnostics_fetch_command
.iter()
.join(" ");
let mut command_parts = diagnostics_settings.diagnostics_fetch_command.iter();
let mut command = Command::new(command_parts.next()?)
.args(command_parts)
.envs(diagnostics_settings.env.clone())
.current_dir(worktree_root)
.stdout(Stdio::piped())
.stderr(Stdio::null())
.kill_on_drop(true)
.spawn()
.log_err()?;
let stdout = command.stdout.take()?;
let mut reader = BufReader::new(stdout);
let (tx, rx) = smol::channel::unbounded();
let error_threshold = 10;
let cargo_diagnostics_fetch_task = cx.background_spawn(async move {
let _command = command;
let mut errors = 0;
loop {
let mut line = String::new();
match reader.read_line(&mut line).await {
Ok(0) => {
return;
},
Ok(_) => {
errors = 0;
let mut deserializer = serde_json::Deserializer::from_str(&line);
deserializer.disable_recursion_limit();
let send_result = match CargoMessage::deserialize(&mut deserializer) {
Ok(CargoMessage::Cargo(Message::CompilerMessage(message))) => tx.send(FetchUpdate::Diagnostic(message.message)).await,
Ok(CargoMessage::Cargo(Message::CompilerArtifact(artifact))) => tx.send(FetchUpdate::Progress(format!("Compiled {:?}", artifact.manifest_path.parent().unwrap_or(&artifact.manifest_path)))).await,
Ok(CargoMessage::Cargo(_)) => Ok(()),
Ok(CargoMessage::Rustc(rustc_message)) => tx.send(FetchUpdate::Diagnostic(rustc_message)).await,
Err(_) => {
log::debug!("Failed to parse cargo diagnostics from line '{line}'");
Ok(())
},
};
if send_result.is_err() {
return;
}
},
Err(e) => {
log::error!("Failed to read line from {command_string} command output when fetching cargo diagnostics: {e}");
errors += 1;
if errors >= error_threshold {
log::error!("Failed {error_threshold} times, aborting the diagnostics fetch");
return;
}
},
}
}
});
Some((cargo_diagnostics_fetch_task, rx))
}
static CARGO_DIAGNOSTICS_FETCH_GENERATION: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
struct CargoFetchDiagnosticData {
generation: usize,
}
pub fn next_cargo_fetch_generation() {
CARGO_DIAGNOSTICS_FETCH_GENERATION.fetch_add(1, atomic::Ordering::Release);
}
pub fn is_outdated_cargo_fetch_diagnostic(diagnostic: &Diagnostic) -> bool {
if let Some(data) = diagnostic
.data
.clone()
.and_then(|data| serde_json::from_value::<CargoFetchDiagnosticData>(data).ok())
{
let current_generation = CARGO_DIAGNOSTICS_FETCH_GENERATION.load(atomic::Ordering::Acquire);
data.generation < current_generation
} else {
false
}
}
/// Converts a Rust root diagnostic to LSP form
///
/// This flattens the Rust diagnostic by:
///
/// 1. Creating a LSP diagnostic with the root message and primary span.
/// 2. Adding any labelled secondary spans to `relatedInformation`
/// 3. Categorising child diagnostics as either `SuggestedFix`es,
/// `relatedInformation` or additional message lines.
///
/// If the diagnostic has no primary span this will return `None`
///
/// Taken from https://github.com/rust-lang/rust-analyzer/blob/fe7b4f2ad96f7c13cc571f45edc2c578b35dddb4/crates/rust-analyzer/src/diagnostics/to_proto.rs#L275-L285
pub(crate) fn map_rust_diagnostic_to_lsp(
worktree_root: &Path,
cargo_diagnostic: &CargoDiagnostic,
) -> Vec<(lsp::Url, lsp::Diagnostic)> {
let primary_spans: Vec<&DiagnosticSpan> = cargo_diagnostic
.spans
.iter()
.filter(|s| s.is_primary)
.collect();
if primary_spans.is_empty() {
return Vec::new();
}
let severity = diagnostic_severity(cargo_diagnostic.level);
let mut source = String::from(CARGO_DIAGNOSTICS_SOURCE_NAME);
let mut code = cargo_diagnostic.code.as_ref().map(|c| c.code.clone());
if let Some(code_val) = &code {
// See if this is an RFC #2103 scoped lint (e.g. from Clippy)
let scoped_code: Vec<&str> = code_val.split("::").collect();
if scoped_code.len() == 2 {
source = String::from(scoped_code[0]);
code = Some(String::from(scoped_code[1]));
}
}
let mut needs_primary_span_label = true;
let mut subdiagnostics = Vec::new();
let mut tags = Vec::new();
for secondary_span in cargo_diagnostic.spans.iter().filter(|s| !s.is_primary) {
if let Some(label) = secondary_span.label.clone() {
subdiagnostics.push(lsp::DiagnosticRelatedInformation {
location: location(worktree_root, secondary_span),
message: label,
});
}
}
let mut message = cargo_diagnostic.message.clone();
for child in &cargo_diagnostic.children {
let child = map_rust_child_diagnostic(worktree_root, child);
match child {
MappedRustChildDiagnostic::SubDiagnostic(sub) => {
subdiagnostics.push(sub);
}
MappedRustChildDiagnostic::MessageLine(message_line) => {
format_to!(message, "\n{message_line}");
// These secondary messages usually duplicate the content of the
// primary span label.
needs_primary_span_label = false;
}
}
}
if let Some(code) = &cargo_diagnostic.code {
let code = code.code.as_str();
if matches!(
code,
"dead_code"
| "unknown_lints"
| "unreachable_code"
| "unused_attributes"
| "unused_imports"
| "unused_macros"
| "unused_variables"
) {
tags.push(lsp::DiagnosticTag::UNNECESSARY);
}
if matches!(code, "deprecated") {
tags.push(lsp::DiagnosticTag::DEPRECATED);
}
}
let code_description = match source.as_str() {
"rustc" => rustc_code_description(code.as_deref()),
"clippy" => clippy_code_description(code.as_deref()),
_ => None,
};
let generation = CARGO_DIAGNOSTICS_FETCH_GENERATION.load(atomic::Ordering::Acquire);
let data = Some(
serde_json::to_value(CargoFetchDiagnosticData { generation })
.expect("Serializing a regular Rust struct"),
);
primary_spans
.iter()
.flat_map(|primary_span| {
let primary_location = primary_location(worktree_root, primary_span);
let message = {
let mut message = message.clone();
if needs_primary_span_label {
if let Some(primary_span_label) = &primary_span.label {
format_to!(message, "\n{primary_span_label}");
}
}
message
};
// Each primary diagnostic span may result in multiple LSP diagnostics.
let mut diagnostics = Vec::new();
let mut related_info_macro_calls = vec![];
// If error occurs from macro expansion, add related info pointing to
// where the error originated
// Also, we would generate an additional diagnostic, so that exact place of macro
// will be highlighted in the error origin place.
let span_stack = std::iter::successors(Some(*primary_span), |span| {
Some(&span.expansion.as_ref()?.span)
});
for (i, span) in span_stack.enumerate() {
if is_dummy_macro_file(&span.file_name) {
continue;
}
// First span is the original diagnostic, others are macro call locations that
// generated that code.
let is_in_macro_call = i != 0;
let secondary_location = location(worktree_root, span);
if secondary_location == primary_location {
continue;
}
related_info_macro_calls.push(lsp::DiagnosticRelatedInformation {
location: secondary_location.clone(),
message: if is_in_macro_call {
"Error originated from macro call here".to_owned()
} else {
"Actual error occurred here".to_owned()
},
});
// For the additional in-macro diagnostic we add the inverse message pointing to the error location in code.
let information_for_additional_diagnostic =
vec![lsp::DiagnosticRelatedInformation {
location: primary_location.clone(),
message: "Exact error occurred here".to_owned(),
}];
let diagnostic = lsp::Diagnostic {
range: secondary_location.range,
// downgrade to hint if we're pointing at the macro
severity: Some(lsp::DiagnosticSeverity::HINT),
code: code.clone().map(lsp::NumberOrString::String),
code_description: code_description.clone(),
source: Some(source.clone()),
message: message.clone(),
related_information: Some(information_for_additional_diagnostic),
tags: if tags.is_empty() {
None
} else {
Some(tags.clone())
},
data: data.clone(),
};
diagnostics.push((secondary_location.uri, diagnostic));
}
// Emit the primary diagnostic.
diagnostics.push((
primary_location.uri.clone(),
lsp::Diagnostic {
range: primary_location.range,
severity,
code: code.clone().map(lsp::NumberOrString::String),
code_description: code_description.clone(),
source: Some(source.clone()),
message,
related_information: {
let info = related_info_macro_calls
.iter()
.cloned()
.chain(subdiagnostics.iter().cloned())
.collect::<Vec<_>>();
if info.is_empty() { None } else { Some(info) }
},
tags: if tags.is_empty() {
None
} else {
Some(tags.clone())
},
data: data.clone(),
},
));
// Emit hint-level diagnostics for all `related_information` entries such as "help"s.
// This is useful because they will show up in the user's editor, unlike
// `related_information`, which just produces hard-to-read links, at least in VS Code.
let back_ref = lsp::DiagnosticRelatedInformation {
location: primary_location,
message: "original diagnostic".to_owned(),
};
for sub in &subdiagnostics {
diagnostics.push((
sub.location.uri.clone(),
lsp::Diagnostic {
range: sub.location.range,
severity: Some(lsp::DiagnosticSeverity::HINT),
code: code.clone().map(lsp::NumberOrString::String),
code_description: code_description.clone(),
source: Some(source.clone()),
message: sub.message.clone(),
related_information: Some(vec![back_ref.clone()]),
tags: None, // don't apply modifiers again
data: data.clone(),
},
));
}
diagnostics
})
.collect()
}
fn rustc_code_description(code: Option<&str>) -> Option<lsp::CodeDescription> {
code.filter(|code| {
let mut chars = code.chars();
chars.next() == Some('E')
&& chars.by_ref().take(4).all(|c| c.is_ascii_digit())
&& chars.next().is_none()
})
.and_then(|code| {
lsp::Url::parse(&format!(
"https://doc.rust-lang.org/error-index.html#{code}"
))
.ok()
.map(|href| lsp::CodeDescription { href })
})
}
fn clippy_code_description(code: Option<&str>) -> Option<lsp::CodeDescription> {
code.and_then(|code| {
lsp::Url::parse(&format!(
"https://rust-lang.github.io/rust-clippy/master/index.html#{code}"
))
.ok()
.map(|href| lsp::CodeDescription { href })
})
}
/// Determines the LSP severity from a diagnostic
fn diagnostic_severity(level: DiagnosticLevel) -> Option<lsp::DiagnosticSeverity> {
let res = match level {
DiagnosticLevel::Ice => lsp::DiagnosticSeverity::ERROR,
DiagnosticLevel::Error => lsp::DiagnosticSeverity::ERROR,
DiagnosticLevel::Warning => lsp::DiagnosticSeverity::WARNING,
DiagnosticLevel::Note => lsp::DiagnosticSeverity::INFORMATION,
DiagnosticLevel::Help => lsp::DiagnosticSeverity::HINT,
_ => return None,
};
Some(res)
}
enum MappedRustChildDiagnostic {
SubDiagnostic(lsp::DiagnosticRelatedInformation),
MessageLine(String),
}
fn map_rust_child_diagnostic(
worktree_root: &Path,
cargo_diagnostic: &CargoDiagnostic,
) -> MappedRustChildDiagnostic {
let spans: Vec<&DiagnosticSpan> = cargo_diagnostic
.spans
.iter()
.filter(|s| s.is_primary)
.collect();
if spans.is_empty() {
// `rustc` uses these spanless children as a way to print multi-line
// messages
return MappedRustChildDiagnostic::MessageLine(cargo_diagnostic.message.clone());
}
let mut edit_map: HashMap<lsp::Url, Vec<lsp::TextEdit>> = HashMap::default();
let mut suggested_replacements = Vec::new();
for &span in &spans {
if let Some(suggested_replacement) = &span.suggested_replacement {
if !suggested_replacement.is_empty() {
suggested_replacements.push(suggested_replacement);
}
let location = location(worktree_root, span);
let edit = lsp::TextEdit::new(location.range, suggested_replacement.clone());
// Only actually emit a quickfix if the suggestion is "valid enough".
// We accept both "MaybeIncorrect" and "MachineApplicable". "MaybeIncorrect" means that
// the suggestion is *complete* (contains no placeholders where code needs to be
// inserted), but might not be what the user wants, or might need minor adjustments.
if matches!(
span.suggestion_applicability,
None | Some(Applicability::MaybeIncorrect | Applicability::MachineApplicable)
) {
edit_map.entry(location.uri).or_default().push(edit);
}
}
}
// rustc renders suggestion diagnostics by appending the suggested replacement, so do the same
// here, otherwise the diagnostic text is missing useful information.
let mut message = cargo_diagnostic.message.clone();
if !suggested_replacements.is_empty() {
message.push_str(": ");
let suggestions = suggested_replacements
.iter()
.map(|suggestion| format!("`{suggestion}`"))
.join(", ");
message.push_str(&suggestions);
}
MappedRustChildDiagnostic::SubDiagnostic(lsp::DiagnosticRelatedInformation {
location: location(worktree_root, spans[0]),
message,
})
}
/// Converts a Rust span to a LSP location
fn location(worktree_root: &Path, span: &DiagnosticSpan) -> lsp::Location {
let file_name = worktree_root.join(&span.file_name);
let uri = url_from_abs_path(&file_name);
let range = {
lsp::Range::new(
position(span, span.line_start, span.column_start.saturating_sub(1)),
position(span, span.line_end, span.column_end.saturating_sub(1)),
)
};
lsp::Location::new(uri, range)
}
/// Returns a `Url` object from a given path, will lowercase drive letters if present.
/// This will only happen when processing windows paths.
///
/// When processing non-windows path, this is essentially the same as `Url::from_file_path`.
pub(crate) fn url_from_abs_path(path: &Path) -> lsp::Url {
let url = lsp::Url::from_file_path(path).unwrap();
match path.components().next() {
Some(Component::Prefix(prefix))
if matches!(prefix.kind(), Prefix::Disk(_) | Prefix::VerbatimDisk(_)) =>
{
// Need to lowercase driver letter
}
_ => return url,
}
let driver_letter_range = {
let (scheme, drive_letter, _rest) = match url.as_str().splitn(3, ':').collect_tuple() {
Some(it) => it,
None => return url,
};
let start = scheme.len() + ':'.len_utf8();
start..(start + drive_letter.len())
};
// Note: lowercasing the `path` itself doesn't help, the `Url::parse`
// machinery *also* canonicalizes the drive letter. So, just massage the
// string in place.
let mut url: String = url.into();
url[driver_letter_range].make_ascii_lowercase();
lsp::Url::parse(&url).unwrap()
}
fn position(
span: &DiagnosticSpan,
line_number: usize,
column_offset_utf32: usize,
) -> lsp::Position {
let line_index = line_number - span.line_start;
let column_offset_encoded = match span.text.get(line_index) {
// Fast path.
Some(line) if line.text.is_ascii() => column_offset_utf32,
Some(line) => {
let line_prefix_len = line
.text
.char_indices()
.take(column_offset_utf32)
.last()
.map(|(pos, c)| pos + c.len_utf8())
.unwrap_or(0);
let line_prefix = &line.text[..line_prefix_len];
line_prefix.len()
}
None => column_offset_utf32,
};
lsp::Position {
line: (line_number as u32).saturating_sub(1),
character: column_offset_encoded as u32,
}
}
/// Checks whether a file name is from macro invocation and does not refer to an actual file.
fn is_dummy_macro_file(file_name: &str) -> bool {
file_name.starts_with('<') && file_name.ends_with('>')
}
/// Extracts a suitable "primary" location from a rustc diagnostic.
///
/// This takes locations pointing into the standard library, or generally outside the current
/// workspace into account and tries to avoid those, in case macros are involved.
fn primary_location(worktree_root: &Path, span: &DiagnosticSpan) -> lsp::Location {
let span_stack = std::iter::successors(Some(span), |span| Some(&span.expansion.as_ref()?.span));
for span in span_stack.clone() {
let abs_path = worktree_root.join(&span.file_name);
if !is_dummy_macro_file(&span.file_name) && abs_path.starts_with(worktree_root) {
return location(worktree_root, span);
}
}
// Fall back to the outermost macro invocation if no suitable span comes up.
let last_span = span_stack.last().unwrap();
location(worktree_root, last_span)
}

View File

@@ -3,11 +3,11 @@ use std::{ops::Range, sync::Arc};
use editor::{
Anchor, Editor, EditorSnapshot, ToOffset,
display_map::{BlockContext, BlockPlacement, BlockProperties, BlockStyle},
hover_popover::diagnostics_markdown_style,
hover_markdown_style,
scroll::Autoscroll,
};
use gpui::{AppContext, Entity, Focusable, WeakEntity};
use language::{BufferId, Diagnostic, DiagnosticEntry};
use language::{BufferId, DiagnosticEntry};
use lsp::DiagnosticSeverity;
use markdown::{Markdown, MarkdownElement};
use settings::Settings;
@@ -28,6 +28,7 @@ impl DiagnosticRenderer {
diagnostic_group: Vec<DiagnosticEntry<Point>>,
buffer_id: BufferId,
diagnostics_editor: Option<WeakEntity<ProjectDiagnosticsEditor>>,
merge_same_row: bool,
cx: &mut App,
) -> Vec<DiagnosticBlock> {
let Some(primary_ix) = diagnostic_group
@@ -37,87 +38,105 @@ impl DiagnosticRenderer {
return Vec::new();
};
let primary = diagnostic_group[primary_ix].clone();
let mut same_row = Vec::new();
let mut close = Vec::new();
let mut distant = Vec::new();
let group_id = primary.diagnostic.group_id;
let mut results = vec![];
for entry in diagnostic_group.iter() {
for (ix, entry) in diagnostic_group.into_iter().enumerate() {
if entry.diagnostic.is_primary {
let mut markdown = Self::markdown(&entry.diagnostic);
let diagnostic = &primary.diagnostic;
if diagnostic.source.is_some() || diagnostic.code.is_some() {
markdown.push_str(" (");
}
if let Some(source) = diagnostic.source.as_ref() {
markdown.push_str(&Markdown::escape(&source));
}
if diagnostic.source.is_some() && diagnostic.code.is_some() {
markdown.push(' ');
}
if let Some(code) = diagnostic.code.as_ref() {
if let Some(description) = diagnostic.code_description.as_ref() {
markdown.push('[');
markdown.push_str(&Markdown::escape(&code.to_string()));
markdown.push_str("](");
markdown.push_str(&Markdown::escape(description.as_ref()));
markdown.push(')');
} else {
markdown.push_str(&Markdown::escape(&code.to_string()));
}
}
if diagnostic.source.is_some() || diagnostic.code.is_some() {
markdown.push(')');
}
for (ix, entry) in diagnostic_group.iter().enumerate() {
if entry.range.start.row.abs_diff(primary.range.start.row) >= 5 {
markdown.push_str("\n- hint: [");
markdown.push_str(&Markdown::escape(&entry.diagnostic.message));
markdown.push_str(&format!(
"](file://#diagnostic-{buffer_id}-{group_id}-{ix})\n",
))
}
}
results.push(DiagnosticBlock {
initial_range: primary.range.clone(),
severity: primary.diagnostic.severity,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
continue;
}
if entry.range.start.row == primary.range.start.row && merge_same_row {
same_row.push(entry)
} else if entry.range.start.row.abs_diff(primary.range.start.row) < 5 {
let markdown = Self::markdown(&entry.diagnostic);
results.push(DiagnosticBlock {
initial_range: entry.range.clone(),
severity: entry.diagnostic.severity,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
close.push(entry)
} else {
let mut markdown = Self::markdown(&entry.diagnostic);
markdown.push_str(&format!(
" ([back](file://#diagnostic-{buffer_id}-{group_id}-{primary_ix}))"
));
results.push(DiagnosticBlock {
initial_range: entry.range.clone(),
severity: entry.diagnostic.severity,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
distant.push((ix, entry))
}
}
results
}
fn markdown(diagnostic: &Diagnostic) -> String {
let mut markdown = String::new();
let diagnostic = &primary.diagnostic;
markdown.push_str(&Markdown::escape(&diagnostic.message));
for entry in same_row {
markdown.push_str("\n- hint: ");
markdown.push_str(&Markdown::escape(&entry.diagnostic.message))
}
if diagnostic.source.is_some() || diagnostic.code.is_some() {
markdown.push_str(" (");
}
if let Some(source) = diagnostic.source.as_ref() {
markdown.push_str(&Markdown::escape(&source));
}
if diagnostic.source.is_some() && diagnostic.code.is_some() {
markdown.push(' ');
}
if let Some(code) = diagnostic.code.as_ref() {
if let Some(description) = diagnostic.code_description.as_ref() {
markdown.push('[');
markdown.push_str(&Markdown::escape(&code.to_string()));
markdown.push_str("](");
markdown.push_str(&Markdown::escape(description.as_ref()));
markdown.push(')');
} else {
markdown.push_str(&Markdown::escape(&code.to_string()));
}
}
if diagnostic.source.is_some() || diagnostic.code.is_some() {
markdown.push(')');
}
if let Some(md) = &diagnostic.markdown {
markdown.push_str(md);
} else {
markdown.push_str(&Markdown::escape(&diagnostic.message));
};
markdown
for (ix, entry) in &distant {
markdown.push_str("\n- hint: [");
markdown.push_str(&Markdown::escape(&entry.diagnostic.message));
markdown.push_str(&format!(
"](file://#diagnostic-{buffer_id}-{group_id}-{ix})\n",
))
}
let mut results = vec![DiagnosticBlock {
initial_range: primary.range,
severity: primary.diagnostic.severity,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
}];
for entry in close {
let markdown = if let Some(source) = entry.diagnostic.source.as_ref() {
format!("{}: {}", source, entry.diagnostic.message)
} else {
entry.diagnostic.message
};
let markdown = Markdown::escape(&markdown).to_string();
results.push(DiagnosticBlock {
initial_range: entry.range,
severity: entry.diagnostic.severity,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
}
for (_, entry) in distant {
let markdown = if let Some(source) = entry.diagnostic.source.as_ref() {
format!("{}: {}", source, entry.diagnostic.message)
} else {
entry.diagnostic.message
};
let mut markdown = Markdown::escape(&markdown).to_string();
markdown.push_str(&format!(
" ([back](file://#diagnostic-{buffer_id}-{group_id}-{primary_ix}))"
));
results.push(DiagnosticBlock {
initial_range: entry.range,
severity: entry.diagnostic.severity,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
}
results
}
}
@@ -130,7 +149,7 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer {
editor: WeakEntity<Editor>,
cx: &mut App,
) -> Vec<BlockProperties<Anchor>> {
let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, cx);
let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, true, cx);
blocks
.into_iter()
.map(|block| {
@@ -157,7 +176,8 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer {
buffer_id: BufferId,
cx: &mut App,
) -> Option<Entity<Markdown>> {
let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, cx);
let blocks =
Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, false, cx);
blocks.into_iter().find_map(|block| {
if block.initial_range == range {
Some(block.markdown)
@@ -191,7 +211,7 @@ impl DiagnosticBlock {
let cx = &bcx.app;
let status_colors = bcx.app.theme().status();
let max_width = bcx.em_width * 120.;
let max_width = bcx.em_width * 100.;
let (background_color, border_color) = match self.severity {
DiagnosticSeverity::ERROR => (status_colors.error_background, status_colors.error),
@@ -215,19 +235,16 @@ impl DiagnosticBlock {
.border_color(border_color)
.max_w(max_width)
.child(
MarkdownElement::new(
self.markdown.clone(),
diagnostics_markdown_style(bcx.window, cx),
)
.on_url_click({
move |link, window, cx| {
editor
.update(cx, |editor, cx| {
Self::open_link(editor, &diagnostics_editor, link, window, cx)
})
.ok();
}
}),
MarkdownElement::new(self.markdown.clone(), hover_markdown_style(bcx.window, cx))
.on_url_click({
move |link, window, cx| {
editor
.update(cx, |editor, cx| {
Self::open_link(editor, &diagnostics_editor, link, window, cx)
})
.ok();
}
}),
)
.into_any_element()
}

View File

@@ -1,4 +1,3 @@
mod cargo;
pub mod items;
mod toolbar_controls;
@@ -8,12 +7,7 @@ mod diagnostic_renderer;
mod diagnostics_tests;
use anyhow::Result;
use cargo::{
FetchStatus, FetchUpdate, cargo_diagnostics_sources, fetch_worktree_diagnostics,
is_outdated_cargo_fetch_diagnostic, map_rust_diagnostic_to_lsp, next_cargo_fetch_generation,
url_from_abs_path,
};
use collections::{BTreeSet, HashMap, HashSet};
use collections::{BTreeSet, HashMap};
use diagnostic_renderer::DiagnosticBlock;
use editor::{
DEFAULT_MULTIBUFFER_CONTEXT, Editor, EditorEvent, ExcerptRange, MultiBuffer, PathKey,
@@ -28,16 +22,13 @@ use gpui::{
use language::{
Bias, Buffer, BufferRow, BufferSnapshot, DiagnosticEntry, Point, ToTreeSitterPoint,
};
use lsp::{DiagnosticSeverity, LanguageServerId};
use project::{
DiagnosticSummary, Project, ProjectPath, Worktree,
lsp_store::rust_analyzer_ext::{CARGO_DIAGNOSTICS_SOURCE_NAME, RUST_ANALYZER_NAME},
project_settings::ProjectSettings,
};
use lsp::DiagnosticSeverity;
use project::{DiagnosticSummary, Project, ProjectPath, project_settings::ProjectSettings};
use settings::Settings;
use std::{
any::{Any, TypeId},
cmp::{self, Ordering},
cmp,
cmp::Ordering,
ops::{Range, RangeInclusive},
sync::Arc,
time::Duration,
@@ -53,10 +44,7 @@ use workspace::{
searchable::SearchableItemHandle,
};
actions!(
diagnostics,
[Deploy, ToggleWarnings, ToggleDiagnosticsRefresh]
);
actions!(diagnostics, [Deploy, ToggleWarnings]);
#[derive(Default)]
pub(crate) struct IncludeWarnings(bool);
@@ -79,15 +67,9 @@ pub(crate) struct ProjectDiagnosticsEditor {
paths_to_update: BTreeSet<ProjectPath>,
include_warnings: bool,
update_excerpts_task: Option<Task<Result<()>>>,
cargo_diagnostics_fetch: CargoDiagnosticsFetchState,
_subscription: Subscription,
}
struct CargoDiagnosticsFetchState {
task: Option<Task<()>>,
rust_analyzer: Option<LanguageServerId>,
}
impl EventEmitter<EditorEvent> for ProjectDiagnosticsEditor {}
const DIAGNOSTICS_UPDATE_DELAY: Duration = Duration::from_millis(50);
@@ -143,7 +125,6 @@ impl Render for ProjectDiagnosticsEditor {
.track_focus(&self.focus_handle(cx))
.size_full()
.on_action(cx.listener(Self::toggle_warnings))
.on_action(cx.listener(Self::toggle_diagnostics_refresh))
.child(child)
}
}
@@ -230,11 +211,7 @@ impl ProjectDiagnosticsEditor {
cx.observe_global_in::<IncludeWarnings>(window, |this, window, cx| {
this.include_warnings = cx.global::<IncludeWarnings>().0;
this.diagnostics.clear();
this.update_all_diagnostics(window, cx);
})
.detach();
cx.observe_release(&cx.entity(), |editor, _, cx| {
editor.stop_cargo_diagnostics_fetch(cx);
this.update_all_excerpts(window, cx);
})
.detach();
@@ -251,13 +228,9 @@ impl ProjectDiagnosticsEditor {
editor,
paths_to_update: Default::default(),
update_excerpts_task: None,
cargo_diagnostics_fetch: CargoDiagnosticsFetchState {
task: None,
rust_analyzer: None,
},
_subscription: project_event_subscription,
};
this.update_all_diagnostics(window, cx);
this.update_all_excerpts(window, cx);
this
}
@@ -265,17 +238,15 @@ impl ProjectDiagnosticsEditor {
if self.update_excerpts_task.is_some() {
return;
}
let project_handle = self.project.clone();
self.update_excerpts_task = Some(cx.spawn_in(window, async move |this, cx| {
cx.background_executor()
.timer(DIAGNOSTICS_UPDATE_DELAY)
.await;
loop {
let Some(path) = this.update(cx, |this, cx| {
let Some(path) = this.update(cx, |this, _| {
let Some(path) = this.paths_to_update.pop_first() else {
this.update_excerpts_task = None;
cx.notify();
this.update_excerpts_task.take();
return None;
};
Some(path)
@@ -335,32 +306,6 @@ impl ProjectDiagnosticsEditor {
cx.set_global(IncludeWarnings(!self.include_warnings));
}
fn toggle_diagnostics_refresh(
&mut self,
_: &ToggleDiagnosticsRefresh,
window: &mut Window,
cx: &mut Context<Self>,
) {
let fetch_cargo_diagnostics = ProjectSettings::get_global(cx)
.diagnostics
.fetch_cargo_diagnostics();
if fetch_cargo_diagnostics {
if self.cargo_diagnostics_fetch.task.is_some() {
self.stop_cargo_diagnostics_fetch(cx);
} else {
self.update_all_diagnostics(window, cx);
}
} else {
if self.update_excerpts_task.is_some() {
self.update_excerpts_task = None;
} else {
self.update_all_diagnostics(window, cx);
}
}
cx.notify();
}
fn focus_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
self.editor.focus_handle(cx).focus(window)
@@ -374,303 +319,6 @@ impl ProjectDiagnosticsEditor {
}
}
fn update_all_diagnostics(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cargo_diagnostics_sources = cargo_diagnostics_sources(self, cx);
if cargo_diagnostics_sources.is_empty() {
self.update_all_excerpts(window, cx);
} else {
self.fetch_cargo_diagnostics(Arc::new(cargo_diagnostics_sources), window, cx);
}
}
fn fetch_cargo_diagnostics(
&mut self,
diagnostics_sources: Arc<Vec<Entity<Worktree>>>,
window: &mut Window,
cx: &mut Context<Self>,
) {
self.cargo_diagnostics_fetch.task = Some(cx.spawn_in(window, async move |editor, cx| {
let rust_analyzer_server = editor
.update(cx, |editor, cx| {
editor
.project
.read(cx)
.language_server_with_name(RUST_ANALYZER_NAME, cx)
})
.ok();
let rust_analyzer_server = match rust_analyzer_server {
Some(rust_analyzer_server) => rust_analyzer_server.await,
None => None,
};
let mut worktree_diagnostics_tasks = Vec::new();
let mut paths_with_reported_cargo_diagnostics = HashSet::default();
if let Some(rust_analyzer_server) = rust_analyzer_server {
let can_continue = editor
.update(cx, |editor, cx| {
editor.cargo_diagnostics_fetch.rust_analyzer = Some(rust_analyzer_server);
let status_inserted =
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
if let Some(rust_analyzer_status) = lsp_store
.language_server_statuses
.get_mut(&rust_analyzer_server)
{
rust_analyzer_status
.progress_tokens
.insert(fetch_cargo_diagnostics_token());
paths_with_reported_cargo_diagnostics.extend(editor.diagnostics.iter().filter_map(|(buffer_id, diagnostics)| {
if diagnostics.iter().any(|d| d.diagnostic.source.as_deref() == Some(CARGO_DIAGNOSTICS_SOURCE_NAME)) {
Some(*buffer_id)
} else {
None
}
}).filter_map(|buffer_id| {
let buffer = lsp_store.buffer_store().read(cx).get(buffer_id)?;
let path = buffer.read(cx).file()?.as_local()?.abs_path(cx);
Some(url_from_abs_path(&path))
}));
true
} else {
false
}
});
if status_inserted {
editor.update_cargo_fetch_status(FetchStatus::Started, cx);
next_cargo_fetch_generation();
true
} else {
false
}
})
.unwrap_or(false);
if can_continue {
for worktree in diagnostics_sources.iter() {
if let Some(((_task, worktree_diagnostics), worktree_root)) = cx
.update(|_, cx| {
let worktree_root = worktree.read(cx).abs_path();
log::info!("Fetching cargo diagnostics for {worktree_root:?}");
fetch_worktree_diagnostics(&worktree_root, cx)
.zip(Some(worktree_root))
})
.ok()
.flatten()
{
let editor = editor.clone();
worktree_diagnostics_tasks.push(cx.spawn(async move |cx| {
let _task = _task;
let mut file_diagnostics = HashMap::default();
let mut diagnostics_total = 0;
let mut updated_urls = HashSet::default();
while let Ok(fetch_update) = worktree_diagnostics.recv().await {
match fetch_update {
FetchUpdate::Diagnostic(diagnostic) => {
for (url, diagnostic) in map_rust_diagnostic_to_lsp(
&worktree_root,
&diagnostic,
) {
let file_diagnostics = file_diagnostics
.entry(url)
.or_insert_with(Vec::<lsp::Diagnostic>::new);
let i = file_diagnostics
.binary_search_by(|probe| {
probe.range.start.cmp(&diagnostic.range.start)
.then(probe.range.end.cmp(&diagnostic.range.end))
.then(Ordering::Greater)
})
.unwrap_or_else(|i| i);
file_diagnostics.insert(i, diagnostic);
}
let file_changed = file_diagnostics.len() > 1;
if file_changed {
if editor
.update_in(cx, |editor, window, cx| {
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
for (uri, mut diagnostics) in
file_diagnostics.drain()
{
diagnostics.dedup();
diagnostics_total += diagnostics.len();
updated_urls.insert(uri.clone());
lsp_store.merge_diagnostics(
rust_analyzer_server,
lsp::PublishDiagnosticsParams {
uri,
diagnostics,
version: None,
},
&[],
|diagnostic, _| {
!is_outdated_cargo_fetch_diagnostic(diagnostic)
},
cx,
)?;
}
anyhow::Ok(())
})?;
editor.update_all_excerpts(window, cx);
anyhow::Ok(())
})
.ok()
.transpose()
.ok()
.flatten()
.is_none()
{
break;
}
}
}
FetchUpdate::Progress(message) => {
if editor
.update(cx, |editor, cx| {
editor.update_cargo_fetch_status(
FetchStatus::Progress { message },
cx,
);
})
.is_err()
{
return updated_urls;
}
}
}
}
editor
.update_in(cx, |editor, window, cx| {
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
for (uri, mut diagnostics) in
file_diagnostics.drain()
{
diagnostics.dedup();
diagnostics_total += diagnostics.len();
updated_urls.insert(uri.clone());
lsp_store.merge_diagnostics(
rust_analyzer_server,
lsp::PublishDiagnosticsParams {
uri,
diagnostics,
version: None,
},
&[],
|diagnostic, _| {
!is_outdated_cargo_fetch_diagnostic(diagnostic)
},
cx,
)?;
}
anyhow::Ok(())
})?;
editor.update_all_excerpts(window, cx);
anyhow::Ok(())
})
.ok();
log::info!("Fetched {diagnostics_total} cargo diagnostics for worktree {worktree_root:?}");
updated_urls
}));
}
}
} else {
log::info!(
"No rust-analyzer language server found, skipping diagnostics fetch"
);
}
}
let updated_urls = futures::future::join_all(worktree_diagnostics_tasks).await.into_iter().flatten().collect();
if let Some(rust_analyzer_server) = rust_analyzer_server {
editor
.update_in(cx, |editor, window, cx| {
editor
.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
for uri_to_cleanup in paths_with_reported_cargo_diagnostics.difference(&updated_urls).cloned() {
lsp_store.merge_diagnostics(
rust_analyzer_server,
lsp::PublishDiagnosticsParams {
uri: uri_to_cleanup,
diagnostics: Vec::new(),
version: None,
},
&[],
|diagnostic, _| {
!is_outdated_cargo_fetch_diagnostic(diagnostic)
},
cx,
).ok();
}
});
editor.update_all_excerpts(window, cx);
editor.stop_cargo_diagnostics_fetch(cx);
cx.notify();
})
.ok();
}
}));
}
fn update_cargo_fetch_status(&self, status: FetchStatus, cx: &mut App) {
let Some(rust_analyzer) = self.cargo_diagnostics_fetch.rust_analyzer else {
return;
};
let work_done = match status {
FetchStatus::Started => lsp::WorkDoneProgress::Begin(lsp::WorkDoneProgressBegin {
title: "cargo".to_string(),
cancellable: None,
message: Some("Fetching cargo diagnostics".to_string()),
percentage: None,
}),
FetchStatus::Progress { message } => {
lsp::WorkDoneProgress::Report(lsp::WorkDoneProgressReport {
message: Some(message),
cancellable: None,
percentage: None,
})
}
FetchStatus::Finished => {
lsp::WorkDoneProgress::End(lsp::WorkDoneProgressEnd { message: None })
}
};
let progress = lsp::ProgressParams {
token: lsp::NumberOrString::String(fetch_cargo_diagnostics_token()),
value: lsp::ProgressParamsValue::WorkDone(work_done),
};
self.project
.read(cx)
.lsp_store()
.update(cx, |lsp_store, cx| {
lsp_store.on_lsp_progress(progress, rust_analyzer, None, cx)
});
}
fn stop_cargo_diagnostics_fetch(&mut self, cx: &mut App) {
self.update_cargo_fetch_status(FetchStatus::Finished, cx);
self.cargo_diagnostics_fetch.task = None;
log::info!("Finished fetching cargo diagnostics");
}
/// Enqueue an update of all excerpts. Updates all paths that either
/// currently have diagnostics or are currently present in this view.
fn update_all_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -768,22 +416,26 @@ impl ProjectDiagnosticsEditor {
group,
buffer_snapshot.remote_id(),
Some(this.clone()),
true,
cx,
)
})?;
for item in more {
let i = blocks
.binary_search_by(|probe| {
probe
.initial_range
.start
.cmp(&item.initial_range.start)
.then(probe.initial_range.end.cmp(&item.initial_range.end))
.then(Ordering::Greater)
let insert_pos = blocks
.binary_search_by(|existing| {
match existing.initial_range.start.cmp(&item.initial_range.start) {
Ordering::Equal => item
.initial_range
.end
.cmp(&existing.initial_range.end)
.reverse(),
other => other,
}
})
.unwrap_or_else(|i| i);
blocks.insert(i, item);
.unwrap_or_else(|pos| pos);
blocks.insert(insert_pos, item);
}
}
@@ -796,25 +448,10 @@ impl ProjectDiagnosticsEditor {
&mut cx,
)
.await;
let i = excerpt_ranges
.binary_search_by(|probe| {
probe
.context
.start
.cmp(&excerpt_range.start)
.then(probe.context.end.cmp(&excerpt_range.end))
.then(probe.primary.start.cmp(&b.initial_range.start))
.then(probe.primary.end.cmp(&b.initial_range.end))
.then(cmp::Ordering::Greater)
})
.unwrap_or_else(|i| i);
excerpt_ranges.insert(
i,
ExcerptRange {
context: excerpt_range,
primary: b.initial_range.clone(),
},
)
excerpt_ranges.push(ExcerptRange {
context: excerpt_range,
primary: b.initial_range.clone(),
})
}
this.update_in(cx, |this, window, cx| {
@@ -885,7 +522,7 @@ impl ProjectDiagnosticsEditor {
markdown::MarkdownElement::rendered_text(
markdown.clone(),
cx,
editor::hover_popover::diagnostics_markdown_style,
editor::hover_markdown_style,
)
},
);
@@ -1286,7 +923,3 @@ fn is_line_blank_or_indented_less(
let line_indent = snapshot.line_indent_for_row(row);
line_indent.is_line_blank() || line_indent.len(tab_size) < indent_level
}
fn fetch_cargo_diagnostics_token() -> String {
"fetch_cargo_diagnostics".to_string()
}

View File

@@ -1,7 +1,4 @@
use std::sync::Arc;
use crate::cargo::cargo_diagnostics_sources;
use crate::{ProjectDiagnosticsEditor, ToggleDiagnosticsRefresh};
use crate::ProjectDiagnosticsEditor;
use gpui::{Context, Entity, EventEmitter, ParentElement, Render, WeakEntity, Window};
use ui::prelude::*;
use ui::{IconButton, IconButtonShape, IconName, Tooltip};
@@ -16,28 +13,18 @@ impl Render for ToolbarControls {
let mut include_warnings = false;
let mut has_stale_excerpts = false;
let mut is_updating = false;
let cargo_diagnostics_sources = Arc::new(
self.diagnostics()
.map(|editor| cargo_diagnostics_sources(editor.read(cx), cx))
.unwrap_or_default(),
);
let fetch_cargo_diagnostics = !cargo_diagnostics_sources.is_empty();
if let Some(editor) = self.diagnostics() {
let diagnostics = editor.read(cx);
include_warnings = diagnostics.include_warnings;
has_stale_excerpts = !diagnostics.paths_to_update.is_empty();
is_updating = if fetch_cargo_diagnostics {
diagnostics.cargo_diagnostics_fetch.task.is_some()
} else {
diagnostics.update_excerpts_task.is_some()
|| diagnostics
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some()
};
is_updating = diagnostics.update_excerpts_task.is_some()
|| diagnostics
.project
.read(cx)
.language_servers_running_disk_based_diagnostics(cx)
.next()
.is_some();
}
let tooltip = if include_warnings {
@@ -54,57 +41,21 @@ impl Render for ToolbarControls {
h_flex()
.gap_1()
.map(|div| {
if is_updating {
div.child(
IconButton::new("stop-updating", IconName::StopFilled)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.tooltip(Tooltip::for_action_title(
"Stop diagnostics update",
&ToggleDiagnosticsRefresh,
))
.on_click(cx.listener(move |toolbar_controls, _, _, cx| {
if let Some(diagnostics) = toolbar_controls.diagnostics() {
diagnostics.update(cx, |diagnostics, cx| {
diagnostics.stop_cargo_diagnostics_fetch(cx);
diagnostics.update_excerpts_task = None;
cx.notify();
});
}
})),
)
} else {
div.child(
IconButton::new("refresh-diagnostics", IconName::Update)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.disabled(!has_stale_excerpts && !fetch_cargo_diagnostics)
.tooltip(Tooltip::for_action_title(
"Refresh diagnostics",
&ToggleDiagnosticsRefresh,
))
.on_click(cx.listener({
move |toolbar_controls, _, window, cx| {
if let Some(diagnostics) = toolbar_controls.diagnostics() {
let cargo_diagnostics_sources =
Arc::clone(&cargo_diagnostics_sources);
diagnostics.update(cx, move |diagnostics, cx| {
if fetch_cargo_diagnostics {
diagnostics.fetch_cargo_diagnostics(
cargo_diagnostics_sources,
window,
cx,
);
} else {
diagnostics.update_all_excerpts(window, cx);
}
});
}
}
})),
)
}
.when(has_stale_excerpts, |div| {
div.child(
IconButton::new("update-excerpts", IconName::Update)
.icon_color(Color::Info)
.shape(IconButtonShape::Square)
.disabled(is_updating)
.tooltip(Tooltip::text("Update excerpts"))
.on_click(cx.listener(|this, _, window, cx| {
if let Some(diagnostics) = this.diagnostics() {
diagnostics.update(cx, |diagnostics, cx| {
diagnostics.update_all_excerpts(window, cx);
});
}
})),
)
})
.child(
IconButton::new("toggle-warnings", IconName::Warning)

View File

@@ -871,6 +871,7 @@ pub struct Editor {
show_breadcrumbs: bool,
show_gutter: bool,
show_scrollbars: bool,
disable_scrolling: bool,
disable_expand_excerpt_buttons: bool,
show_line_numbers: Option<bool>,
use_relative_line_numbers: Option<bool>,
@@ -1667,6 +1668,7 @@ impl Editor {
blink_manager: blink_manager.clone(),
show_local_selections: true,
show_scrollbars: true,
disable_scrolling: false,
mode,
show_breadcrumbs: EditorSettings::get_global(cx).toolbar.breadcrumbs,
show_gutter: mode.is_full(),
@@ -5005,11 +5007,11 @@ impl Editor {
range
};
ranges.push(range.clone());
ranges.push(range);
if !self.linked_edit_ranges.is_empty() {
let start_anchor = snapshot.anchor_before(range.start);
let end_anchor = snapshot.anchor_after(range.end);
let start_anchor = snapshot.anchor_before(selection.head());
let end_anchor = snapshot.anchor_after(selection.tail());
if let Some(ranges) = self
.linked_editing_ranges_for(start_anchor.text_anchor..end_anchor.text_anchor, cx)
{
@@ -16485,6 +16487,11 @@ impl Editor {
cx.notify();
}
pub fn disable_scrolling(&mut self, cx: &mut Context<Self>) {
self.disable_scrolling = true;
cx.notify();
}
pub fn set_show_line_numbers(&mut self, show_line_numbers: bool, cx: &mut Context<Self>) {
self.show_line_numbers = Some(show_line_numbers);
cx.notify();

View File

@@ -1,7 +1,6 @@
use super::*;
use crate::{
JoinLines,
linked_editing_ranges::LinkedEditingRanges,
scroll::scroll_amount::ScrollAmount,
test::{
assert_text_with_selections, build_editor,
@@ -19560,146 +19559,6 @@ async fn test_hide_mouse_context_menu_on_modal_opened(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_html_linked_edits_on_completion(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let fs = FakeFs::new(cx.executor());
fs.insert_file(path!("/file.html"), Default::default())
.await;
let project = Project::test(fs, [path!("/").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
let html_language = Arc::new(Language::new(
LanguageConfig {
name: "HTML".into(),
matcher: LanguageMatcher {
path_suffixes: vec!["html".to_string()],
..LanguageMatcher::default()
},
brackets: BracketPairConfig {
pairs: vec![BracketPair {
start: "<".into(),
end: ">".into(),
close: true,
..Default::default()
}],
..Default::default()
},
..Default::default()
},
Some(tree_sitter_html::LANGUAGE.into()),
));
language_registry.add(html_language);
let mut fake_servers = language_registry.register_fake_lsp(
"HTML",
FakeLspAdapter {
capabilities: lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions {
resolve_provider: Some(true),
..Default::default()
}),
..Default::default()
},
..Default::default()
},
);
let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let worktree_id = workspace
.update(cx, |workspace, _window, cx| {
workspace.project().update(cx, |project, cx| {
project.worktrees(cx).next().unwrap().read(cx).id()
})
})
.unwrap();
project
.update(cx, |project, cx| {
project.open_local_buffer_with_lsp(path!("/file.html"), cx)
})
.await
.unwrap();
let editor = workspace
.update(cx, |workspace, window, cx| {
workspace.open_path((worktree_id, "file.html"), None, true, window, cx)
})
.unwrap()
.await
.unwrap()
.downcast::<Editor>()
.unwrap();
let fake_server = fake_servers.next().await.unwrap();
editor.update_in(cx, |editor, window, cx| {
editor.set_text("<ad></ad>", window, cx);
editor.change_selections(None, window, cx, |selections| {
selections.select_ranges([Point::new(0, 3)..Point::new(0, 3)]);
});
let Some((buffer, _)) = editor
.buffer
.read(cx)
.text_anchor_for_position(editor.selections.newest_anchor().start, cx)
else {
panic!("Failed to get buffer for selection position");
};
let buffer = buffer.read(cx);
let buffer_id = buffer.remote_id();
let opening_range =
buffer.anchor_before(Point::new(0, 1))..buffer.anchor_after(Point::new(0, 3));
let closing_range =
buffer.anchor_before(Point::new(0, 6))..buffer.anchor_after(Point::new(0, 8));
let mut linked_ranges = HashMap::default();
linked_ranges.insert(
buffer_id,
vec![(opening_range.clone(), vec![closing_range.clone()])],
);
editor.linked_edit_ranges = LinkedEditingRanges(linked_ranges);
});
let mut completion_handle =
fake_server.set_request_handler::<lsp::request::Completion, _, _>(move |_, _| async move {
Ok(Some(lsp::CompletionResponse::Array(vec![
lsp::CompletionItem {
label: "head".to_string(),
text_edit: Some(lsp::CompletionTextEdit::InsertAndReplace(
lsp::InsertReplaceEdit {
new_text: "head".to_string(),
insert: lsp::Range::new(
lsp::Position::new(0, 1),
lsp::Position::new(0, 3),
),
replace: lsp::Range::new(
lsp::Position::new(0, 1),
lsp::Position::new(0, 3),
),
},
)),
..Default::default()
},
])))
});
editor.update_in(cx, |editor, window, cx| {
editor.show_completions(&ShowCompletions { trigger: None }, window, cx);
});
cx.run_until_parked();
completion_handle.next().await.unwrap();
editor.update(cx, |editor, _| {
assert!(
editor.context_menu_visible(),
"Completion menu should be visible"
);
});
editor.update_in(cx, |editor, window, cx| {
editor.confirm_completion(&ConfirmCompletion::default(), window, cx)
});
cx.executor().run_until_parked();
editor.update(cx, |editor, cx| {
assert_eq!(editor.text(cx), "<head></head>");
});
}
fn empty_range(row: usize, column: usize) -> Range<DisplayPoint> {
let point = DisplayPoint::new(DisplayRow(row as u32), column as u32);
point..point

View File

@@ -2276,9 +2276,6 @@ impl EditorElement {
}
let display_row = multibuffer_point.to_display_point(snapshot).row();
if !range.contains(&display_row) {
return None;
}
if row_infos
.get((display_row - range.start).0 as usize)
.is_some_and(|row_info| row_info.expand_info.is_some())
@@ -5681,7 +5678,9 @@ impl EditorElement {
}
fn paint_mouse_listeners(&mut self, layout: &EditorLayout, window: &mut Window, cx: &mut App) {
self.paint_scroll_wheel_listener(layout, window, cx);
if !self.editor.read(cx).disable_scrolling {
self.paint_scroll_wheel_listener(layout, window, cx);
}
window.on_mouse_event({
let position_map = layout.position_map.clone();

View File

@@ -655,59 +655,11 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
},
syntax: cx.theme().syntax().clone(),
selection_background_color: { cx.theme().players().local().selection },
heading: StyleRefinement::default()
.font_weight(FontWeight::BOLD)
.text_base()
.mt(rems(1.))
.mb_0(),
..Default::default()
}
}
pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let settings = ThemeSettings::get_global(cx);
let ui_font_family = settings.ui_font.family.clone();
let ui_font_fallbacks = settings.ui_font.fallbacks.clone();
let buffer_font_family = settings.buffer_font.family.clone();
let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone();
let mut base_text_style = window.text_style();
base_text_style.refine(&TextStyleRefinement {
font_family: Some(ui_font_family.clone()),
font_fallbacks: ui_font_fallbacks,
color: Some(cx.theme().colors().editor_foreground),
..Default::default()
});
MarkdownStyle {
base_text_style,
code_block: StyleRefinement::default().my(rems(1.)).font_buffer(cx),
inline_code: TextStyleRefinement {
background_color: Some(cx.theme().colors().editor_background.opacity(0.5)),
font_family: Some(buffer_font_family),
font_fallbacks: buffer_font_fallbacks,
..Default::default()
},
rule_color: cx.theme().colors().border,
block_quote_border_color: Color::Muted.color(cx),
block_quote: TextStyleRefinement {
color: Some(Color::Muted.color(cx)),
..Default::default()
},
link: TextStyleRefinement {
color: Some(cx.theme().colors().editor_foreground),
underline: Some(gpui::UnderlineStyle {
thickness: px(1.),
color: Some(cx.theme().colors().editor_foreground),
wavy: false,
}),
..Default::default()
},
syntax: cx.theme().syntax().clone(),
selection_background_color: { cx.theme().players().local().selection },
height_is_multiple_of_line_height: true,
heading: StyleRefinement::default()
.font_weight(FontWeight::BOLD)
.text_base()
.mt(rems(1.))
.mb_0(),
..Default::default()
}
@@ -999,7 +951,7 @@ impl DiagnosticPopover {
.child(
MarkdownElement::new(
self.markdown.clone(),
diagnostics_markdown_style(window, cx),
hover_markdown_style(window, cx),
)
.on_url_click(move |link, window, cx| {
if let Some(renderer) = GlobalDiagnosticRenderer::global(cx) {

View File

@@ -233,7 +233,7 @@ pub fn deploy_context_menu(
.separator()
.action("Cut", Box::new(Cut))
.action("Copy", Box::new(Copy))
.action("Copy and Trim", Box::new(CopyAndTrim))
.action("Copy and trim", Box::new(CopyAndTrim))
.action("Paste", Box::new(Paste))
.separator()
.map(|builder| {

View File

@@ -184,6 +184,9 @@ impl ScrollManager {
window: &mut Window,
cx: &mut Context<Editor>,
) {
if self.forbid_vertical_scroll {
return;
}
let (new_anchor, top_row) = if scroll_position.y <= 0. {
(
ScrollAnchor {
@@ -255,16 +258,10 @@ impl ScrollManager {
window: &mut Window,
cx: &mut Context<Editor>,
) {
let adjusted_anchor = if self.forbid_vertical_scroll {
ScrollAnchor {
offset: gpui::Point::new(anchor.offset.x, self.anchor.offset.y),
anchor: self.anchor.anchor,
}
} else {
anchor
};
self.anchor = adjusted_anchor;
if self.forbid_vertical_scroll {
return;
}
self.anchor = anchor;
cx.emit(EditorEvent::ScrollPositionChanged { local, autoscroll });
self.show_scrollbars(window, cx);
self.autoscroll_request.take();
@@ -407,12 +404,11 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let mut delta = scroll_delta;
if self.scroll_manager.forbid_vertical_scroll {
delta.y = 0.0;
return;
}
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let position = self.scroll_manager.anchor.scroll_position(&display_map) + delta;
let position = self.scroll_manager.anchor.scroll_position(&display_map) + scroll_delta;
self.set_scroll_position_taking_display_map(position, true, false, display_map, window, cx);
}
@@ -422,12 +418,10 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let mut position = scroll_position;
if self.scroll_manager.forbid_vertical_scroll {
let current_position = self.scroll_position(cx);
position.y = current_position.y;
return;
}
self.set_scroll_position_internal(position, true, false, window, cx);
self.set_scroll_position_internal(scroll_position, true, false, window, cx);
}
/// Scrolls so that `row` is at the top of the editor view.
@@ -486,15 +480,8 @@ impl Editor {
self.edit_prediction_preview
.set_previous_scroll_position(None);
let adjusted_position = if self.scroll_manager.forbid_vertical_scroll {
let current_position = self.scroll_manager.anchor.scroll_position(&display_map);
gpui::Point::new(scroll_position.x, current_position.y)
} else {
scroll_position
};
self.scroll_manager.set_scroll_position(
adjusted_position,
scroll_position,
&display_map,
local,
autoscroll,

Some files were not shown because too many files have changed in this diff Show More