Compare commits

..

10 Commits

Author SHA1 Message Date
Nathan Sobo
447eb8e1c9 Checkpoint 2025-04-21 07:08:03 -06:00
Nathan Sobo
e434117018 Checkpoint: Still a failing test for concurrent tool calls.
Seems like I'm surfacing a bug in Anthropic.
2025-04-20 20:50:20 -06:00
Nathan Sobo
36271b79b3 Failing test proving we need to batch tools per message 2025-04-20 19:04:37 -06:00
Nathan Sobo
41644a53cc Checkpoint 2025-04-20 17:56:42 -06:00
Nathan Sobo
08a9c4af09 Checkpoint 2025-04-20 17:54:33 -06:00
Nathan Sobo
3187f28405 Checkpoint 2025-04-20 17:28:44 -06:00
Nathan Sobo
101f3b100f Get a basic request/reply tested in AgentThread 2025-04-20 00:41:03 -06:00
Nathan Sobo
39c8b7bf5f Add agent_thread crate
Experimental for now, I want to try really integration testing it
against the real APIs in a more "eval style", meaning embrace the
stochastic nature of it.
2025-04-20 00:17:38 -06:00
Nathan Sobo
08b41252f6 Include role in start message 2025-04-20 00:16:49 -06:00
Nathan Sobo
152bbca238 Add gpui helpers 2025-04-20 00:16:08 -06:00
348 changed files with 6620 additions and 12637 deletions

View File

@@ -1,71 +0,0 @@
name: Run Agent Eval
on:
schedule:
- cron: "0 * * * *"
pull_request:
branches:
- "**"
types: [opened, synchronize, reopened, labeled]
workflow_dispatch:
concurrency:
# Allow only one workflow per any non-`main` branch.
group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
RUST_BACKTRACE: 1
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
ZED_CLIENT_CHECKSUM_SEED: ${{ secrets.ZED_CLIENT_CHECKSUM_SEED }}
ZED_EVAL_TELEMETRY: 1
jobs:
run_eval:
timeout-minutes: 60
name: Run Agent Eval
if: >
github.repository_owner == 'zed-industries' &&
(github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-eval'))
runs-on:
- buildjet-16vcpu-ubuntu-2204
steps:
- name: Add Rust to the PATH
run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Cache dependencies
uses: swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2
with:
save-if: ${{ github.ref == 'refs/heads/main' }}
cache-provider: "buildjet"
- name: Install Linux dependencies
run: ./script/linux
- name: Configure CI
run: |
mkdir -p ./../.cargo
cp ./.cargo/ci-config.toml ./../.cargo/config.toml
- name: Compile eval
run: cargo build --package=eval
- name: Run eval
run: cargo run --package=eval -- --repetitions=3 --concurrency=1
# Even the Linux runner is not stateful, in theory there is no need to do this cleanup.
# But, to avoid potential issues in the future if we choose to use a stateful Linux runner and forget to add code
# to clean up the config file, Ive included the cleanup code here as a precaution.
# While its not strictly necessary at this moment, I believe its better to err on the side of caution.
- name: Clean CI config file
if: always()
run: rm -rf ./../.cargo

View File

@@ -0,0 +1,28 @@
name: Run Eval Daily
on:
schedule:
- cron: "0 2 * * *"
workflow_dispatch:
env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0
RUST_BACKTRACE: 1
jobs:
run_eval:
name: Run Eval
if: github.repository_owner == 'zed-industries'
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
clean: false
- name: Setup Rust
uses: dtolnay/rust-toolchain@stable
- name: Run cargo eval
run: cargo run -p eval

11
.rules
View File

@@ -5,7 +5,6 @@
* Prefer implementing functionality in existing files unless it is a new logical component. Avoid creating many small files.
* Avoid using functions that panic like `unwrap()`, instead use mechanisms like `?` to propagate errors.
* Be careful with operations like indexing which may panic if the indexes are out of bounds.
* Never create files with `mod.rs` paths - prefer `src/some_module.rs` instead of `src/some_module/mod.rs`.
# GPUI
@@ -109,13 +108,3 @@ When a view's state has changed in a way that may affect its rendering, it shoul
While updating an entity (`cx: Context<T>`), it can emit an event using `cx.emit(event)`. Entities register which events they can emit by declaring `impl EventEmittor<EventType> for EntityType {}`.
Other entities can then register a callback to handle these events by doing `cx.subscribe(other_entity, |this, other_entity, event, cx| ...)`. This will return a `Subscription` which deregisters the callback when dropped. Typically `cx.subscribe` happens when creating a new entity and the subscriptions are stored in a `_subscriptions: Vec<Subscription>` field.
## Recent API changes
GPUI has had some changes to its APIs. Always write code using the new APIs:
* `spawn` methods now take async closures (`AsyncFn`), and so should be called like `cx.spawn(async move |cx| ...)`.
* Use `Entity<T>`. This replaces `Model<T>` and `View<T>` which longer exists and should NEVER be used.
* Use `App` references. This replaces `AppContext` which no longer exists and should NEVER be used.
* Use `Context<T>` references. This replaces `ModelContext<T>` which no longer exists and should NEVER be used.
* `Window` is now passed around explicitly. The new interface adds a `Window` reference parameter to some methods, and adds some new "*_in" methods for plumbing `Window`. The old types `WindowContext` and `ViewContext<T>` should NEVER be used.

52
Cargo.lock generated
View File

@@ -128,6 +128,36 @@ dependencies = [
"zed_llm_client",
]
[[package]]
name = "agent2"
version = "0.1.0"
dependencies = [
"anyhow",
"assistant_tool",
"assistant_tools",
"chrono",
"client",
"collections",
"ctor",
"env_logger 0.11.8",
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"language_model",
"language_models",
"parking_lot",
"project",
"reqwest_client",
"schemars",
"serde",
"serde_json",
"settings",
"smol",
"thiserror 2.0.12",
"util",
]
[[package]]
name = "ahash"
version = "0.7.8"
@@ -3042,7 +3072,6 @@ dependencies = [
"strum 0.27.1",
"subtle",
"supermaven_api",
"task",
"telemetry_events",
"text",
"theme",
@@ -4014,7 +4043,6 @@ dependencies = [
"node_runtime",
"parking_lot",
"paths",
"proto",
"schemars",
"serde",
"serde_json",
@@ -4190,7 +4218,6 @@ dependencies = [
"command_palette_hooks",
"dap",
"db",
"debugger_tools",
"editor",
"env_logger 0.11.8",
"feature_flags",
@@ -4200,7 +4227,6 @@ dependencies = [
"language",
"log",
"menu",
"parking_lot",
"picker",
"pretty_assertions",
"project",
@@ -4895,13 +4921,13 @@ dependencies = [
"anyhow",
"assistant_tool",
"assistant_tools",
"async-trait",
"async-watch",
"chrono",
"clap",
"client",
"collections",
"context_server",
"dap",
"dirs 5.0.1",
"env_logger 0.11.8",
"extension",
@@ -4919,15 +4945,11 @@ dependencies = [
"paths",
"project",
"prompt_store",
"regex",
"release_channel",
"reqwest_client",
"rust-embed",
"serde",
"serde_json",
"settings",
"shellexpand 2.1.2",
"smol",
"telemetry",
"toml 0.8.20",
"unindent",
@@ -6282,7 +6304,6 @@ dependencies = [
"log",
"pest",
"pest_derive",
"rust-embed",
"serde",
"serde_json",
"thiserror 1.0.69",
@@ -7722,7 +7743,6 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
"partial-json-fixer",
"project",
"proto",
"schemars",
@@ -9838,12 +9858,6 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "partial-json-fixer"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
[[package]]
name = "password-hash"
version = "0.4.2"
@@ -11759,7 +11773,6 @@ dependencies = [
"client",
"clock",
"dap",
"dap_adapters",
"env_logger 0.11.8",
"extension",
"extension_host",
@@ -14231,12 +14244,11 @@ version = "0.1.0"
dependencies = [
"anyhow",
"collections",
"dap-types",
"futures 0.3.31",
"gpui",
"hex",
"parking_lot",
"pretty_assertions",
"proto",
"schemars",
"serde",
"serde_json",

View File

@@ -3,6 +3,7 @@ resolver = "2"
members = [
"crates/activity_indicator",
"crates/agent",
"crates/agent2",
"crates/anthropic",
"crates/askpass",
"crates/assets",
@@ -480,7 +481,6 @@ num-format = "0.4.4"
ordered-float = "2.1.1"
palette = { version = "0.7.5", default-features = false, features = ["std"] }
parking_lot = "0.12.1"
partial-json-fixer = "0.5.3"
pathdiff = "0.2"
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-image-icon lucide-image"><rect width="18" height="18" x="3" y="3" rx="2" ry="2"/><circle cx="9" cy="9" r="2"/><path d="m21 15-3.086-3.086a2 2 0 0 0-2.828 0L6 21"/></svg>

Before

Width:  |  Height:  |  Size: 372 B

View File

@@ -1028,10 +1028,10 @@
// Using `ctrl-shift-space` in Zed requires disabling the macOS global shortcut.
// System Preferences->Keyboard->Keyboard Shortcuts->Input Sources->Select the previous input source (uncheck)
"ctrl-shift-space": "terminal::ToggleViMode",
"ctrl-alt-up": "pane::SplitUp",
"ctrl-alt-down": "pane::SplitDown",
"ctrl-alt-left": "pane::SplitLeft",
"ctrl-alt-right": "pane::SplitRight"
"ctrl-k up": "pane::SplitUp",
"ctrl-k down": "pane::SplitDown",
"ctrl-k left": "pane::SplitLeft",
"ctrl-k right": "pane::SplitRight"
}
},
{

View File

@@ -830,13 +830,5 @@
// and Windows.
"alt-l": "editor::AcceptEditPrediction"
}
},
{
// Fixes https://github.com/zed-industries/zed/issues/29095 by ensuring that
// the last binding for editor::ToggleComments is not ctrl-c.
"context": "hack_to_fix_ctrl-c",
"bindings": {
"g c": "editor::ToggleComments"
}
}
]

View File

@@ -73,9 +73,9 @@ There are project rules that apply to these root directories:
{{/each}}
{{/if}}
{{#if has_user_rules}}
{{#if has_default_user_rules}}
The user has specified the following rules that should be applied:
{{#each user_rules}}
{{#each default_user_rules}}
{{#if title}}
Rules title: {{title}}

View File

@@ -1489,12 +1489,7 @@
"use_multiline_find": false,
"use_smartcase_find": false,
"highlight_on_yank_duration": 200,
"custom_digraphs": {},
// Cursor shape for the each mode.
// Specify the mode as the key and the shape as the value.
// The mode can be one of the following: "normal", "replace", "insert", "visual".
// The shape can be one of the following: "block", "bar", "underline", "hollow".
"cursor_shape": {}
"custom_digraphs": {}
},
// The server to connect to. If the environment variable
// ZED_SERVER_URL is set, it will override this setting.

View File

@@ -1,4 +1,4 @@
use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
use crate::context::{AssistantContext, ContextId, format_context_as_string};
use crate::context_picker::MentionLink;
use crate::thread::{
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
@@ -6,9 +6,7 @@ use crate::thread::{
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
@@ -268,6 +266,14 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
}
}
fn render_tool_use_markdown(
text: SharedString,
language_registry: Arc<LanguageRegistry>,
cx: &mut App,
) -> Entity<Markdown> {
cx.new(|cx| Markdown::new(text, Some(language_registry), None, cx))
}
fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
@@ -672,26 +678,6 @@ fn open_markdown_link(
})
.detach_and_log_err(cx);
}
Some(MentionLink::Selection(path, line_range)) => {
let open_task = workspace.update(cx, |workspace, cx| {
workspace.open_path(path, None, true, window, cx)
});
window
.spawn(cx, async move |cx| {
let active_editor = open_task
.await?
.downcast::<Editor>()
.context("Item is not an editor")?;
active_editor.update_in(cx, |editor, window, cx| {
editor.change_selections(Some(Autoscroll::center()), window, cx, |s| {
s.select_ranges([Point::new(line_range.start as u32, 0)
..Point::new(line_range.start as u32, 0)])
});
anyhow::Ok(())
})
})
.detach_and_log_err(cx);
}
Some(MentionLink::Thread(thread_id)) => workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
panel.update(cx, |panel, cx| {
@@ -702,12 +688,6 @@ fn open_markdown_link(
}
}),
Some(MentionLink::Fetch(url)) => cx.open_url(&url),
Some(MentionLink::Rules(prompt_id)) => window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: Some(prompt_id.0),
}),
cx,
),
None => cx.open_url(&text),
}
}
@@ -881,34 +861,21 @@ impl ActiveThread {
tool_output: SharedString,
cx: &mut Context<Self>,
) {
let rendered = self
.rendered_tool_uses
.entry(tool_use_id.clone())
.or_insert_with(|| RenderedToolUse {
label: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
input: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
output: cx.new(|cx| {
Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
}),
});
rendered.label.update(cx, |this, cx| {
this.replace(tool_label, cx);
});
rendered.input.update(cx, |this, cx| {
let input = format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
);
this.replace(input, cx);
});
rendered.output.update(cx, |this, cx| {
this.replace(tool_output, cx);
});
let rendered = RenderedToolUse {
label: render_tool_use_markdown(tool_label.into(), self.language_registry.clone(), cx),
input: render_tool_use_markdown(
format!(
"```json\n{}\n```",
serde_json::to_string_pretty(tool_input).unwrap_or_default()
)
.into(),
self.language_registry.clone(),
cx,
),
output: render_tool_use_markdown(tool_output, self.language_registry.clone(), cx),
};
self.rendered_tool_uses
.insert(tool_use_id.clone(), rendered);
}
fn handle_thread_event(
@@ -1001,19 +968,6 @@ impl ActiveThread {
);
}
}
ThreadEvent::StreamedToolUse {
tool_use_id,
ui_text,
input,
} => {
self.render_tool_use_markdown(
tool_use_id.clone(),
ui_text.clone(),
input,
"".into(),
cx,
);
}
ThreadEvent::ToolFinished {
pending_tool_use, ..
} => {
@@ -1032,7 +986,6 @@ impl ActiveThread {
}
}
ThreadEvent::CheckpointChanged => cx.notify(),
ThreadEvent::ReceivedTextChunk => {}
}
}
@@ -1095,21 +1048,9 @@ impl ActiveThread {
) {
let options = AgentNotification::window_options(screen, cx);
let project_name = self.workspace.upgrade().and_then(|workspace| {
workspace
.read(cx)
.project()
.read(cx)
.visible_worktrees(cx)
.next()
.map(|worktree| worktree.read(cx).root_name().to_string())
});
if let Some(screen_window) = cx
.open_window(options, |_, cx| {
cx.new(|_| {
AgentNotification::new(title.clone(), caption.clone(), icon, project_name)
})
cx.new(|_| AgentNotification::new(title.clone(), caption.clone(), icon))
})
.log_err()
{
@@ -1506,8 +1447,45 @@ impl ActiveThread {
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
let generating_label = (is_generating && is_last_message).then(|| {
Label::new("Generating")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animations(
"generating-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),
],
|mut label, animation_ix, delta| {
match animation_ix {
0 => {
let chars_to_show = (delta * 10.).ceil() as usize;
let text = &"Generating"[0..chars_to_show];
label.set_text(text);
}
1 => {
let text = match delta {
d if d < 0.25 => "Generating",
d if d < 0.5 => "Generating.",
d if d < 0.75 => "Generating..",
_ => "Generating...",
};
label.set_text(text);
}
_ => {}
}
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.map_element(|label| label.alpha(delta)),
)
});
// Don't render user messages that are just there for returning tool results.
if message.role == Role::User && thread.message_has_tool_results(message_id) {
@@ -1534,7 +1512,9 @@ impl ActiveThread {
.map(|(_, state)| state.editor.clone());
let colors = cx.theme().colors();
let active_color = colors.element_active;
let editor_bg_color = colors.editor_background;
let bg_user_message_header = editor_bg_color.blend(active_color.opacity(0.25));
let open_as_markdown = IconButton::new(("open-as-markdown", ix), IconName::FileCode)
.shape(ui::IconButtonShape::Square)
@@ -1661,7 +1641,7 @@ impl ActiveThread {
let message_content =
has_content.then(|| {
v_flex()
.gap_1()
.gap_1p5()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
@@ -1684,6 +1664,7 @@ impl ActiveThread {
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.pt_1()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
@@ -1698,6 +1679,7 @@ impl ActiveThread {
} else {
div()
.min_h_6()
.text_ui(cx)
.child(self.render_message_content(
message_id,
rendered_message,
@@ -1756,18 +1738,35 @@ impl ActiveThread {
.pb_4()
.child(
v_flex()
.bg(editor_bg_color)
.bg(colors.editor_background)
.rounded_lg()
.border_1()
.border_color(colors.border)
.shadow_md()
.child(div().p_2().children(message_content))
.child(
h_flex()
.p_1()
.border_t_1()
.border_color(colors.border_variant)
.justify_end()
.py_1()
.pl_2()
.pr_1()
.bg(bg_user_message_header)
.border_b_1()
.border_color(colors.border)
.justify_between()
.rounded_t_md()
.child(
h_flex()
.gap_1p5()
.child(
Icon::new(IconName::PersonCircle)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
Label::new("You")
.size(LabelSize::Small)
.color(Color::Muted),
),
)
.child(
h_flex()
.gap_1()
@@ -1820,12 +1819,8 @@ impl ActiveThread {
edit_message_editor.is_none() && allow_editing_message,
|this| {
this.child(
Button::new("edit-message", "Edit Message")
Button::new("edit-message", "Edit")
.label_size(LabelSize::Small)
.icon(IconName::Pencil)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.on_click(cx.listener({
let message_segments =
message.segments.clone();
@@ -1842,7 +1837,8 @@ impl ActiveThread {
},
),
),
),
)
.child(div().p_2().children(message_content)),
),
Role::Assistant => v_flex()
.id(("message-container", ix))
@@ -2081,13 +2077,11 @@ impl ActiveThread {
.map(|m| m.role)
.unwrap_or(Role::User);
let is_assistant_message = message_role == Role::Assistant;
let is_user_message = message_role == Role::User;
let is_assistant = message_role == Role::Assistant;
v_flex()
.text_ui(cx)
.gap_2()
.when(is_user_message, |this| this.text_xs())
.children(
rendered_message.segments.iter().enumerate().map(
|(index, segment)| match segment {
@@ -2108,28 +2102,10 @@ impl ActiveThread {
RenderedMessageSegment::Text(markdown) => {
let markdown_element = MarkdownElement::new(
markdown.clone(),
if is_user_message {
let mut style = default_markdown_style(window, cx);
let mut text_style = window.text_style();
let theme_settings = ThemeSettings::get_global(cx);
let buffer_font = theme_settings.buffer_font.family.clone();
let buffer_font_size = TextSize::Small.rems(cx);
text_style.refine(&TextStyleRefinement {
font_family: Some(buffer_font),
font_size: Some(buffer_font_size.into()),
..Default::default()
});
style.base_text_style = text_style;
style
} else {
default_markdown_style(window, cx)
},
default_markdown_style(window, cx),
);
let markdown_element = if is_assistant_message {
let markdown_element = if is_assistant {
markdown_element.code_block_renderer(
markdown::CodeBlockRenderer::Custom {
render: Arc::new({
@@ -2268,7 +2244,34 @@ impl ActiveThread {
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(AnimatedLabel::new("Thinking").size(LabelSize::Small)),
.child({
Label::new("Thinking")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animation(
"generating-label",
Animation::new(Duration::from_secs(1)).repeat(),
|mut label, delta| {
let text = match delta {
d if d < 0.25 => "Thinking",
d if d < 0.5 => "Thinking.",
d if d < 0.75 => "Thinking..",
_ => "Thinking...",
};
label.set_text(text);
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| {
label.map_element(|label| label.alpha(delta))
},
)
}),
)
.child(
h_flex()
@@ -2466,18 +2469,16 @@ impl ActiveThread {
.upgrade()
.map(|workspace| workspace.read(cx).app_state().fs.clone());
let needs_confirmation = matches!(&tool_use.status, ToolUseStatus::NeedsConfirmation);
let needs_confirmation_tools = tool_use.needs_confirmation;
let edit_tools = tool_use.needs_confirmation;
let status_icons = div().child(match &tool_use.status {
ToolUseStatus::NeedsConfirmation => {
ToolUseStatus::Pending | ToolUseStatus::NeedsConfirmation => {
let icon = Icon::new(IconName::Warning)
.color(Color::Warning)
.size(IconSize::Small);
icon.into_any_element()
}
ToolUseStatus::Pending
| ToolUseStatus::InputStillStreaming
| ToolUseStatus::Running => {
ToolUseStatus::Running => {
let icon = Icon::new(IconName::ArrowCircle)
.color(Color::Accent)
.size(IconSize::Small);
@@ -2563,34 +2564,34 @@ impl ActiveThread {
}),
)),
),
ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child(
results_content_container()
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(
delta,
)))
},
),
)
.child(
Label::new("Running…")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
),
ToolUseStatus::Running => container.child(
results_content_container().child(
h_flex()
.gap_1()
.pb_1()
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(
delta,
)))
},
),
)
.child(
Label::new("Running…")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
),
),
ToolUseStatus::Error(_) => container.child(
results_content_container()
@@ -2654,8 +2655,8 @@ impl ActiveThread {
))
};
v_flex().gap_1().mb_2().map(|element| {
if !needs_confirmation_tools {
v_flex().gap_1().mb_3().map(|element| {
if !edit_tools {
element.child(
v_flex()
.child(
@@ -2833,7 +2834,30 @@ impl ActiveThread {
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(
AnimatedLabel::new("Waiting for Confirmation").size(LabelSize::Small)
Label::new("Waiting for Confirmation")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animation(
"generating-label",
Animation::new(Duration::from_secs(1)).repeat(),
|mut label, delta| {
let text = match delta {
d if d < 0.25 => "Waiting for Confirmation",
d if d < 0.5 => "Waiting for Confirmation.",
d if d < 0.75 => "Waiting for Confirmation..",
_ => "Waiting for Confirmation...",
};
label.set_text(text);
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.map_element(|label| label.alpha(delta)),
),
)
.child(
h_flex()
@@ -2933,10 +2957,10 @@ impl ActiveThread {
return div().into_any();
};
let user_rules_text = if project_context.user_rules.is_empty() {
let default_user_rules_text = if project_context.default_user_rules.is_empty() {
None
} else if project_context.user_rules.len() == 1 {
let user_rules = &project_context.user_rules[0];
} else if project_context.default_user_rules.len() == 1 {
let user_rules = &project_context.default_user_rules[0];
match user_rules.title.as_ref() {
Some(title) => Some(format!("Using \"{title}\" user rule")),
@@ -2945,14 +2969,14 @@ impl ActiveThread {
} else {
Some(format!(
"Using {} user rules",
project_context.user_rules.len()
project_context.default_user_rules.len()
))
};
let first_user_rules_id = project_context
.user_rules
let first_default_user_rules_id = project_context
.default_user_rules
.first()
.map(|user_rules| user_rules.uuid.0);
.map(|user_rules| user_rules.uuid);
let rules_files = project_context
.worktrees
@@ -2969,7 +2993,7 @@ impl ActiveThread {
rules_files => Some(format!("Using {} project rules files", rules_files.len())),
};
if user_rules_text.is_none() && rules_file_text.is_none() {
if default_user_rules_text.is_none() && rules_file_text.is_none() {
return div().into_any();
}
@@ -2977,42 +3001,45 @@ impl ActiveThread {
.pt_2()
.px_2p5()
.gap_1()
.when_some(user_rules_text, |parent, user_rules_text| {
parent.child(
h_flex()
.w_full()
.child(
Icon::new(RULES_ICON)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(user_rules_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.truncate()
.buffer_font(cx)
.ml_1p5()
.mr_0p5(),
)
.child(
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
.tooltip(Tooltip::text("View User Rules"))
.on_click(move |_event, window, cx| {
window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: first_user_rules_id,
}),
cx,
)
}),
),
)
})
.when_some(
default_user_rules_text,
|parent, default_user_rules_text| {
parent.child(
h_flex()
.w_full()
.child(
Icon::new(IconName::File)
.size(IconSize::XSmall)
.color(Color::Disabled),
)
.child(
Label::new(default_user_rules_text)
.size(LabelSize::XSmall)
.color(Color::Muted)
.truncate()
.buffer_font(cx)
.ml_1p5()
.mr_0p5(),
)
.child(
IconButton::new("open-prompt-library", IconName::ArrowUpRightAlt)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Ignored)
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
.tooltip(Tooltip::text("View User Rules"))
.on_click(move |_event, window, cx| {
window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_focus: first_default_user_rules_id,
}),
cx,
)
}),
),
)
},
)
.when_some(rules_file_text, |parent, rules_file_text| {
parent.child(
h_flex()
@@ -3232,10 +3259,12 @@ pub(crate) fn open_context(
}
}
AssistantContext::Directory(directory_context) => {
let entry_id = directory_context.entry_id;
let project_path = directory_context.project_path(cx);
workspace.update(cx, |workspace, cx| {
workspace.project().update(cx, |_project, cx| {
cx.emit(project::Event::RevealInProjectPanel(entry_id));
workspace.project().update(cx, |project, cx| {
if let Some(entry) = project.entry_for_path(&project_path, cx) {
cx.emit(project::Event::RevealInProjectPanel(entry.id));
}
})
})
}
@@ -3258,15 +3287,15 @@ pub(crate) fn open_context(
.detach();
}
}
AssistantContext::Selection(selection_context) => {
if let Some(project_path) = selection_context
AssistantContext::Excerpt(excerpt_context) => {
if let Some(project_path) = excerpt_context
.context_buffer
.buffer
.read(cx)
.project_path(cx)
{
let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot();
let target_position = selection_context.range.start.to_point(&snapshot);
let snapshot = excerpt_context.context_buffer.buffer.read(cx).snapshot();
let target_position = excerpt_context.range.start.to_point(&snapshot);
open_editor_at_position(project_path, target_position, &workspace, window, cx)
.detach();
@@ -3287,13 +3316,6 @@ pub(crate) fn open_context(
}
})
}
AssistantContext::Rules(rules_context) => window.dispatch_action(
Box::new(OpenPromptLibrary {
prompt_to_select: Some(rules_context.prompt_id.0),
}),
cx,
),
AssistantContext::Image(_) => {}
}
}

View File

@@ -1,4 +1,4 @@
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent, ui::AnimatedLabel};
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet};
@@ -8,8 +8,8 @@ use editor::{
scroll::Autoscroll,
};
use gpui::{
Action, AnyElement, AnyView, App, Empty, Entity, EventEmitter, FocusHandle, Focusable,
SharedString, Subscription, Task, WeakEntity, Window, prelude::*,
Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
Subscription, Task, WeakEntity, Window, prelude::*,
};
use language::{Capability, DiskState, OffsetRangeExt, Point};
use multi_buffer::PathKey;
@@ -307,10 +307,6 @@ impl AgentDiff {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.thread.read(cx).is_generating() {
return;
}
let snapshot = self.multibuffer.read(cx).snapshot(cx);
let diff_hunks_in_ranges = self
.editor
@@ -343,10 +339,6 @@ impl AgentDiff {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.thread.read(cx).is_generating() {
return;
}
let snapshot = self.multibuffer.read(cx).snapshot(cx);
let diff_hunks_in_ranges = self
.editor
@@ -658,11 +650,6 @@ fn render_diff_hunk_controls(
cx: &mut App,
) -> AnyElement {
let editor = editor.clone();
if agent_diff.read(cx).thread.read(cx).is_generating() {
return Empty.into_any();
}
h_flex()
.h(line_height)
.mr_0p5()
@@ -870,14 +857,8 @@ impl Render for AgentDiffToolbar {
None => return div(),
};
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
if is_generating {
return div()
.w(rems(6.5625)) // Arbitrary 105px size—so the label doesn't dance around
.child(AnimatedLabel::new("Generating"));
}
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
if is_empty {
return div();
}
@@ -988,7 +969,7 @@ mod tests {
.await
.unwrap();
cx.update(|_, cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit(

View File

@@ -16,7 +16,7 @@ use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageMod
use settings::{Settings, update_settings_file};
use ui::{
Disclosure, Divider, DividerColor, ElevationIndex, Indicator, Scrollbar, ScrollbarState,
Switch, SwitchColor, Tooltip, prelude::*,
Switch, Tooltip, prelude::*,
};
use util::ResultExt as _;
use zed_actions::ExtensionCategoryFilter;
@@ -236,7 +236,6 @@ impl AssistantConfiguration {
"always-allow-tool-actions-switch",
always_allow_tool_actions.into(),
)
.color(SwitchColor::Accent)
.on_click({
let fs = self.fs.clone();
move |state, _window, cx| {
@@ -333,44 +332,41 @@ impl AssistantConfiguration {
),
)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager =
self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected
| ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx)
.log_err();
});
}
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(
context_server,
cx,
)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
Switch::new("context-server-switch", is_running.into()).on_click({
let context_server_manager =
self.context_server_manager.clone();
let context_server = context_server.clone();
move |state, _window, cx| match state {
ToggleState::Unselected | ToggleState::Indeterminate => {
context_server_manager.update(cx, |this, cx| {
this.stop_server(context_server.clone(), cx)
.log_err();
});
}
}),
ToggleState::Selected => {
cx.spawn({
let context_server_manager =
context_server_manager.clone();
let context_server = context_server.clone();
async move |cx| {
if let Some(start_server_task) =
context_server_manager
.update(cx, |this, cx| {
this.start_server(
context_server,
cx,
)
})
.log_err()
{
start_server_task.await.log_err();
}
}
})
.detach();
}
}
}),
),
)
.map(|parent| {

View File

@@ -1,7 +1,7 @@
use assistant_settings::AssistantSettings;
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
use language_model::LanguageModelRegistry;
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
@@ -9,12 +9,17 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
pub use language_model_selector::ModelType;
#[derive(Clone, Copy)]
pub enum ModelType {
Default,
InlineAssistant,
}
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
focus_handle: FocusHandle,
model_type: ModelType,
}
impl AssistantModelSelector {
@@ -58,13 +63,13 @@ impl AssistantModelSelector {
}
}
},
model_type,
window,
cx,
)
}),
menu_handle,
focus_handle,
model_type,
}
}
@@ -77,7 +82,11 @@ impl Render for AssistantModelSelector {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle.clone();
let model = self.selector.read(cx).active_model(cx);
let model_registry = LanguageModelRegistry::read_global(cx);
let model = match self.model_type {
ModelType::Default => model_registry.default_model(),
ModelType::InlineAssistant => model_registry.inline_assistant_model(),
};
let (model_name, model_icon) = match model {
Some(model) => (model.model.name().0, Some(model.provider.icon())),
_ => (SharedString::from("No model selected"), None),

View File

@@ -25,7 +25,7 @@ use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
use language_model_selector::ToggleModelSelector;
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, PromptId};
use proto::Plan;
use settings::{Settings, update_settings_file};
use time::UtcOffset;
@@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
}
})
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(action, window, cx)
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx)
});
}
})
@@ -502,9 +502,7 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }),
cx,
)
.detach_and_log_err(cx);
@@ -1550,7 +1548,7 @@ impl AssistantPanel {
fn render_usage_banner(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let usage = self.thread.read(cx).last_usage()?;
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage).into_any_element())
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage.amount).into_any_element())
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -1951,9 +1949,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
.collect::<Vec<_>>();
for (buffer, range) in selection_ranges {
store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
store.add_excerpt(range, buffer, cx).detach_and_log_err(cx);
}
})
})

View File

@@ -1,7 +1,7 @@
use crate::context::attach_context_to_message;
use crate::context_store::ContextStore;
use crate::inline_prompt_editor::CodegenStatus;
use anyhow::Result;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
@@ -131,12 +131,7 @@ impl BufferCodegen {
cx.notify();
}
pub fn start(
&mut self,
primary_model: Arc<dyn LanguageModel>,
user_prompt: String,
cx: &mut Context<Self>,
) -> Result<()> {
pub fn start(&mut self, user_prompt: String, cx: &mut Context<Self>) -> Result<()> {
let alternative_models = LanguageModelRegistry::read_global(cx)
.inline_alternative_models()
.to_vec();
@@ -160,6 +155,11 @@ impl BufferCodegen {
}));
}
let primary_model = LanguageModelRegistry::read_global(cx)
.default_model()
.context("no active model")?
.model;
for (model, alternative) in iter::once(primary_model)
.chain(alternative_models)
.zip(&self.alternatives)

View File

@@ -1,15 +1,9 @@
use std::{
ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
use std::{ops::Range, path::Path, sync::Arc};
use futures::{FutureExt, future::Shared};
use gpui::{App, Entity, SharedString, Task};
use language::Buffer;
use language_model::{LanguageModelImage, LanguageModelRequestMessage};
use project::{ProjectEntryId, ProjectPath, Worktree};
use prompt_store::UserPromptId;
use gpui::{App, Entity, SharedString};
use language::{Buffer, File};
use language_model::LanguageModelRequestMessage;
use project::{ProjectPath, Worktree};
use rope::Point;
use serde::{Deserialize, Serialize};
use text::{Anchor, BufferId};
@@ -18,8 +12,6 @@ use util::post_inc;
use crate::thread::Thread;
pub const RULES_ICON: IconName = IconName::Context;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct ContextId(pub(crate) usize);
@@ -28,16 +20,13 @@ impl ContextId {
Self(post_inc(&mut self.0))
}
}
pub enum ContextKind {
File,
Directory,
Symbol,
Selection,
Excerpt,
FetchedUrl,
Thread,
Rules,
Image,
}
impl ContextKind {
@@ -46,11 +35,9 @@ impl ContextKind {
ContextKind::File => IconName::File,
ContextKind::Directory => IconName::Folder,
ContextKind::Symbol => IconName::Code,
ContextKind::Selection => IconName::Context,
ContextKind::Excerpt => IconName::Code,
ContextKind::FetchedUrl => IconName::Globe,
ContextKind::Thread => IconName::MessageBubbles,
ContextKind::Rules => RULES_ICON,
ContextKind::Image => IconName::Image,
}
}
}
@@ -62,9 +49,7 @@ pub enum AssistantContext {
Symbol(SymbolContext),
FetchedUrl(FetchedUrlContext),
Thread(ThreadContext),
Selection(SelectionContext),
Rules(RulesContext),
Image(ImageContext),
Excerpt(ExcerptContext),
}
impl AssistantContext {
@@ -75,9 +60,7 @@ impl AssistantContext {
Self::Symbol(symbol) => symbol.id,
Self::FetchedUrl(url) => url.id,
Self::Thread(thread) => thread.id,
Self::Selection(selection) => selection.id,
Self::Rules(rules) => rules.id,
Self::Image(image) => image.id,
Self::Excerpt(excerpt) => excerpt.id,
}
}
}
@@ -92,25 +75,17 @@ pub struct FileContext {
pub struct DirectoryContext {
pub id: ContextId,
pub worktree: Entity<Worktree>,
pub entry_id: ProjectEntryId,
pub last_path: Arc<Path>,
pub path: Arc<Path>,
/// Buffers of the files within the directory.
pub context_buffers: Vec<ContextBuffer>,
}
impl DirectoryContext {
pub fn entry<'a>(&self, cx: &'a App) -> Option<&'a project::Entry> {
self.worktree.read(cx).entry_for_id(self.entry_id)
}
pub fn project_path(&self, cx: &App) -> Option<ProjectPath> {
let worktree = self.worktree.read(cx);
worktree
.entry_for_id(self.entry_id)
.map(|entry| ProjectPath {
worktree_id: worktree.id(),
path: entry.path.clone(),
})
pub fn project_path(&self, cx: &App) -> ProjectPath {
ProjectPath {
worktree_id: self.worktree.read(cx).id(),
path: self.path.clone(),
}
}
}
@@ -145,51 +120,17 @@ impl ThreadContext {
}
}
#[derive(Debug, Clone)]
pub struct ImageContext {
pub id: ContextId,
pub original_image: Arc<gpui::Image>,
pub image_task: Shared<Task<Option<LanguageModelImage>>>,
}
impl ImageContext {
pub fn image(&self) -> Option<LanguageModelImage> {
self.image_task.clone().now_or_never().flatten()
}
pub fn is_loading(&self) -> bool {
self.image_task.clone().now_or_never().is_none()
}
pub fn is_error(&self) -> bool {
self.image_task
.clone()
.now_or_never()
.map(|result| result.is_none())
.unwrap_or(false)
}
}
#[derive(Clone)]
pub struct ContextBuffer {
pub id: BufferId,
// TODO: Entity<Buffer> holds onto the buffer even if the buffer is deleted. Should probably be
// TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
// a WeakEntity and handle removal from the UI when it has dropped.
pub buffer: Entity<Buffer>,
pub last_full_path: Arc<Path>,
pub file: Arc<dyn File>,
pub version: clock::Global,
pub text: SharedString,
}
impl ContextBuffer {
pub fn full_path(&self, cx: &App) -> PathBuf {
let file = self.buffer.read(cx).file();
// Note that in practice file can't be `None` because it is present when this is created and
// there's no way for buffers to go from having a file to not.
file.map_or(self.last_full_path.to_path_buf(), |file| file.full_path(cx))
}
}
impl std::fmt::Debug for ContextBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContextBuffer")
@@ -220,21 +161,13 @@ pub struct ContextSymbolId {
}
#[derive(Debug, Clone)]
pub struct SelectionContext {
pub struct ExcerptContext {
pub id: ContextId,
pub range: Range<Anchor>,
pub line_range: Range<Point>,
pub context_buffer: ContextBuffer,
}
#[derive(Debug, Clone)]
pub struct RulesContext {
pub id: ContextId,
pub prompt_id: UserPromptId,
pub title: SharedString,
pub text: SharedString,
}
/// Formats a collection of contexts into a string representation
pub fn format_context_as_string<'a>(
contexts: impl Iterator<Item = &'a AssistantContext>,
@@ -243,31 +176,27 @@ pub fn format_context_as_string<'a>(
let mut file_context = Vec::new();
let mut directory_context = Vec::new();
let mut symbol_context = Vec::new();
let mut selection_context = Vec::new();
let mut excerpt_context = Vec::new();
let mut fetch_context = Vec::new();
let mut thread_context = Vec::new();
let mut rules_context = Vec::new();
for context in contexts {
match context {
AssistantContext::File(context) => file_context.push(context),
AssistantContext::Directory(context) => directory_context.push(context),
AssistantContext::Symbol(context) => symbol_context.push(context),
AssistantContext::Selection(context) => selection_context.push(context),
AssistantContext::Excerpt(context) => excerpt_context.push(context),
AssistantContext::FetchedUrl(context) => fetch_context.push(context),
AssistantContext::Thread(context) => thread_context.push(context),
AssistantContext::Rules(context) => rules_context.push(context),
AssistantContext::Image(_) => {}
}
}
if file_context.is_empty()
&& directory_context.is_empty()
&& symbol_context.is_empty()
&& selection_context.is_empty()
&& excerpt_context.is_empty()
&& fetch_context.is_empty()
&& thread_context.is_empty()
&& rules_context.is_empty()
{
return None;
}
@@ -303,13 +232,13 @@ pub fn format_context_as_string<'a>(
result.push_str("</symbols>\n");
}
if !selection_context.is_empty() {
result.push_str("<selections>\n");
for context in selection_context {
if !excerpt_context.is_empty() {
result.push_str("<excerpts>\n");
for context in excerpt_context {
result.push_str(&context.context_buffer.text);
result.push('\n');
}
result.push_str("</selections>\n");
result.push_str("</excerpts>\n");
}
if !fetch_context.is_empty() {
@@ -334,18 +263,6 @@ pub fn format_context_as_string<'a>(
result.push_str("</conversation_threads>\n");
}
if !rules_context.is_empty() {
result.push_str(
"<user_rules>\n\
The user has specified the following rules that should be applied:\n\n",
);
for context in &rules_context {
result.push_str(&context.text);
result.push('\n');
}
result.push_str("</user_rules>\n");
}
result.push_str("</context>\n");
Some(result)
}

View File

@@ -1,7 +1,6 @@
mod completion_provider;
mod fetch_context_picker;
mod file_context_picker;
mod rules_context_picker;
mod symbol_context_picker;
mod thread_context_picker;
@@ -17,91 +16,30 @@ use gpui::{
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
WeakEntity,
};
use language::Buffer;
use multi_buffer::MultiBufferRow;
use project::{Entry, ProjectPath};
use prompt_store::UserPromptId;
use rules_context_picker::RulesContextEntry;
use symbol_context_picker::SymbolContextPicker;
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
use ui::{
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
};
use uuid::Uuid;
use workspace::{Workspace, notifications::NotifyResultExt};
use crate::AssistantPanel;
use crate::context::RULES_ICON;
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
use crate::context_picker::fetch_context_picker::FetchContextPicker;
use crate::context_picker::file_context_picker::FileContextPicker;
use crate::context_picker::rules_context_picker::RulesContextPicker;
use crate::context_picker::thread_context_picker::ThreadContextPicker;
use crate::context_store::ContextStore;
use crate::thread::ThreadId;
use crate::thread_store::ThreadStore;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerEntry {
Mode(ContextPickerMode),
Action(ContextPickerAction),
}
impl ContextPickerEntry {
pub fn keyword(&self) -> &'static str {
match self {
Self::Mode(mode) => mode.keyword(),
Self::Action(action) => action.keyword(),
}
}
pub fn label(&self) -> &'static str {
match self {
Self::Mode(mode) => mode.label(),
Self::Action(action) => action.label(),
}
}
pub fn icon(&self) -> IconName {
match self {
Self::Mode(mode) => mode.icon(),
Self::Action(action) => action.icon(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerMode {
File,
Symbol,
Fetch,
Thread,
Rules,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ContextPickerAction {
AddSelections,
}
impl ContextPickerAction {
pub fn keyword(&self) -> &'static str {
match self {
Self::AddSelections => "selection",
}
}
pub fn label(&self) -> &'static str {
match self {
Self::AddSelections => "Selection",
}
}
pub fn icon(&self) -> IconName {
match self {
Self::AddSelections => IconName::Context,
}
}
}
impl TryFrom<&str> for ContextPickerMode {
@@ -113,20 +51,18 @@ impl TryFrom<&str> for ContextPickerMode {
"symbol" => Ok(Self::Symbol),
"fetch" => Ok(Self::Fetch),
"thread" => Ok(Self::Thread),
"rules" => Ok(Self::Rules),
_ => Err(format!("Invalid context picker mode: {}", value)),
}
}
}
impl ContextPickerMode {
pub fn keyword(&self) -> &'static str {
pub fn mention_prefix(&self) -> &'static str {
match self {
Self::File => "file",
Self::Symbol => "symbol",
Self::Fetch => "fetch",
Self::Thread => "thread",
Self::Rules => "rules",
}
}
@@ -136,7 +72,6 @@ impl ContextPickerMode {
Self::Symbol => "Symbols",
Self::Fetch => "Fetch",
Self::Thread => "Threads",
Self::Rules => "Rules",
}
}
@@ -146,7 +81,6 @@ impl ContextPickerMode {
Self::Symbol => IconName::Code,
Self::Fetch => IconName::Globe,
Self::Thread => IconName::MessageBubbles,
Self::Rules => RULES_ICON,
}
}
}
@@ -158,7 +92,6 @@ enum ContextPickerState {
Symbol(Entity<SymbolContextPicker>),
Fetch(Entity<FetchContextPicker>),
Thread(Entity<ThreadContextPicker>),
Rules(Entity<RulesContextPicker>),
}
pub(super) struct ContextPicker {
@@ -222,13 +155,7 @@ impl ContextPicker {
.enumerate()
.map(|(ix, entry)| self.recent_menu_item(context_picker.clone(), ix, entry));
let entries = self
.workspace
.upgrade()
.map(|workspace| {
available_context_picker_entries(&self.thread_store, &workspace, cx)
})
.unwrap_or_default();
let modes = supported_context_picker_modes(&self.thread_store);
menu.when(has_recent, |menu| {
menu.custom_row(|_, _| {
@@ -244,15 +171,15 @@ impl ContextPicker {
})
.extend(recent_entries)
.when(has_recent, |menu| menu.separator())
.extend(entries.into_iter().map(|entry| {
.extend(modes.into_iter().map(|mode| {
let context_picker = context_picker.clone();
ContextMenuEntry::new(entry.label())
.icon(entry.icon())
ContextMenuEntry::new(mode.label())
.icon(mode.icon())
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.handler(move |window, cx| {
context_picker.update(cx, |this, cx| this.select_entry(entry, window, cx))
context_picker.update(cx, |this, cx| this.select_mode(mode, window, cx))
})
}))
.keep_open_on_confirm()
@@ -271,87 +198,61 @@ impl ContextPicker {
self.thread_store.is_some()
}
fn select_entry(
fn select_mode(
&mut self,
entry: ContextPickerEntry,
mode: ContextPickerMode,
window: &mut Window,
cx: &mut Context<Self>,
) {
let context_picker = cx.entity().downgrade();
match entry {
ContextPickerEntry::Mode(mode) => match mode {
ContextPickerMode::File => {
self.mode = ContextPickerState::File(cx.new(|cx| {
FileContextPicker::new(
match mode {
ContextPickerMode::File => {
self.mode = ContextPickerState::File(cx.new(|cx| {
FileContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
ContextPickerMode::Symbol => {
self.mode = ContextPickerState::Symbol(cx.new(|cx| {
SymbolContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
ContextPickerMode::Fetch => {
self.mode = ContextPickerState::Fetch(cx.new(|cx| {
FetchContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
ContextPickerMode::Thread => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Thread(cx.new(|cx| {
ThreadContextPicker::new(
thread_store.clone(),
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
ContextPickerMode::Symbol => {
self.mode = ContextPickerState::Symbol(cx.new(|cx| {
SymbolContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
ContextPickerMode::Rules => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Rules(cx.new(|cx| {
RulesContextPicker::new(
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
}
ContextPickerMode::Fetch => {
self.mode = ContextPickerState::Fetch(cx.new(|cx| {
FetchContextPicker::new(
context_picker.clone(),
self.workspace.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
ContextPickerMode::Thread => {
if let Some(thread_store) = self.thread_store.as_ref() {
self.mode = ContextPickerState::Thread(cx.new(|cx| {
ThreadContextPicker::new(
thread_store.clone(),
context_picker.clone(),
self.context_store.clone(),
window,
cx,
)
}));
}
}
},
ContextPickerEntry::Action(action) => match action {
ContextPickerAction::AddSelections => {
if let Some((context_store, workspace)) =
self.context_store.upgrade().zip(self.workspace.upgrade())
{
add_selections_as_context(&context_store, &workspace, cx);
}
cx.emit(DismissEvent);
}
},
}
}
cx.notify();
@@ -480,7 +381,6 @@ impl ContextPicker {
ContextPickerState::Symbol(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Fetch(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Thread(entity) => entity.update(cx, |_, cx| cx.notify()),
ContextPickerState::Rules(entity) => entity.update(cx, |_, cx| cx.notify()),
}
}
}
@@ -495,7 +395,6 @@ impl Focusable for ContextPicker {
ContextPickerState::Symbol(symbol_picker) => symbol_picker.focus_handle(cx),
ContextPickerState::Fetch(fetch_picker) => fetch_picker.focus_handle(cx),
ContextPickerState::Thread(thread_picker) => thread_picker.focus_handle(cx),
ContextPickerState::Rules(user_rules_picker) => user_rules_picker.focus_handle(cx),
}
}
}
@@ -511,9 +410,6 @@ impl Render for ContextPicker {
ContextPickerState::Symbol(symbol_picker) => parent.child(symbol_picker.clone()),
ContextPickerState::Fetch(fetch_picker) => parent.child(fetch_picker.clone()),
ContextPickerState::Thread(thread_picker) => parent.child(thread_picker.clone()),
ContextPickerState::Rules(user_rules_picker) => {
parent.child(user_rules_picker.clone())
}
})
}
}
@@ -525,37 +421,18 @@ enum RecentEntry {
Thread(ThreadContextEntry),
}
fn available_context_picker_entries(
fn supported_context_picker_modes(
thread_store: &Option<WeakEntity<ThreadStore>>,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Vec<ContextPickerEntry> {
let mut entries = vec![
ContextPickerEntry::Mode(ContextPickerMode::File),
ContextPickerEntry::Mode(ContextPickerMode::Symbol),
) -> Vec<ContextPickerMode> {
let mut modes = vec![
ContextPickerMode::File,
ContextPickerMode::Symbol,
ContextPickerMode::Fetch,
];
let has_selection = workspace
.read(cx)
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
.map_or(false, |editor| {
editor.update(cx, |editor, cx| editor.has_non_empty_selection(cx))
});
if has_selection {
entries.push(ContextPickerEntry::Action(
ContextPickerAction::AddSelections,
));
}
if thread_store.is_some() {
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
modes.push(ContextPickerMode::Thread);
}
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Fetch));
entries
modes
}
fn recent_context_picker_entries(
@@ -614,54 +491,6 @@ fn recent_context_picker_entries(
recent
}
fn add_selections_as_context(
context_store: &Entity<ContextStore>,
workspace: &Entity<Workspace>,
cx: &mut App,
) {
let selection_ranges = selection_ranges(workspace, cx);
context_store.update(cx, |context_store, cx| {
for (buffer, range) in selection_ranges {
context_store
.add_selection(buffer, range, cx)
.detach_and_log_err(cx);
}
})
}
fn selection_ranges(
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Vec<(Entity<Buffer>, Range<text::Anchor>)> {
let Some(editor) = workspace
.read(cx)
.active_item(cx)
.and_then(|item| item.act_as::<Editor>(cx))
else {
return Vec::new();
};
editor.update(cx, |editor, cx| {
let selections = editor.selections.all_adjusted(cx);
let buffer = editor.buffer().clone().read(cx);
let snapshot = buffer.snapshot(cx);
selections
.into_iter()
.map(|s| snapshot.anchor_after(s.start)..snapshot.anchor_before(s.end))
.flat_map(|range| {
let (start_buffer, start) = buffer.text_anchor_for_position(range.start, cx)?;
let (end_buffer, end) = buffer.text_anchor_for_position(range.end, cx)?;
if start_buffer != end_buffer {
return None;
}
Some((start_buffer, start..end))
})
.collect::<Vec<_>>()
})
}
pub(crate) fn insert_fold_for_mention(
excerpt_id: ExcerptId,
crease_start: text::Anchor,
@@ -681,11 +510,24 @@ pub(crate) fn insert_fold_for_mention(
let start = start.bias_right(&snapshot);
let end = snapshot.anchor_before(start.to_offset(&snapshot) + content_len);
let crease = crease_for_mention(
crease_label,
crease_icon_path,
let placeholder = FoldPlaceholder {
render: render_fold_icon_button(
crease_icon_path,
crease_label,
editor_entity.downgrade(),
),
merge_adjacent: false,
..Default::default()
};
let render_trailer =
move |_row, _unfold, _window: &mut Window, _cx: &mut App| Empty.into_any();
let crease = Crease::inline(
start..end,
editor_entity.downgrade(),
placeholder.clone(),
fold_toggle("mention"),
render_trailer,
);
editor.display_map.update(cx, |display_map, cx| {
@@ -694,29 +536,6 @@ pub(crate) fn insert_fold_for_mention(
});
}
pub fn crease_for_mention(
label: SharedString,
icon_path: SharedString,
range: Range<Anchor>,
editor_entity: WeakEntity<Editor>,
) -> Crease<Anchor> {
let placeholder = FoldPlaceholder {
render: render_fold_icon_button(icon_path, label, editor_entity),
merge_adjacent: false,
..Default::default()
};
let render_trailer = move |_row, _unfold, _window: &mut Window, _cx: &mut App| Empty.into_any();
let crease = Crease::inline(
range,
placeholder.clone(),
fold_toggle("mention"),
render_trailer,
);
crease
}
fn render_fold_icon_button(
icon_path: SharedString,
label: SharedString,
@@ -805,19 +624,15 @@ fn fold_toggle(
pub enum MentionLink {
File(ProjectPath, Entry),
Symbol(ProjectPath, String),
Selection(ProjectPath, Range<usize>),
Fetch(String),
Thread(ThreadId),
Rules(UserPromptId),
}
impl MentionLink {
const FILE: &str = "@file";
const SYMBOL: &str = "@symbol";
const SELECTION: &str = "@selection";
const THREAD: &str = "@thread";
const FETCH: &str = "@fetch";
const RULES: &str = "@rules";
const SEPARATOR: &str = ":";
@@ -825,9 +640,7 @@ impl MentionLink {
url.starts_with(Self::FILE)
|| url.starts_with(Self::SYMBOL)
|| url.starts_with(Self::FETCH)
|| url.starts_with(Self::SELECTION)
|| url.starts_with(Self::THREAD)
|| url.starts_with(Self::RULES)
}
pub fn for_file(file_name: &str, full_path: &str) -> String {
@@ -844,29 +657,12 @@ impl MentionLink {
)
}
pub fn for_selection(file_name: &str, full_path: &str, line_range: Range<usize>) -> String {
format!(
"[@{} ({}-{})]({}:{}:{}-{})",
file_name,
line_range.start,
line_range.end,
Self::SELECTION,
full_path,
line_range.start,
line_range.end
)
}
pub fn for_thread(thread: &ThreadContextEntry) -> String {
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
}
pub fn for_fetch(url: &str) -> String {
format!("[@{}]({}:{})", url, Self::FETCH, url)
}
pub fn for_rules(rules: &RulesContextEntry) -> String {
format!("[@{}]({}:{})", rules.title, Self::RULES, rules.prompt_id.0)
pub fn for_thread(thread: &ThreadContextEntry) -> String {
format!("[@{}]({}:{})", thread.summary, Self::THREAD, thread.id)
}
pub fn try_parse(link: &str, workspace: &Entity<Workspace>, cx: &App) -> Option<Self> {
@@ -905,29 +701,11 @@ impl MentionLink {
let project_path = extract_project_path_from_link(path, workspace, cx)?;
Some(MentionLink::Symbol(project_path, symbol.to_string()))
}
Self::SELECTION => {
let (path, line_args) = argument.split_once(Self::SEPARATOR)?;
let project_path = extract_project_path_from_link(path, workspace, cx)?;
let line_range = {
let (start, end) = line_args
.trim_start_matches('(')
.trim_end_matches(')')
.split_once('-')?;
start.parse::<usize>().ok()?..end.parse::<usize>().ok()?
};
Some(MentionLink::Selection(project_path, line_range))
}
Self::THREAD => {
let thread_id = ThreadId::from(argument);
Some(MentionLink::Thread(thread_id))
}
Self::FETCH => Some(MentionLink::Fetch(argument.to_string())),
Self::RULES => {
let prompt_id = UserPromptId(Uuid::try_parse(argument).ok()?);
Some(MentionLink::Rules(prompt_id))
}
_ => None,
}
}

View File

@@ -1,27 +1,24 @@
use std::cell::RefCell;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::Result;
use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _};
use editor::{CompletionProvider, Editor, ExcerptId};
use file_icons::FileIcons;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{App, Entity, Task, WeakEntity};
use http_client::HttpClientWithUrl;
use itertools::Itertools;
use language::{Buffer, CodeLabel, HighlightId};
use lsp::CompletionContext;
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
use prompt_store::PromptId;
use rope::Point;
use text::{Anchor, OffsetRangeExt, ToPoint};
use text::{Anchor, ToPoint};
use ui::prelude::*;
use workspace::Workspace;
use crate::context::RULES_ICON;
use crate::context_picker::file_context_picker::search_files;
use crate::context_picker::symbol_context_picker::search_symbols;
use crate::context_store::ContextStore;
@@ -29,12 +26,11 @@ use crate::thread_store::ThreadStore;
use super::fetch_context_picker::fetch_url_content;
use super::file_context_picker::FileMatch;
use super::rules_context_picker::{RulesContextEntry, search_rules};
use super::symbol_context_picker::SymbolMatch;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
available_context_picker_entries, recent_context_picker_entries, selection_ranges,
ContextPickerMode, MentionLink, RecentEntry, recent_context_picker_entries,
supported_context_picker_modes,
};
pub(crate) enum Match {
@@ -42,24 +38,22 @@ pub(crate) enum Match {
File(FileMatch),
Thread(ThreadMatch),
Fetch(SharedString),
Rules(RulesContextEntry),
Entry(EntryMatch),
Mode(ModeMatch),
}
pub struct EntryMatch {
pub struct ModeMatch {
mat: Option<StringMatch>,
entry: ContextPickerEntry,
mode: ContextPickerMode,
}
impl Match {
pub fn score(&self) -> f64 {
match self {
Match::File(file) => file.mat.score,
Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
Match::Mode(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.),
Match::Thread(_) => 1.,
Match::Symbol(_) => 1.,
Match::Fetch(_) => 1.,
Match::Rules(_) => 1.,
}
}
}
@@ -118,21 +112,6 @@ fn search(
Task::ready(Vec::new())
}
}
Some(ContextPickerMode::Rules) => {
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
let search_rules_task =
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
cx.background_spawn(async move {
search_rules_task
.await
.into_iter()
.map(Match::Rules)
.collect::<Vec<_>>()
})
} else {
Task::ready(Vec::new())
}
}
None => {
if query.is_empty() {
let mut matches = recent_entries
@@ -163,14 +142,9 @@ fn search(
.collect::<Vec<_>>();
matches.extend(
available_context_picker_entries(&thread_store, &workspace, cx)
supported_context_picker_modes(&thread_store)
.into_iter()
.map(|mode| {
Match::Entry(EntryMatch {
entry: mode,
mat: None,
})
}),
.map(|mode| Match::Mode(ModeMatch { mode, mat: None })),
);
Task::ready(matches)
@@ -180,11 +154,11 @@ fn search(
let search_files_task =
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
let entry_candidates = entries
let modes = supported_context_picker_modes(&thread_store);
let mode_candidates = modes
.iter()
.enumerate()
.map(|(ix, entry)| StringMatchCandidate::new(ix, entry.keyword()))
.map(|(ix, mode)| StringMatchCandidate::new(ix, mode.mention_prefix()))
.collect::<Vec<_>>();
cx.background_spawn(async move {
@@ -194,8 +168,8 @@ fn search(
.map(Match::File)
.collect::<Vec<_>>();
let entry_matches = fuzzy::match_strings(
&entry_candidates,
let mode_matches = fuzzy::match_strings(
&mode_candidates,
&query,
false,
100,
@@ -204,9 +178,9 @@ fn search(
)
.await;
matches.extend(entry_matches.into_iter().map(|mat| {
Match::Entry(EntryMatch {
entry: entries[mat.candidate_id],
matches.extend(mode_matches.into_iter().map(|mat| {
Match::Mode(ModeMatch {
mode: modes[mat.candidate_id],
mat: Some(mat),
})
}));
@@ -246,137 +220,19 @@ impl ContextPickerCompletionProvider {
}
}
fn completion_for_entry(
entry: ContextPickerEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Option<Completion> {
match entry {
ContextPickerEntry::Mode(mode) => Some(Completion {
replace_range: source_range.clone(),
new_text: format!("@{} ", mode.keyword()),
label: CodeLabel::plain(mode.label().to_string(), None),
icon_path: Some(mode.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(Arc::new(|_, _, _| true)),
}),
ContextPickerEntry::Action(action) => {
let (new_text, on_action) = match action {
ContextPickerAction::AddSelections => {
let selections = selection_ranges(workspace, cx);
let selection_infos = selections
.iter()
.map(|(buffer, range)| {
let full_path = buffer
.read(cx)
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| PathBuf::from("untitled"));
let file_name = full_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string();
let line_range = range.to_point(&buffer.read(cx).snapshot());
let link = MentionLink::for_selection(
&file_name,
&full_path.to_string_lossy(),
line_range.start.row as usize..line_range.end.row as usize,
);
(file_name, link, line_range)
})
.collect::<Vec<_>>();
let new_text = selection_infos.iter().map(|(_, link, _)| link).join(" ");
let callback = Arc::new({
let context_store = context_store.clone();
let selections = selections.clone();
let selection_infos = selection_infos.clone();
move |_, _: &mut Window, cx: &mut App| {
context_store.update(cx, |context_store, cx| {
for (buffer, range) in &selections {
context_store
.add_selection(buffer.clone(), range.clone(), cx)
.detach_and_log_err(cx)
}
});
let editor = editor.clone();
let selection_infos = selection_infos.clone();
cx.defer(move |cx| {
let mut current_offset = 0;
for (file_name, link, line_range) in selection_infos.iter() {
let snapshot =
editor.read(cx).buffer().read(cx).snapshot(cx);
let Some(start) = snapshot
.anchor_in_excerpt(excerpt_id, source_range.start)
else {
return;
};
let offset = start.to_offset(&snapshot) + current_offset;
let text_len = link.len();
let range = snapshot.anchor_after(offset)
..snapshot.anchor_after(offset + text_len);
let crease = super::crease_for_mention(
format!(
"{} ({}-{})",
file_name,
line_range.start.row + 1,
line_range.end.row + 1
)
.into(),
IconName::Context.path().into(),
range,
editor.downgrade(),
);
editor.update(cx, |editor, cx| {
editor.display_map.update(cx, |display_map, cx| {
display_map.fold(vec![crease], cx);
});
});
current_offset += text_len + 1;
}
});
false
}
});
(new_text, callback)
}
};
Some(Completion {
replace_range: source_range.clone(),
new_text,
label: CodeLabel::plain(action.label().to_string(), None),
icon_path: Some(action.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(on_action),
})
}
fn completion_for_mode(source_range: Range<Anchor>, mode: ContextPickerMode) -> Completion {
Completion {
replace_range: source_range.clone(),
new_text: format!("@{} ", mode.mention_prefix()),
label: CodeLabel::plain(mode.label().to_string(), None),
icon_path: Some(mode.icon().path().into()),
documentation: None,
source: project::CompletionSource::Custom,
insert_text_mode: None,
// This ensures that when a user accepts this completion, the
// completion menu will still be shown after "@category " is
// inserted
confirm: Some(Arc::new(|_, _, _| true)),
}
}
@@ -431,60 +287,6 @@ impl ContextPickerCompletionProvider {
}
}
fn completion_for_rules(
rules: RulesContextEntry,
excerpt_id: ExcerptId,
source_range: Range<Anchor>,
editor: Entity<Editor>,
context_store: Entity<ContextStore>,
thread_store: Entity<ThreadStore>,
) -> Completion {
let new_text = MentionLink::for_rules(&rules);
let new_text_len = new_text.len();
Completion {
replace_range: source_range.clone(),
new_text,
label: CodeLabel::plain(rules.title.to_string(), None),
documentation: None,
insert_text_mode: None,
source: project::CompletionSource::Custom,
icon_path: Some(RULES_ICON.path().into()),
confirm: Some(confirm_completion_callback(
RULES_ICON.path().into(),
rules.title.clone(),
excerpt_id,
source_range.start,
new_text_len,
editor.clone(),
move |cx| {
let prompt_uuid = rules.prompt_id;
let prompt_id = PromptId::User { uuid: prompt_uuid };
let context_store = context_store.clone();
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
log::error!("Can't add user rules as prompt store is missing.");
return;
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return;
};
let Some(title) = metadata.title else {
return;
};
let text_task = prompt_store.load(prompt_id, cx);
cx.spawn(async move |cx| {
let text = text_task.await?;
context_store.update(cx, |context_store, cx| {
context_store.add_rules(prompt_uuid, title, text, false, cx)
})
})
.detach_and_log_err(cx);
},
)),
}
}
fn completion_for_fetch(
source_range: Range<Anchor>,
url_to_fetch: SharedString,
@@ -791,17 +593,6 @@ impl CompletionProvider for ContextPickerCompletionProvider {
thread_store,
))
}
Match::Rules(user_rules) => {
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
Some(Self::completion_for_rules(
user_rules,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
thread_store,
))
}
Match::Fetch(url) => Some(Self::completion_for_fetch(
source_range.clone(),
url,
@@ -810,15 +601,9 @@ impl CompletionProvider for ContextPickerCompletionProvider {
context_store.clone(),
http_client.clone(),
)),
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
entry,
excerpt_id,
source_range.clone(),
editor.clone(),
context_store.clone(),
&workspace,
cx,
),
Match::Mode(ModeMatch { mode, .. }) => {
Some(Self::completion_for_mode(source_range.clone(), mode))
}
})
.collect()
})?))

View File

@@ -1,248 +0,0 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use anyhow::anyhow;
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
use picker::{Picker, PickerDelegate};
use prompt_store::{PromptId, UserPromptId};
use ui::{ListItem, prelude::*};
use crate::context::RULES_ICON;
use crate::context_picker::ContextPicker;
use crate::context_store::{self, ContextStore};
use crate::thread_store::ThreadStore;
pub struct RulesContextPicker {
picker: Entity<Picker<RulesContextPickerDelegate>>,
}
impl RulesContextPicker {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
RulesContextPicker { picker }
}
}
impl Focusable for RulesContextPicker {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.picker.focus_handle(cx)
}
}
impl Render for RulesContextPicker {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
self.picker.clone()
}
}
#[derive(Debug, Clone)]
pub struct RulesContextEntry {
pub prompt_id: UserPromptId,
pub title: SharedString,
}
pub struct RulesContextPickerDelegate {
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
matches: Vec<RulesContextEntry>,
selected_index: usize,
}
impl RulesContextPickerDelegate {
pub fn new(
thread_store: WeakEntity<ThreadStore>,
context_picker: WeakEntity<ContextPicker>,
context_store: WeakEntity<context_store::ContextStore>,
) -> Self {
RulesContextPickerDelegate {
thread_store,
context_picker,
context_store,
matches: Vec::new(),
selected_index: 0,
}
}
}
impl PickerDelegate for RulesContextPickerDelegate {
type ListItem = ListItem;
fn match_count(&self) -> usize {
self.matches.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) {
self.selected_index = ix;
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Search available rules…".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
this.delegate.matches = matches;
this.delegate.selected_index = 0;
cx.notify();
})
.ok();
})
}
fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(self.selected_index) else {
return;
};
let Some(thread_store) = self.thread_store.upgrade() else {
return;
};
let prompt_id = entry.prompt_id;
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
});
cx.spawn(async move |this, cx| {
let (metadata, text) = load_rules_task.await?;
let Some(title) = metadata.title else {
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
};
this.update(cx, |this, cx| {
this.delegate
.context_store
.update(cx, |context_store, cx| {
context_store.add_rules(prompt_id, title, text, true, cx)
})
.ok();
})
})
.detach_and_log_err(cx);
}
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
self.context_picker
.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
fn render_match(
&self,
ix: usize,
selected: bool,
_window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
let thread = &self.matches[ix];
Some(ListItem::new(ix).inset(true).toggle_state(selected).child(
render_thread_context_entry(thread, self.context_store.clone(), cx),
))
}
}
pub fn render_thread_context_entry(
user_rules: &RulesContextEntry,
context_store: WeakEntity<ContextStore>,
cx: &mut App,
) -> Div {
let added = context_store.upgrade().map_or(false, |ctx_store| {
ctx_store
.read(cx)
.includes_user_rules(&user_rules.prompt_id)
.is_some()
});
h_flex()
.gap_1p5()
.w_full()
.justify_between()
.child(
h_flex()
.gap_1p5()
.max_w_72()
.child(
Icon::new(RULES_ICON)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(Label::new(user_rules.title.clone()).truncate()),
)
.when(added, |el| {
el.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::Check)
.size(IconSize::Small)
.color(Color::Success),
)
.child(Label::new("Added").size(LabelSize::Small)),
)
})
}
pub(crate) fn search_rules(
query: String,
cancellation_flag: Arc<AtomicBool>,
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<RulesContextEntry>> {
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
return Task::ready(vec![]);
};
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
cx.background_spawn(async move {
search_task
.await
.into_iter()
.flat_map(|metadata| {
// Default prompts are filtered out as they are automatically included.
if metadata.default {
None
} else {
match metadata.id {
PromptId::EditWorkflow => None,
PromptId::User { uuid } => Some(RulesContextEntry {
prompt_id: uuid,
title: metadata.title?,
}),
}
}
})
.collect::<Vec<_>>()
})
}

View File

@@ -103,11 +103,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
let Some(thread_store) = self.thread_store.upgrade() else {
let Some(threads) = self.thread_store.upgrade() else {
return Task::ready(());
};
let search_task = search_threads(query, Arc::new(AtomicBool::default()), thread_store, cx);
let search_task = search_threads(query, Arc::new(AtomicBool::default()), threads, cx);
cx.spawn_in(window, async move |this, cx| {
let matches = search_task.await;
this.update(cx, |this, cx| {
@@ -217,15 +217,15 @@ pub(crate) fn search_threads(
thread_store: Entity<ThreadStore>,
cx: &mut App,
) -> Task<Vec<ThreadMatch>> {
let threads = thread_store
.read(cx)
.threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,
summary: thread.summary,
})
.collect::<Vec<_>>();
let threads = thread_store.update(cx, |this, _cx| {
this.threads()
.into_iter()
.map(|thread| ThreadContextEntry {
id: thread.id,
summary: thread.summary,
})
.collect::<Vec<_>>()
});
let executor = cx.background_executor().clone();
cx.background_spawn(async move {

View File

@@ -6,11 +6,9 @@ use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap, HashSet};
use futures::future::join_all;
use futures::{self, Future, FutureExt, future};
use gpui::{App, AppContext as _, Context, Entity, Image, SharedString, Task, WeakEntity};
use language::Buffer;
use language_model::LanguageModelImage;
use project::{Project, ProjectEntryId, ProjectItem, ProjectPath, Worktree};
use prompt_store::UserPromptId;
use gpui::{App, AppContext as _, Context, Entity, SharedString, Task, WeakEntity};
use language::{Buffer, File};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use rope::{Point, Rope};
use text::{Anchor, BufferId, OffsetRangeExt};
use util::{ResultExt as _, maybe};
@@ -18,8 +16,7 @@ use util::{ResultExt as _, maybe};
use crate::ThreadStore;
use crate::context::{
AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext,
FetchedUrlContext, FileContext, ImageContext, RulesContext, SelectionContext, SymbolContext,
ThreadContext,
ExcerptContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext,
};
use crate::context_strip::SuggestedContext;
use crate::thread::{Thread, ThreadId};
@@ -28,6 +25,7 @@ pub struct ContextStore {
project: WeakEntity<Project>,
context: Vec<AssistantContext>,
thread_store: Option<WeakEntity<ThreadStore>>,
// TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId.
next_context_id: ContextId,
files: BTreeMap<BufferId, ContextId>,
directories: HashMap<ProjectPath, ContextId>,
@@ -37,7 +35,6 @@ pub struct ContextStore {
threads: HashMap<ThreadId, ContextId>,
thread_summary_tasks: Vec<Task<()>>,
fetched_urls: HashMap<String, ContextId>,
user_rules: HashMap<UserPromptId, ContextId>,
}
impl ContextStore {
@@ -58,7 +55,6 @@ impl ContextStore {
threads: HashMap::default(),
thread_summary_tasks: Vec::new(),
fetched_urls: HashMap::default(),
user_rules: HashMap::default(),
}
}
@@ -76,7 +72,6 @@ impl ContextStore {
self.directories.clear();
self.threads.clear();
self.fetched_urls.clear();
self.user_rules.clear();
}
pub fn add_file_from_path(
@@ -114,12 +109,13 @@ impl ContextStore {
return anyhow::Ok(());
}
let context_buffer = this
.update(cx, |_, cx| load_context_buffer(buffer, cx))??
.await;
let (buffer_info, text_task) =
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, cx))??;
let text = text_task.await;
this.update(cx, |this, cx| {
this.insert_file(context_buffer, cx);
this.insert_file(make_context_buffer(buffer_info, text), cx);
})?;
anyhow::Ok(())
@@ -132,11 +128,14 @@ impl ContextStore {
cx: &mut Context<Self>,
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| {
let context_buffer = this
.update(cx, |_, cx| load_context_buffer(buffer, cx))??
.await;
let (buffer_info, text_task) =
this.update(cx, |_, cx| collect_buffer_info_and_text(buffer, cx))??;
this.update(cx, |this, cx| this.insert_file(context_buffer, cx))?;
let text = text_task.await;
this.update(cx, |this, cx| {
this.insert_file(make_context_buffer(buffer_info, text), cx)
})?;
anyhow::Ok(())
})
@@ -160,14 +159,6 @@ impl ContextStore {
return Task::ready(Err(anyhow!("failed to read project")));
};
let Some(entry_id) = project
.read(cx)
.entry_for_path(&project_path, cx)
.map(|entry| entry.id)
else {
return Task::ready(Err(anyhow!("no entry found for directory context")));
};
let already_included = match self.includes_directory(&project_path) {
Some(FileInclusion::Direct(context_id)) => {
if remove_if_exists {
@@ -209,15 +200,27 @@ impl ContextStore {
let buffers = open_buffers_task.await;
let context_buffer_tasks = this.update(cx, |_, cx| {
buffers
.into_iter()
.flatten()
.flat_map(move |buffer| load_context_buffer(buffer, cx).log_err())
.collect::<Vec<_>>()
})?;
let mut buffer_infos = Vec::new();
let mut text_tasks = Vec::new();
this.update(cx, |_, cx| {
// Skip all binary files and other non-UTF8 files
for buffer in buffers.into_iter().flatten() {
if let Some((buffer_info, text_task)) =
collect_buffer_info_and_text(buffer, cx).log_err()
{
buffer_infos.push(buffer_info);
text_tasks.push(text_task);
}
}
anyhow::Ok(())
})??;
let context_buffers = future::join_all(context_buffer_tasks).await;
let buffer_texts = future::join_all(text_tasks).await;
let context_buffers = buffer_infos
.into_iter()
.zip(buffer_texts)
.map(|(info, text)| make_context_buffer(info, text))
.collect::<Vec<_>>();
if context_buffers.is_empty() {
let full_path = cx.update(|cx| worktree.read(cx).full_path(&project_path.path))?;
@@ -225,7 +228,7 @@ impl ContextStore {
}
this.update(cx, |this, cx| {
this.insert_directory(worktree, entry_id, project_path, context_buffers, cx);
this.insert_directory(worktree, project_path, context_buffers, cx);
})?;
anyhow::Ok(())
@@ -235,21 +238,19 @@ impl ContextStore {
fn insert_directory(
&mut self,
worktree: Entity<Worktree>,
entry_id: ProjectEntryId,
project_path: ProjectPath,
context_buffers: Vec<ContextBuffer>,
cx: &mut Context<Self>,
) {
let id = self.next_context_id.post_inc();
let last_path = project_path.path.clone();
let path = project_path.path.clone();
self.directories.insert(project_path, id);
self.context
.push(AssistantContext::Directory(DirectoryContext {
id,
worktree,
entry_id,
last_path,
path,
context_buffers,
}));
cx.notify();
@@ -289,23 +290,27 @@ impl ContextStore {
}
}
let context_buffer_task =
match load_context_buffer_range(buffer, symbol_enclosing_range.clone(), cx) {
Ok((_line_range, context_buffer_task)) => context_buffer_task,
Err(err) => return Task::ready(Err(err)),
};
let (buffer_info, collect_content_task) = match collect_buffer_info_and_text_for_range(
buffer,
symbol_enclosing_range.clone(),
cx,
) {
Ok((_, buffer_info, collect_context_task)) => (buffer_info, collect_context_task),
Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |this, cx| {
let context_buffer = context_buffer_task.await;
let content = collect_content_task.await;
this.update(cx, |this, cx| {
this.insert_symbol(
make_context_symbol(
context_buffer,
buffer_info,
project_path,
symbol_name,
symbol_range,
symbol_enclosing_range,
content,
),
cx,
)
@@ -385,42 +390,6 @@ impl ContextStore {
cx.notify();
}
pub fn add_rules(
&mut self,
prompt_id: UserPromptId,
title: impl Into<SharedString>,
text: impl Into<SharedString>,
remove_if_exists: bool,
cx: &mut Context<ContextStore>,
) {
if let Some(context_id) = self.includes_user_rules(&prompt_id) {
if remove_if_exists {
self.remove_context(context_id, cx);
}
} else {
self.insert_user_rules(prompt_id, title, text, cx);
}
}
pub fn insert_user_rules(
&mut self,
prompt_id: UserPromptId,
title: impl Into<SharedString>,
text: impl Into<SharedString>,
cx: &mut Context<ContextStore>,
) {
let id = self.next_context_id.post_inc();
self.user_rules.insert(prompt_id, id);
self.context.push(AssistantContext::Rules(RulesContext {
id,
prompt_id,
title: title.into(),
text: text.into(),
}));
cx.notify();
}
pub fn add_fetched_url(
&mut self,
url: String,
@@ -450,54 +419,33 @@ impl ContextStore {
cx.notify();
}
pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
let id = self.next_context_id.post_inc();
self.context.push(AssistantContext::Image(ImageContext {
id,
original_image: image,
image_task,
}));
cx.notify();
}
pub fn wait_for_images(&self, cx: &App) -> Task<()> {
let tasks = self
.context
.iter()
.filter_map(|ctx| match ctx {
AssistantContext::Image(ctx) => Some(ctx.image_task.clone()),
_ => None,
})
.collect::<Vec<_>>();
cx.spawn(async move |_cx| {
join_all(tasks).await;
})
}
pub fn add_selection(
pub fn add_excerpt(
&mut self,
buffer: Entity<Buffer>,
range: Range<Anchor>,
buffer: Entity<Buffer>,
cx: &mut Context<ContextStore>,
) -> Task<Result<()>> {
cx.spawn(async move |this, cx| {
let (line_range, context_buffer_task) = this.update(cx, |_, cx| {
load_context_buffer_range(buffer, range.clone(), cx)
let (line_range, buffer_info, text_task) = this.update(cx, |_, cx| {
collect_buffer_info_and_text_for_range(buffer, range.clone(), cx)
})??;
let context_buffer = context_buffer_task.await;
let text = text_task.await;
this.update(cx, |this, cx| {
this.insert_selection(context_buffer, range, line_range, cx)
this.insert_excerpt(
make_context_buffer(buffer_info, text),
range,
line_range,
cx,
)
})?;
anyhow::Ok(())
})
}
fn insert_selection(
fn insert_excerpt(
&mut self,
context_buffer: ContextBuffer,
range: Range<Anchor>,
@@ -505,13 +453,12 @@ impl ContextStore {
cx: &mut Context<Self>,
) {
let id = self.next_context_id.post_inc();
self.context
.push(AssistantContext::Selection(SelectionContext {
id,
range,
line_range,
context_buffer,
}));
self.context.push(AssistantContext::Excerpt(ExcerptContext {
id,
range,
line_range,
context_buffer,
}));
cx.notify();
}
@@ -564,17 +511,13 @@ impl ContextStore {
self.symbol_buffers.remove(&symbol.context_symbol.id);
self.symbols.retain(|_, context_id| *context_id != id);
}
AssistantContext::Selection(_) => {}
AssistantContext::Excerpt(_) => {}
AssistantContext::FetchedUrl(_) => {
self.fetched_urls.retain(|_, context_id| *context_id != id);
}
AssistantContext::Thread(_) => {
self.threads.retain(|_, context_id| *context_id != id);
}
AssistantContext::Rules(RulesContext { prompt_id, .. }) => {
self.user_rules.remove(&prompt_id);
}
AssistantContext::Image(_) => {}
}
cx.notify();
@@ -671,10 +614,6 @@ impl ContextStore {
self.threads.get(thread_id).copied()
}
pub fn includes_user_rules(&self, prompt_id: &UserPromptId) -> Option<ContextId> {
self.user_rules.get(prompt_id).copied()
}
pub fn includes_url(&self, url: &str) -> Option<ContextId> {
self.fetched_urls.get(url).copied()
}
@@ -700,11 +639,9 @@ impl ContextStore {
}
AssistantContext::Directory(_)
| AssistantContext::Symbol(_)
| AssistantContext::Selection(_)
| AssistantContext::Excerpt(_)
| AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
| AssistantContext::Rules(_)
| AssistantContext::Image(_) => None,
| AssistantContext::Thread(_) => None,
})
.collect()
}
@@ -719,78 +656,92 @@ pub enum FileInclusion {
InDirectory(ProjectPath),
}
// ContextBuffer without text.
struct BufferInfo {
id: BufferId,
buffer: Entity<Buffer>,
file: Arc<dyn File>,
version: clock::Global,
}
fn make_context_buffer(info: BufferInfo, text: SharedString) -> ContextBuffer {
ContextBuffer {
id: info.id,
buffer: info.buffer,
file: info.file,
version: info.version,
text,
}
}
fn make_context_symbol(
context_buffer: ContextBuffer,
info: BufferInfo,
path: ProjectPath,
name: SharedString,
range: Range<Anchor>,
enclosing_range: Range<Anchor>,
text: SharedString,
) -> ContextSymbol {
ContextSymbol {
id: ContextSymbolId { name, range, path },
buffer_version: context_buffer.version,
buffer_version: info.version,
enclosing_range,
buffer: context_buffer.buffer,
text: context_buffer.text,
buffer: info.buffer,
text,
}
}
fn load_context_buffer_range(
fn collect_buffer_info_and_text_for_range(
buffer: Entity<Buffer>,
range: Range<Anchor>,
cx: &App,
) -> Result<(Range<Point>, Task<ContextBuffer>)> {
let buffer_ref = buffer.read(cx);
let id = buffer_ref.remote_id();
) -> Result<(Range<Point>, BufferInfo, Task<SharedString>)> {
let content = buffer
.read(cx)
.text_for_range(range.clone())
.collect::<Rope>();
let file = buffer_ref.file().context("context buffer missing path")?;
let full_path = file.full_path(cx);
let line_range = range.to_point(&buffer.read(cx).snapshot());
// Important to collect version at the same time as content so that staleness logic is correct.
let version = buffer_ref.version();
let content = buffer_ref.text_for_range(range.clone()).collect::<Rope>();
let line_range = range.to_point(&buffer_ref.snapshot());
let buffer_info = collect_buffer_info(buffer, cx)?;
let full_path = buffer_info.file.full_path(cx);
// Build the text on a background thread.
let task = cx.background_spawn({
let text_task = cx.background_spawn({
let line_range = line_range.clone();
async move {
let text = to_fenced_codeblock(&full_path, content, Some(line_range));
ContextBuffer {
id,
buffer,
last_full_path: full_path.into(),
version,
text,
}
}
async move { to_fenced_codeblock(&full_path, content, Some(line_range)) }
});
Ok((line_range, task))
Ok((line_range, buffer_info, text_task))
}
fn load_context_buffer(buffer: Entity<Buffer>, cx: &App) -> Result<Task<ContextBuffer>> {
let buffer_ref = buffer.read(cx);
let id = buffer_ref.remote_id();
fn collect_buffer_info_and_text(
buffer: Entity<Buffer>,
cx: &App,
) -> Result<(BufferInfo, Task<SharedString>)> {
let content = buffer.read(cx).as_rope().clone();
let file = buffer_ref.file().context("context buffer missing path")?;
let full_path = file.full_path(cx);
let buffer_info = collect_buffer_info(buffer, cx)?;
let full_path = buffer_info.file.full_path(cx);
let text_task =
cx.background_spawn(async move { to_fenced_codeblock(&full_path, content, None) });
Ok((buffer_info, text_task))
}
fn collect_buffer_info(buffer: Entity<Buffer>, cx: &App) -> Result<BufferInfo> {
let buffer_ref = buffer.read(cx);
let file = buffer_ref.file().context("file context must have a path")?;
// Important to collect version at the same time as content so that staleness logic is correct.
let version = buffer_ref.version();
let content = buffer_ref.as_rope().clone();
// Build the text on a background thread.
Ok(cx.background_spawn(async move {
let text = to_fenced_codeblock(&full_path, content, None);
ContextBuffer {
id,
buffer,
last_full_path: full_path.into(),
version,
text,
}
}))
Ok(BufferInfo {
buffer,
id: buffer_ref.remote_id(),
file: file.clone(),
version,
})
}
fn to_fenced_codeblock(
@@ -877,7 +828,6 @@ pub fn refresh_context_store_text(
let task = maybe!({
match context {
AssistantContext::File(file_context) => {
// TODO: Should refresh if the path has changed, as it's in the text.
if changed_buffers.is_empty()
|| changed_buffers.contains(&file_context.context_buffer.buffer)
{
@@ -886,9 +836,8 @@ pub fn refresh_context_store_text(
}
}
AssistantContext::Directory(directory_context) => {
let directory_path = directory_context.project_path(cx)?;
let should_refresh = directory_path.path != directory_context.last_path
|| changed_buffers.is_empty()
let directory_path = directory_context.project_path(cx);
let should_refresh = changed_buffers.is_empty()
|| changed_buffers.iter().any(|buffer| {
let Some(buffer_path) = buffer.read(cx).project_path(cx) else {
return false;
@@ -898,16 +847,10 @@ pub fn refresh_context_store_text(
if should_refresh {
let context_store = context_store.clone();
return refresh_directory_text(
context_store,
directory_context,
directory_path,
cx,
);
return refresh_directory_text(context_store, directory_context, cx);
}
}
AssistantContext::Symbol(symbol_context) => {
// TODO: Should refresh if the path has changed, as it's in the text.
if changed_buffers.is_empty()
|| changed_buffers.contains(&symbol_context.context_symbol.buffer)
{
@@ -915,13 +858,12 @@ pub fn refresh_context_store_text(
return refresh_symbol_text(context_store, symbol_context, cx);
}
}
AssistantContext::Selection(selection_context) => {
// TODO: Should refresh if the path has changed, as it's in the text.
AssistantContext::Excerpt(excerpt_context) => {
if changed_buffers.is_empty()
|| changed_buffers.contains(&selection_context.context_buffer.buffer)
|| changed_buffers.contains(&excerpt_context.context_buffer.buffer)
{
let context_store = context_store.clone();
return refresh_selection_text(context_store, selection_context, cx);
return refresh_excerpt_text(context_store, excerpt_context, cx);
}
}
AssistantContext::Thread(thread_context) => {
@@ -934,11 +876,6 @@ pub fn refresh_context_store_text(
// and doing the caching properly could be tricky (unless it's already handled by
// the HttpClient?).
AssistantContext::FetchedUrl(_) => {}
AssistantContext::Rules(user_rules_context) => {
let context_store = context_store.clone();
return Some(refresh_user_rules(context_store, user_rules_context, cx));
}
AssistantContext::Image(_) => {}
}
None
@@ -977,7 +914,6 @@ fn refresh_file_text(
fn refresh_directory_text(
context_store: Entity<ContextStore>,
directory_context: &DirectoryContext,
directory_path: ProjectPath,
cx: &App,
) -> Option<Task<()>> {
let mut stale = false;
@@ -1002,8 +938,7 @@ fn refresh_directory_text(
let id = directory_context.id;
let worktree = directory_context.worktree.clone();
let entry_id = directory_context.entry_id;
let last_path = directory_path.path;
let path = directory_context.path.clone();
Some(cx.spawn(async move |cx| {
let context_buffers = context_buffers.await;
context_store
@@ -1011,8 +946,7 @@ fn refresh_directory_text(
let new_directory_context = DirectoryContext {
id,
worktree,
entry_id,
last_path,
path,
context_buffers,
};
context_store.replace_context(AssistantContext::Directory(new_directory_context));
@@ -1043,27 +977,26 @@ fn refresh_symbol_text(
}
}
fn refresh_selection_text(
fn refresh_excerpt_text(
context_store: Entity<ContextStore>,
selection_context: &SelectionContext,
excerpt_context: &ExcerptContext,
cx: &App,
) -> Option<Task<()>> {
let id = selection_context.id;
let range = selection_context.range.clone();
let task = refresh_context_excerpt(&selection_context.context_buffer, range.clone(), cx);
let id = excerpt_context.id;
let range = excerpt_context.range.clone();
let task = refresh_context_excerpt(&excerpt_context.context_buffer, range.clone(), cx);
if let Some(task) = task {
Some(cx.spawn(async move |cx| {
let (line_range, context_buffer) = task.await;
context_store
.update(cx, |context_store, _| {
let new_selection_context = SelectionContext {
let new_excerpt_context = ExcerptContext {
id,
range,
line_range,
context_buffer,
};
context_store
.replace_context(AssistantContext::Selection(new_selection_context));
context_store.replace_context(AssistantContext::Excerpt(new_excerpt_context));
})
.ok();
}))
@@ -1093,49 +1026,15 @@ fn refresh_thread_text(
})
}
fn refresh_user_rules(
context_store: Entity<ContextStore>,
user_rules_context: &RulesContext,
fn refresh_context_buffer(
context_buffer: &ContextBuffer,
cx: &App,
) -> Task<()> {
let id = user_rules_context.id;
let prompt_id = user_rules_context.prompt_id;
let Some(thread_store) = context_store.read(cx).thread_store.as_ref() else {
return Task::ready(());
};
let Ok(load_task) = thread_store.read_with(cx, |thread_store, cx| {
thread_store.load_rules(prompt_id, cx)
}) else {
return Task::ready(());
};
cx.spawn(async move |cx| {
if let Ok((metadata, text)) = load_task.await {
if let Some(title) = metadata.title.clone() {
context_store
.update(cx, |context_store, _cx| {
context_store.replace_context(AssistantContext::Rules(RulesContext {
id,
prompt_id,
title,
text: text.into(),
}));
})
.ok();
return;
}
}
context_store
.update(cx, |context_store, cx| {
context_store.remove_context(id, cx);
})
.ok();
})
}
fn refresh_context_buffer(context_buffer: &ContextBuffer, cx: &App) -> Option<Task<ContextBuffer>> {
) -> Option<impl Future<Output = ContextBuffer> + use<>> {
let buffer = context_buffer.buffer.read(cx);
if buffer.version.changed_since(&context_buffer.version) {
load_context_buffer(context_buffer.buffer.clone(), cx).log_err()
let (buffer_info, text_task) =
collect_buffer_info_and_text(context_buffer.buffer.clone(), cx).log_err()?;
Some(text_task.map(move |text| make_context_buffer(buffer_info, text)))
} else {
None
}
@@ -1148,9 +1047,10 @@ fn refresh_context_excerpt(
) -> Option<impl Future<Output = (Range<Point>, ContextBuffer)> + use<>> {
let buffer = context_buffer.buffer.read(cx);
if buffer.version.changed_since(&context_buffer.version) {
let (line_range, context_buffer_task) =
load_context_buffer_range(context_buffer.buffer.clone(), range, cx).log_err()?;
Some(context_buffer_task.map(move |context_buffer| (line_range, context_buffer)))
let (line_range, buffer_info, text_task) =
collect_buffer_info_and_text_for_range(context_buffer.buffer.clone(), range, cx)
.log_err()?;
Some(text_task.map(move |text| (line_range, make_context_buffer(buffer_info, text))))
} else {
None
}
@@ -1163,7 +1063,7 @@ fn refresh_context_symbol(
let buffer = context_symbol.buffer.read(cx);
let project_path = buffer.project_path(cx)?;
if buffer.version.changed_since(&context_symbol.buffer_version) {
let (_line_range, context_buffer_task) = load_context_buffer_range(
let (_, buffer_info, text_task) = collect_buffer_info_and_text_for_range(
context_symbol.buffer.clone(),
context_symbol.enclosing_range.clone(),
cx,
@@ -1172,8 +1072,15 @@ fn refresh_context_symbol(
let name = context_symbol.id.name.clone();
let range = context_symbol.id.range.clone();
let enclosing_range = context_symbol.enclosing_range.clone();
Some(context_buffer_task.map(move |context_buffer| {
make_context_symbol(context_buffer, project_path, name, range, enclosing_range)
Some(text_task.map(move |text| {
make_context_symbol(
buffer_info,
project_path,
name,
range,
enclosing_range,
text,
)
}))
} else {
None

View File

@@ -24,7 +24,6 @@ use gpui::{
WeakEntity, Window, point,
};
use language::{Buffer, Point, Selection, TransactionId};
use language_model::ConfiguredModel;
use language_model::{LanguageModelRegistry, report_assistant_event};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
@@ -1222,15 +1221,9 @@ impl InlineAssistant {
self.prompt_history.pop_front();
}
let Some(ConfiguredModel { model, .. }) =
LanguageModelRegistry::read_global(cx).inline_assistant_model()
else {
return;
};
assist
.codegen
.update(cx, |codegen, cx| codegen.start(model, user_prompt, cx))
.update(cx, |codegen, cx| codegen.start(user_prompt, cx))
.log_err();
}

View File

@@ -1,4 +1,4 @@
use crate::assistant_model_selector::AssistantModelSelector;
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@@ -20,7 +20,7 @@ use gpui::{
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
};
use language_model::{LanguageModel, LanguageModelRegistry};
use language_model_selector::{ModelType, ToggleModelSelector};
use language_model_selector::ToggleModelSelector;
use parking_lot::Mutex;
use settings::Settings;
use std::cmp;

View File

@@ -6,7 +6,7 @@ use crate::context::{AssistantContext, format_context_as_string};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
use collections::HashSet;
use editor::actions::{MoveUp, Paste};
use editor::actions::MoveUp;
use editor::{
ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorEvent, EditorMode,
EditorStyle, MultiBuffer,
@@ -14,8 +14,8 @@ use editor::{
use file_icons::FileIcons;
use fs::Fs;
use gpui::{
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
Animation, AnimationExt, App, Entity, EventEmitter, Focusable, Subscription, Task, TextStyle,
WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language};
use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage};
@@ -271,7 +271,6 @@ impl MessageEditor {
let refresh_task =
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
let thread = self.thread.clone();
let context_store = self.context_store.clone();
@@ -281,7 +280,6 @@ impl MessageEditor {
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
refresh_task.await;
wait_for_images.await;
thread
.update(cx, |thread, cx| {
@@ -295,12 +293,7 @@ impl MessageEditor {
let excerpt_ids = context_store
.context()
.iter()
.filter(|ctx| {
matches!(
ctx,
AssistantContext::Selection(_) | AssistantContext::Image(_)
)
})
.filter(|ctx| matches!(ctx, AssistantContext::Excerpt(_)))
.map(|ctx| ctx.id())
.collect::<Vec<_>>();
@@ -377,38 +370,8 @@ impl MessageEditor {
}
}
fn paste(&mut self, _: &Paste, _: &mut Window, cx: &mut Context<Self>) {
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
self.context_store.update(cx, |store, cx| {
for image in images {
store.add_image(Arc::new(image), cx);
}
});
}
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.edits_expanded = true;
fn handle_review_click(&self, window: &mut Window, cx: &mut Context<Self>) {
AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
cx.notify();
}
fn handle_file_click(
@@ -482,7 +445,6 @@ impl MessageEditor {
.on_action(cx.listener(Self::move_up))
.on_action(cx.listener(Self::toggle_chat_mode))
.on_action(cx.listener(Self::expand_message_editor))
.capture_action(cx.listener(Self::paste))
.gap_2()
.p_2()
.bg(editor_bg_color)

View File

@@ -16,7 +16,7 @@ use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
@@ -38,7 +38,7 @@ use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
@@ -97,7 +97,6 @@ pub struct Message {
pub role: Role,
pub segments: Vec<MessageSegment>,
pub context: String,
pub images: Vec<LanguageModelImage>,
}
impl Message {
@@ -168,9 +167,12 @@ pub enum MessageSegment {
impl MessageSegment {
pub fn should_display(&self) -> bool {
// We add USING_TOOL_MARKER when making a request that includes tool uses
// without non-whitespace text around them, and this can cause the model
// to mimic the pattern, so we consider those segments not displayable.
match self {
Self::Text(text) => text.is_empty(),
Self::Thinking { text, .. } => text.is_empty(),
Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
Self::Thinking { text, .. } => text.is_empty() || text.trim() == USING_TOOL_MARKER,
Self::RedactedThinking(_) => false,
}
}
@@ -315,7 +317,6 @@ pub struct Thread {
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
remaining_turns: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -369,7 +370,6 @@ impl Thread {
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
}
}
@@ -418,7 +418,6 @@ impl Thread {
})
.collect(),
context: message.context,
images: Vec::new(),
})
.collect(),
next_message_id,
@@ -444,7 +443,6 @@ impl Thread {
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
}
}
@@ -525,7 +523,7 @@ impl Thread {
self.messages.iter().find(|message| message.id == id)
}
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
pub fn messages(&self) -> impl Iterator<Item = &Message> {
self.messages.iter()
}
@@ -619,12 +617,24 @@ impl Thread {
.await
.unwrap_or(false);
if !equal {
if equal {
git_store
.update(cx, |store, cx| {
store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
})?
.detach();
} else {
this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
})?;
}
git_store
.update(cx, |store, cx| {
store.delete_checkpoint(final_checkpoint, cx)
})?
.detach();
Ok(())
}
Err(_) => this.update(cx, |this, cx| {
@@ -740,41 +750,31 @@ impl Thread {
}
}
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
message.images = new_context
.iter()
.filter_map(|context| {
if let AssistantContext::Image(image_context) = context {
image_context.image_task.clone().now_or_never().flatten()
} else {
None
}
})
.collect::<Vec<_>>();
}
self.action_log.update(cx, |log, cx| {
// Track all buffers added as context
for ctx in &new_context {
match ctx {
AssistantContext::File(file_ctx) => {
log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
}
AssistantContext::Directory(dir_ctx) => {
for context_buffer in &dir_ctx.context_buffers {
log.track_buffer(context_buffer.buffer.clone(), cx);
log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
}
}
AssistantContext::Symbol(symbol_ctx) => {
log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
log.buffer_added_as_context(
symbol_ctx.context_symbol.buffer.clone(),
cx,
);
}
AssistantContext::Selection(selection_context) => {
log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
AssistantContext::Excerpt(excerpt_context) => {
log.buffer_added_as_context(
excerpt_context.context_buffer.buffer.clone(),
cx,
);
}
AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
| AssistantContext::Rules(_)
| AssistantContext::Image(_) => {}
AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
}
}
});
@@ -815,7 +815,6 @@ impl Thread {
role,
segments,
context: String::new(),
images: Vec::new(),
});
self.touch_updated_at();
cx.emit(ThreadEvent::MessageAdded(id));
@@ -943,21 +942,7 @@ impl Thread {
})
}
pub fn remaining_turns(&self) -> u32 {
self.remaining_turns
}
pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
self.remaining_turns = remaining_turns;
}
pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
if self.remaining_turns == 0 {
return;
}
self.remaining_turns -= 1;
let mut request = self.to_completion_request(cx);
if model.supports_tools() {
request.tools = {
@@ -1053,21 +1038,6 @@ impl Thread {
.push(MessageContent::Text(message.context.to_string()));
}
if !message.images.is_empty() {
// Some providers only support image parts after an initial text part
if request_message.content.is_empty() {
request_message
.content
.push(MessageContent::Text("Images attached by user:".to_string()));
}
for image in &message.images {
request_message
.content
.push(MessageContent::Image(image.clone()))
}
}
for segment in &message.segments {
match segment {
MessageSegment::Text(text) => {
@@ -1210,12 +1180,6 @@ impl Thread {
None
};
let prompt_id = self.last_prompt_id.clone();
let tool_use_metadata = ToolUseMetadata {
model: model.clone(),
thread_id: self.id.clone(),
prompt_id: prompt_id.clone(),
};
let task = cx.spawn(async move |thread, cx| {
let stream_completion_future = model.stream_completion_with_usage(request, &cx);
let initial_token_usage =
@@ -1262,7 +1226,6 @@ impl Thread {
current_token_usage = token_usage;
}
LanguageModelCompletionEvent::Text(chunk) => {
cx.emit(ThreadEvent::ReceivedTextChunk);
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.push_text(&chunk);
@@ -1322,27 +1285,11 @@ impl Thread {
thread.insert_message(Role::Assistant, vec![], cx)
});
let tool_use_id = tool_use.id.clone();
let streamed_input = if tool_use.is_input_complete {
None
} else {
Some((&tool_use.input).clone())
};
let ui_text = thread.tool_use.request_tool_use(
thread.tool_use.request_tool_use(
last_assistant_message_id,
tool_use,
tool_use_metadata.clone(),
cx,
);
if let Some(input) = streamed_input {
cx.emit(ThreadEvent::StreamedToolUse {
tool_use_id,
ui_text,
input,
});
}
}
}
@@ -1812,7 +1759,7 @@ impl Thread {
thread_data,
final_project_snapshot
);
client.telemetry().flush_events().await;
client.telemetry().flush_events();
Ok(())
})
@@ -1857,7 +1804,7 @@ impl Thread {
thread_data,
final_project_snapshot
);
client.telemetry().flush_events().await;
client.telemetry().flush_events();
Ok(())
})
@@ -2113,7 +2060,7 @@ impl Thread {
github_login = github_login
);
client.telemetry().flush_events().await;
client.telemetry().flush_events();
}
}
})
@@ -2231,14 +2178,8 @@ pub enum ThreadEvent {
ShowError(ThreadError),
UsageUpdated(RequestUsage),
StreamedCompletion,
ReceivedTextChunk,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
StreamedToolUse {
tool_use_id: LanguageModelToolUseId,
ui_text: Arc<str>,
input: serde_json::Value,
},
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),

View File

@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
use project::{Project, Worktree};
use prompt_store::{
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
DefaultUserRulesContext, ProjectContext, PromptBuilder, PromptId, PromptStore,
PromptsUpdatedEvent, RulesFileContext, WorktreeContext,
};
use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
@@ -62,7 +62,6 @@ pub struct ThreadStore {
project: Entity<Project>,
tools: Entity<ToolWorkingSet>,
prompt_builder: Arc<PromptBuilder>,
prompt_store: Option<Entity<PromptStore>>,
context_server_manager: Entity<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
threads: Vec<SerializedThreadMetadata>,
@@ -136,7 +135,6 @@ impl ThreadStore {
let (ready_tx, ready_rx) = oneshot::channel();
let mut ready_tx = Some(ready_tx);
let reload_system_prompt_task = cx.spawn({
let prompt_store = prompt_store.clone();
async move |thread_store, cx| {
loop {
let Some(reload_task) = thread_store
@@ -160,7 +158,6 @@ impl ThreadStore {
project,
tools,
prompt_builder,
prompt_store,
context_server_manager,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
@@ -248,7 +245,7 @@ impl ThreadStore {
let default_user_rules = default_user_rules
.into_iter()
.flat_map(|(contents, prompt_metadata)| match contents {
Ok(contents) => Some(UserRulesContext {
Ok(contents) => Some(DefaultUserRulesContext {
uuid: match prompt_metadata.id {
PromptId::User { uuid } => uuid,
PromptId::EditWorkflow => return None,
@@ -349,27 +346,6 @@ impl ThreadStore {
self.context_server_manager.clone()
}
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
self.prompt_store.clone()
}
pub fn load_rules(
&self,
prompt_id: UserPromptId,
cx: &App,
) -> Task<Result<(PromptMetadata, String)>> {
let prompt_id = PromptId::User { uuid: prompt_id };
let Some(prompt_store) = self.prompt_store.as_ref() else {
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
};
let prompt_store = prompt_store.read(cx);
let Some(metadata) = prompt_store.metadata(prompt_id) else {
return Task::ready(Err(anyhow!("User rules not found in library.")));
};
let text_task = prompt_store.load(prompt_id, cx);
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
}
pub fn tools(&self) -> Entity<ToolWorkingSet> {
self.tools.clone()
}

View File

@@ -7,13 +7,13 @@ use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task};
use language_model::{
LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
};
use ui::IconName;
use util::truncate_lines_to_byte_limit;
use crate::thread::{MessageId, PromptId, ThreadId};
use crate::thread::MessageId;
use crate::thread_store::SerializedMessage;
#[derive(Debug)]
@@ -27,6 +27,8 @@ pub struct ToolUse {
pub needs_confirmation: bool,
}
pub const USING_TOOL_MARKER: &str = "<using_tool>";
pub struct ToolUseState {
tools: Entity<ToolWorkingSet>,
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
@@ -34,7 +36,6 @@ pub struct ToolUseState {
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
tool_use_metadata_by_id: HashMap<LanguageModelToolUseId, ToolUseMetadata>,
}
impl ToolUseState {
@@ -46,7 +47,6 @@ impl ToolUseState {
tool_results: HashMap::default(),
pending_tool_uses_by_id: HashMap::default(),
tool_result_cards: HashMap::default(),
tool_use_metadata_by_id: HashMap::default(),
}
}
@@ -73,7 +73,6 @@ impl ToolUseState {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
input: tool_use.input.clone(),
is_input_complete: true,
})
.collect::<Vec<_>>();
@@ -175,9 +174,6 @@ impl ToolUseState {
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
PendingToolUseStatus::InputStillStreaming => {
ToolUseStatus::InputStillStreaming
}
}
} else {
ToolUseStatus::Pending
@@ -194,12 +190,7 @@ impl ToolUseState {
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
ui_text: self.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
),
ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
input: tool_use.input.clone(),
status,
icon,
@@ -214,15 +205,10 @@ impl ToolUseState {
&self,
tool_name: &str,
input: &serde_json::Value,
is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
if is_input_complete {
tool.ui_text(input).into()
} else {
tool.still_streaming_ui_text(input).into()
}
tool.ui_text(input).into()
} else {
format!("Unknown tool {tool_name:?}").into()
}
@@ -268,52 +254,20 @@ impl ToolUseState {
&mut self,
assistant_message_id: MessageId,
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
) -> Arc<str> {
let tool_uses = self
.tool_uses_by_assistant_message
) {
self.tool_uses_by_assistant_message
.entry(assistant_message_id)
.or_default();
.or_default()
.push(tool_use.clone());
let mut existing_tool_use_found = false;
for existing_tool_use in tool_uses.iter_mut() {
if existing_tool_use.id == tool_use.id {
*existing_tool_use = tool_use.clone();
existing_tool_use_found = true;
}
}
if !existing_tool_use_found {
tool_uses.push(tool_use.clone());
}
let status = if tool_use.is_input_complete {
self.tool_use_metadata_by_id
.insert(tool_use.id.clone(), metadata);
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
PendingToolUseStatus::Idle
} else {
PendingToolUseStatus::InputStillStreaming
};
let ui_text: Arc<str> = self
.tool_ui_label(
&tool_use.name,
&tool_use.input,
tool_use.is_input_complete,
cx,
)
.into();
// The tool use is being requested by the Assistant, so we want to
// attach the tool results to the next user message.
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
self.tool_uses_by_user_message
.entry(next_user_message_id)
.or_default()
.push(tool_use.id.clone());
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
@@ -321,13 +275,13 @@ impl ToolUseState {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
ui_text: ui_text.clone(),
ui_text: self
.tool_ui_label(&tool_use.name, &tool_use.input, cx)
.into(),
input: tool_use.input,
status,
status: PendingToolUseStatus::Idle,
},
);
ui_text
}
pub fn run_pending_tool(
@@ -373,21 +327,7 @@ impl ToolUseState {
output: Result<String>,
cx: &App,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
telemetry::event!(
"Agent Tool Finished",
model = metadata
.as_ref()
.map(|metadata| metadata.model.telemetry_id()),
model_provider = metadata
.as_ref()
.map(|metadata| metadata.model.provider_id().to_string()),
thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()),
prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()),
tool_name,
success = output.is_ok()
);
telemetry::event!("Agent Tool Finished", tool_name, success = output.is_ok());
match output {
Ok(tool_result) => {
@@ -450,8 +390,28 @@ impl ToolUseState {
request_message: &mut LanguageModelRequestMessage,
) {
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
let mut found_tool_use = false;
for tool_use in tool_uses {
if self.tool_results.contains_key(&tool_use.id) {
if !found_tool_use {
// The API fails if a message contains a tool use without any (non-whitespace) text around it
match request_message.content.last_mut() {
Some(MessageContent::Text(txt)) => {
if txt.is_empty() {
txt.push_str(USING_TOOL_MARKER);
}
}
None | Some(_) => {
request_message
.content
.push(MessageContent::Text(USING_TOOL_MARKER.into()));
}
};
}
found_tool_use = true;
// Do not send tool uses until they are completed
request_message
.content
@@ -517,7 +477,6 @@ pub struct Confirmation {
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
@@ -537,10 +496,3 @@ impl PendingToolUseStatus {
matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
}
}
#[derive(Clone)]
pub struct ToolUseMetadata {
pub model: Arc<dyn LanguageModel>,
pub thread_id: ThreadId,
pub prompt_id: PromptId,
}

View File

@@ -1,9 +1,7 @@
mod agent_notification;
mod animated_label;
mod context_pill;
mod usage_banner;
pub use agent_notification::*;
pub use animated_label::*;
pub use context_pill::*;
pub use usage_banner::*;

View File

@@ -12,7 +12,6 @@ pub struct AgentNotification {
title: SharedString,
caption: SharedString,
icon: IconName,
project_name: Option<SharedString>,
}
impl AgentNotification {
@@ -20,13 +19,11 @@ impl AgentNotification {
title: impl Into<SharedString>,
caption: impl Into<SharedString>,
icon: IconName,
project_name: Option<impl Into<SharedString>>,
) -> Self {
Self {
title: title.into(),
caption: caption.into(),
icon,
project_name: project_name.map(|name| name.into()),
}
}
@@ -133,34 +130,11 @@ impl Render for AgentNotification {
.child(gradient_overflow()),
)
.child(
h_flex()
div()
.relative()
.gap_1p5()
.text_size(px(12.))
.text_color(cx.theme().colors().text_muted)
.truncate()
.when_some(
self.project_name.clone(),
|description, project_name| {
description.child(
h_flex()
.gap_1p5()
.child(
div()
.max_w_16()
.truncate()
.child(project_name),
)
.child(
div().size(px(3.)).rounded_full().bg(cx
.theme()
.colors()
.text
.opacity(0.5)),
),
)
},
)
.child(self.caption.clone())
.child(gradient_overflow()),
),

View File

@@ -1,116 +0,0 @@
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
use std::time::Duration;
use ui::prelude::*;
#[derive(IntoElement)]
pub struct AnimatedLabel {
base: Label,
text: SharedString,
}
impl AnimatedLabel {
pub fn new(text: impl Into<SharedString>) -> Self {
let text = text.into();
AnimatedLabel {
base: Label::new(text.clone()),
text,
}
}
}
impl LabelCommon for AnimatedLabel {
fn size(mut self, size: LabelSize) -> Self {
self.base = self.base.size(size);
self
}
fn weight(mut self, weight: FontWeight) -> Self {
self.base = self.base.weight(weight);
self
}
fn line_height_style(mut self, line_height_style: LineHeightStyle) -> Self {
self.base = self.base.line_height_style(line_height_style);
self
}
fn color(mut self, color: Color) -> Self {
self.base = self.base.color(color);
self
}
fn strikethrough(mut self) -> Self {
self.base = self.base.strikethrough();
self
}
fn italic(mut self) -> Self {
self.base = self.base.italic();
self
}
fn alpha(mut self, alpha: f32) -> Self {
self.base = self.base.alpha(alpha);
self
}
fn underline(mut self) -> Self {
self.base = self.base.underline();
self
}
fn truncate(mut self) -> Self {
self.base = self.base.truncate();
self
}
fn single_line(mut self) -> Self {
self.base = self.base.single_line();
self
}
fn buffer_font(mut self, cx: &App) -> Self {
self.base = self.base.buffer_font(cx);
self
}
}
impl RenderOnce for AnimatedLabel {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let text = self.text.clone();
self.base
.color(Color::Muted)
.with_animations(
"animated-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),
],
move |mut label, animation_ix, delta| {
match animation_ix {
0 => {
let chars_to_show = (delta * text.len() as f32).ceil() as usize;
let text = SharedString::from(text[0..chars_to_show].to_string());
label.set_text(text);
}
1 => match delta {
d if d < 0.25 => label.set_text(text.clone()),
d if d < 0.5 => label.set_text(format!("{}.", text)),
d if d < 0.75 => label.set_text(format!("{}..", text)),
_ => label.set_text(format!("{}...", text)),
},
_ => {}
}
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.map_element(|label| label.alpha(delta)),
)
}
}

View File

@@ -1,14 +1,11 @@
use std::sync::Arc;
use std::{rc::Rc, time::Duration};
use file_icons::FileIcons;
use futures::FutureExt;
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
use gpui::{ClickEvent, Task};
use language_model::LanguageModelImage;
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
use gpui::ClickEvent;
use gpui::{Animation, AnimationExt as _, pulsating_between};
use ui::{IconButtonShape, Tooltip, prelude::*};
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
use crate::context::{AssistantContext, ContextId, ContextKind};
#[derive(IntoElement)]
pub enum ContextPill {
@@ -123,100 +120,74 @@ impl RenderOnce for ContextPill {
on_remove,
focused,
on_click,
} => {
let status_is_error = matches!(context.status, ContextStatus::Error { .. });
base_pill
.pr(if on_remove.is_some() { px(2.) } else { px(4.) })
.map(|pill| {
if status_is_error {
pill.bg(cx.theme().status().error_background)
.border_color(cx.theme().status().error_border)
} else if *focused {
pill.bg(color.element_background)
.border_color(color.border_focused)
} else {
pill.bg(color.element_background)
.border_color(color.border.opacity(0.5))
}
})
.child(
h_flex()
.id("context-data")
.gap_1()
.child(
div().max_w_64().child(
Label::new(context.name.clone())
.size(LabelSize::Small)
.truncate(),
),
)
.when_some(context.parent.as_ref(), |element, parent_name| {
if *dupe_name {
element.child(
Label::new(parent_name.clone())
.size(LabelSize::XSmall)
.color(Color::Muted),
)
} else {
element
}
})
.when_some(context.tooltip.as_ref(), |element, tooltip| {
element.tooltip(Tooltip::text(tooltip.clone()))
})
.map(|element| match &context.status {
ContextStatus::Ready => element
.when_some(
context.render_preview.as_ref(),
|element, render_preview| {
element.hoverable_tooltip({
let render_preview = render_preview.clone();
move |_, cx| {
cx.new(|_| ContextPillPreview {
render_preview: render_preview.clone(),
})
.into()
}
})
},
)
.into_any(),
ContextStatus::Loading { message } => element
.tooltip(ui::Tooltip::text(message.clone()))
.with_animation(
"pulsating-ctx-pill",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.opacity(delta),
)
.into_any_element(),
ContextStatus::Error { message } => element
.tooltip(ui::Tooltip::text(message.clone()))
.into_any_element(),
} => base_pill
.bg(color.element_background)
.border_color(if *focused {
color.border_focused
} else {
color.border.opacity(0.5)
})
.pr(if on_remove.is_some() { px(2.) } else { px(4.) })
.child(
h_flex()
.id("context-data")
.gap_1()
.child(
div().max_w_64().child(
Label::new(context.name.clone())
.size(LabelSize::Small)
.truncate(),
),
)
.when_some(context.parent.as_ref(), |element, parent_name| {
if *dupe_name {
element.child(
Label::new(parent_name.clone())
.size(LabelSize::XSmall)
.color(Color::Muted),
)
} else {
element
}
})
.when_some(context.tooltip.as_ref(), |element, tooltip| {
element.tooltip(Tooltip::text(tooltip.clone()))
}),
)
.when_some(on_remove.as_ref(), |element, on_remove| {
element.child(
IconButton::new(("remove", context.id.0), IconName::Close)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
)
.when_some(on_remove.as_ref(), |element, on_remove| {
element.child(
IconButton::new(("remove", context.id.0), IconName::Close)
.shape(IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.tooltip(Tooltip::text("Remove Context"))
.on_click({
let on_remove = on_remove.clone();
move |event, window, cx| on_remove(event, window, cx)
}),
)
})
.when_some(on_click.as_ref(), |element, on_click| {
let on_click = on_click.clone();
})
.when_some(on_click.as_ref(), |element, on_click| {
let on_click = on_click.clone();
element
.cursor_pointer()
.on_click(move |event, window, cx| on_click(event, window, cx))
})
.map(|element| {
if context.summarizing {
element
.cursor_pointer()
.on_click(move |event, window, cx| on_click(event, window, cx))
})
.into_any_element()
}
.tooltip(ui::Tooltip::text("Summarizing..."))
.with_animation(
"pulsating-ctx-pill",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.4, 0.8)),
|label, delta| label.opacity(delta),
)
.into_any_element()
} else {
element.into_any()
}
}),
ContextPill::Suggested {
name,
icon_path: _,
@@ -227,15 +198,15 @@ impl RenderOnce for ContextPill {
.cursor_pointer()
.pr_1()
.border_dashed()
.map(|pill| {
if *focused {
pill.border_color(color.border_focused)
.bg(color.element_background.opacity(0.5))
} else {
pill.border_color(color.border)
}
.border_color(if *focused {
color.border_focused
} else {
color.border
})
.hover(|style| style.bg(color.element_hover.opacity(0.5)))
.when(*focused, |this| {
this.bg(color.element_background.opacity(0.5))
})
.child(
div().max_w_64().child(
Label::new(name.clone())
@@ -256,13 +227,6 @@ impl RenderOnce for ContextPill {
}
}
pub enum ContextStatus {
Ready,
Loading { message: SharedString },
Error { message: SharedString },
}
#[derive(RegisterComponent)]
pub struct AddedContext {
pub id: ContextId,
pub kind: ContextKind,
@@ -270,15 +234,14 @@ pub struct AddedContext {
pub parent: Option<SharedString>,
pub tooltip: Option<SharedString>,
pub icon_path: Option<SharedString>,
pub status: ContextStatus,
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement + 'static>>,
pub summarizing: bool,
}
impl AddedContext {
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
match context {
AssistantContext::File(file_context) => {
let full_path = file_context.context_buffer.full_path(cx);
let full_path = file_context.context_buffer.file.full_path(cx);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -296,20 +259,15 @@ impl AddedContext {
parent,
tooltip: Some(full_path_string),
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: None,
summarizing: false,
}
}
AssistantContext::Directory(directory_context) => {
let worktree = directory_context.worktree.read(cx);
// If the directory no longer exists, use its last known path.
let full_path = worktree
.entry_for_id(directory_context.entry_id)
.map_or_else(
|| directory_context.last_path.clone(),
|entry| worktree.full_path(&entry.path).into(),
);
let full_path = directory_context
.worktree
.read(cx)
.full_path(&directory_context.path);
let full_path_string: SharedString =
full_path.to_string_lossy().into_owned().into();
let name = full_path
@@ -327,8 +285,7 @@ impl AddedContext {
parent,
tooltip: Some(full_path_string),
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
summarizing: false,
}
}
@@ -339,12 +296,11 @@ impl AddedContext {
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
summarizing: false,
},
AssistantContext::Selection(selection_context) => {
let full_path = selection_context.context_buffer.full_path(cx);
AssistantContext::Excerpt(excerpt_context) => {
let full_path = excerpt_context.context_buffer.file.full_path(cx);
let mut full_path_string = full_path.to_string_lossy().into_owned();
let mut name = full_path
.file_name()
@@ -353,8 +309,8 @@ impl AddedContext {
let line_range_text = format!(
" ({}-{})",
selection_context.line_range.start.row + 1,
selection_context.line_range.end.row + 1
excerpt_context.line_range.start.row + 1,
excerpt_context.line_range.end.row + 1
);
full_path_string.push_str(&line_range_text);
@@ -366,25 +322,13 @@ impl AddedContext {
.map(|n| n.to_string_lossy().into_owned().into());
AddedContext {
id: selection_context.id,
kind: ContextKind::Selection,
id: excerpt_context.id,
kind: ContextKind::File, // Use File icon for excerpts
name: name.into(),
parent,
tooltip: None,
tooltip: Some(full_path_string.into()),
icon_path: FileIcons::get_icon(&full_path, cx),
status: ContextStatus::Ready,
render_preview: Some(Rc::new({
let content = selection_context.context_buffer.text.clone();
move |_, cx| {
div()
.id("context-pill-selection-preview")
.overflow_scroll()
.max_w_128()
.max_h_96()
.child(Label::new(content.clone()).buffer_font(cx))
.into_any_element()
}
})),
summarizing: false,
}
}
@@ -395,8 +339,7 @@ impl AddedContext {
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
summarizing: false,
},
AssistantContext::Thread(thread_context) => AddedContext {
@@ -406,143 +349,11 @@ impl AddedContext {
parent: None,
tooltip: None,
icon_path: None,
status: if thread_context
summarizing: thread_context
.thread
.read(cx)
.is_generating_detailed_summary()
{
ContextStatus::Loading {
message: "Summarizing…".into(),
}
} else {
ContextStatus::Ready
},
render_preview: None,
},
AssistantContext::Rules(user_rules_context) => AddedContext {
id: user_rules_context.id,
kind: ContextKind::Rules,
name: user_rules_context.title.clone(),
parent: None,
tooltip: None,
icon_path: None,
status: ContextStatus::Ready,
render_preview: None,
},
AssistantContext::Image(image_context) => AddedContext {
id: image_context.id,
kind: ContextKind::Image,
name: "Image".into(),
parent: None,
tooltip: None,
icon_path: None,
status: if image_context.is_loading() {
ContextStatus::Loading {
message: "Loading…".into(),
}
} else if image_context.is_error() {
ContextStatus::Error {
message: "Failed to load image".into(),
}
} else {
ContextStatus::Ready
},
render_preview: Some(Rc::new({
let image = image_context.original_image.clone();
move |_, _| {
gpui::img(image.clone())
.max_w_96()
.max_h_96()
.into_any_element()
}
})),
.is_generating_detailed_summary(),
},
}
}
}
struct ContextPillPreview {
render_preview: Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>,
}
impl Render for ContextPillPreview {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
tooltip_container(window, cx, move |this, window, cx| {
this.occlude()
.on_mouse_move(|_, _, cx| cx.stop_propagation())
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
.child((self.render_preview)(window, cx))
})
}
}
impl Component for AddedContext {
fn scope() -> ComponentScope {
ComponentScope::Agent
}
fn sort_name() -> &'static str {
"AddedContext"
}
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
let image_ready = (
"Ready",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(0),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
}),
cx,
),
);
let image_loading = (
"Loading",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(1),
original_image: Arc::new(Image::empty()),
image_task: cx
.background_spawn(async move {
smol::Timer::after(Duration::from_secs(60 * 5)).await;
Some(LanguageModelImage::empty())
})
.shared(),
}),
cx,
),
);
let image_error = (
"Error",
AddedContext::new(
&AssistantContext::Image(ImageContext {
id: ContextId(2),
original_image: Arc::new(Image::empty()),
image_task: Task::ready(None).shared(),
}),
cx,
),
);
Some(
v_flex()
.gap_6()
.children(
vec![image_ready, image_loading, image_error]
.into_iter()
.map(|(text, context)| {
single_example(
text,
ContextPill::added(context, false, false, None).into_any_element(),
)
}),
)
.into_any(),
)
}
}

View File

@@ -1,30 +1,31 @@
use client::zed_urls;
use language_model::RequestUsage;
use ui::{Banner, ProgressBar, Severity, prelude::*};
use zed_llm_client::{Plan, UsageLimit};
#[derive(IntoElement, RegisterComponent)]
pub struct UsageBanner {
plan: Plan,
usage: RequestUsage,
requests: i32,
}
impl UsageBanner {
pub fn new(plan: Plan, usage: RequestUsage) -> Self {
Self { plan, usage }
pub fn new(plan: Plan, requests: i32) -> Self {
Self { plan, requests }
}
}
impl RenderOnce for UsageBanner {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let used_percentage = match self.usage.limit {
UsageLimit::Limited(limit) => Some((self.usage.amount as f32 / limit as f32) * 100.),
let request_limit = self.plan.model_requests_limit();
let used_percentage = match request_limit {
UsageLimit::Limited(limit) => Some((self.requests as f32 / limit as f32) * 100.),
UsageLimit::Unlimited => None,
};
let (severity, message) = match self.usage.limit {
let (severity, message) = match request_limit {
UsageLimit::Limited(limit) => {
if self.usage.amount >= limit {
if self.requests >= limit {
let message = match self.plan {
Plan::ZedPro => "Monthly request limit reached",
Plan::ZedProTrial => "Trial request limit reached",
@@ -32,7 +33,7 @@ impl RenderOnce for UsageBanner {
};
(Severity::Error, message)
} else if (self.usage.amount as f32 / limit as f32) >= 0.9 {
} else if (self.requests as f32 / limit as f32) >= 0.9 {
(Severity::Warning, "Approaching request limit")
} else {
let message = match self.plan {
@@ -80,11 +81,11 @@ impl RenderOnce for UsageBanner {
.child(ProgressBar::new("usage", percent, 100., cx))
}))
.child(
Label::new(match self.usage.limit {
Label::new(match request_limit {
UsageLimit::Limited(limit) => {
format!("{} / {limit}", self.usage.amount)
format!("{} / {limit}", self.requests)
}
UsageLimit::Unlimited => format!("{} / ∞", self.usage.amount),
UsageLimit::Unlimited => format!("{} / ∞", self.requests),
})
.size(LabelSize::Small)
.color(Color::Muted),
@@ -103,131 +104,74 @@ impl Component for UsageBanner {
}
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
let trial_limit = Plan::ZedProTrial.model_requests_limit();
let trial_examples = vec![
single_example(
"Zed Pro Trial - New User",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedProTrial,
RequestUsage {
limit: trial_limit,
amount: 10,
},
))
.child(UsageBanner::new(Plan::ZedProTrial, 10))
.into_any_element(),
),
single_example(
"Zed Pro Trial - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedProTrial,
RequestUsage {
limit: trial_limit,
amount: 135,
},
))
.child(UsageBanner::new(Plan::ZedProTrial, 135))
.into_any_element(),
),
single_example(
"Zed Pro Trial - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedProTrial,
RequestUsage {
limit: trial_limit,
amount: 150,
},
))
.child(UsageBanner::new(Plan::ZedProTrial, 150))
.into_any_element(),
),
];
let free_limit = Plan::Free.model_requests_limit();
let free_examples = vec![
single_example(
"Free - Normal Usage",
div()
.size_full()
.child(UsageBanner::new(
Plan::Free,
RequestUsage {
limit: free_limit,
amount: 25,
},
))
.child(UsageBanner::new(Plan::Free, 25))
.into_any_element(),
),
single_example(
"Free - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(
Plan::Free,
RequestUsage {
limit: free_limit,
amount: 45,
},
))
.child(UsageBanner::new(Plan::Free, 45))
.into_any_element(),
),
single_example(
"Free - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(
Plan::Free,
RequestUsage {
limit: free_limit,
amount: 50,
},
))
.child(UsageBanner::new(Plan::Free, 50))
.into_any_element(),
),
];
let zed_pro_limit = Plan::ZedPro.model_requests_limit();
let zed_pro_examples = vec![
single_example(
"Zed Pro - Normal Usage",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedPro,
RequestUsage {
limit: zed_pro_limit,
amount: 250,
},
))
.child(UsageBanner::new(Plan::ZedPro, 250))
.into_any_element(),
),
single_example(
"Zed Pro - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedPro,
RequestUsage {
limit: zed_pro_limit,
amount: 450,
},
))
.child(UsageBanner::new(Plan::ZedPro, 450))
.into_any_element(),
),
single_example(
"Zed Pro - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedPro,
RequestUsage {
limit: zed_pro_limit,
amount: 500,
},
))
.child(UsageBanner::new(Plan::ZedPro, 500))
.into_any_element(),
),
];

45
crates/agent2/Cargo.toml Normal file
View File

@@ -0,0 +1,45 @@
[package]
name = "agent2"
version = "0.1.0"
edition = "2021"
license = "GPL-3.0-or-later"
publish = false
[lib]
path = "src/agent2.rs"
[lints]
workspace = true
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
chrono.workspace = true
collections.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
language_models.workspace = true
parking_lot.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
thiserror.workspace = true
util.workspace = true
[dev-dependencies]
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }
gpui = { workspace = true, "features" = ["test-support"] }
gpui_tokio.workspace = true
language_model = { workspace = true, "features" = ["test-support"] }
project = { workspace = true, "features" = ["test-support"] }
reqwest_client.workspace = true
settings = { workspace = true, "features" = ["test-support"] }

1
crates/agent2/LICENSE-GPL Symbolic link
View File

@@ -0,0 +1 @@
../../LICENSE-GPL

278
crates/agent2/src/agent2.rs Normal file
View File

@@ -0,0 +1,278 @@
#[cfg(test)]
mod tests;
use anyhow::Result;
use assistant_tool::{ActionLog, Tool};
use futures::{channel::mpsc, future};
use gpui::{Context, Entity, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolSchemaFormat,
LanguageModelToolUse, MessageContent, Role, StopReason,
};
use project::Project;
use smol::stream::StreamExt;
use std::{collections::BTreeMap, sync::Arc};
use util::ResultExt;
#[derive(Debug)]
pub struct AgentMessage {
pub role: Role,
pub content: Vec<MessageContent>,
}
impl AgentMessage {
fn to_request_message(&self) -> LanguageModelRequestMessage {
LanguageModelRequestMessage {
role: self.role,
content: self.content.clone(),
cache: false, // TODO: Figure out caching
}
}
}
pub type AgentResponseEvent = LanguageModelCompletionEvent;
pub struct Agent {
messages: Vec<AgentMessage>,
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results.
running_turn: Option<Task<()>>,
tools: BTreeMap<Arc<str>, Arc<dyn Tool>>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
impl Agent {
pub fn new(project: Entity<Project>, action_log: Entity<ActionLog>) -> Self {
Self {
messages: Vec::new(),
running_turn: None,
tools: BTreeMap::default(),
project,
action_log,
}
}
pub fn add_tool(&mut self, tool: Arc<dyn Tool>) {
let name = Arc::from(tool.name());
self.tools.insert(name, tool);
}
pub fn remove_tool(&mut self, name: &str) -> bool {
self.tools.remove(name).is_some()
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send(
&mut self,
model: Arc<dyn LanguageModel>,
content: impl Into<MessageContent>,
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
cx.notify();
let (events_tx, events_rx) = mpsc::unbounded();
self.messages.push(AgentMessage {
role: Role::User,
content: vec![content.into()],
});
self.running_turn = Some(cx.spawn(async move |thread, cx| {
let turn_result = async {
// Perform one request, then keep looping if the model makes tool calls.
loop {
let request =
thread.update(cx, |thread, _cx| thread.build_completion_request())?;
println!(
"request: {}",
serde_json::to_string_pretty(&request).unwrap()
);
// Stream events, appending to messages and collecting up tool uses.
let mut events = model.stream_completion(request, cx).await?;
let mut tool_uses = Vec::new();
while let Some(event) = events.next().await {
match event {
Ok(event) => {
thread
.update(cx, |thread, cx| {
tool_uses.extend(thread.handle_response_event(
event,
events_tx.clone(),
cx,
));
})
.ok();
}
Err(error) => {
events_tx.unbounded_send(Err(error)).ok();
break;
}
}
}
// If there are no tool uses, the turn is done.
if tool_uses.is_empty() {
break;
}
// If there are tool uses, wait for their results to be
// computed, then send them together in a single message on
// the next loop iteration.
let tool_results = future::join_all(tool_uses).await;
thread
.update(cx, |thread, _cx| {
thread.messages.push(AgentMessage {
role: Role::User,
content: tool_results.into_iter().map(Into::into).collect(),
});
})
.ok();
}
anyhow::Ok(())
}
.await;
if let Err(error) = turn_result {
events_tx.unbounded_send(Err(error)).ok();
}
}));
events_rx
}
fn handle_response_event(
&mut self,
event: LanguageModelCompletionEvent,
events_tx: mpsc::UnboundedSender<Result<AgentResponseEvent>>,
cx: &mut Context<Self>,
) -> Option<Task<LanguageModelToolResult>> {
use LanguageModelCompletionEvent::*;
events_tx.unbounded_send(Ok(event.clone())).ok();
match event {
Text(new_text) => self.handle_text_event(new_text, cx),
Thinking { text, signature } => {}
ToolUse(tool_use) => {
return Some(self.handle_tool_use_event(tool_use, cx));
}
StartMessage { message_id, role } => {
self.messages.push(AgentMessage {
role,
content: Vec::new(),
});
}
UsageUpdate(token_usage) => {}
Stop(stop_reason) => self.handle_stop_event(stop_reason),
}
None
}
fn handle_stop_event(&mut self, stop_reason: StopReason) {
match stop_reason {
StopReason::EndTurn | StopReason::ToolUse => {}
StopReason::MaxTokens => todo!(),
}
}
fn handle_text_event(&mut self, new_text: String, cx: &mut Context<Self>) {
if let Some(last_message) = self.messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
text.push_str(&new_text);
} else {
last_message.content.push(MessageContent::Text(new_text));
}
cx.notify();
} else {
todo!("does this happen in practice?");
}
}
fn handle_tool_use_event(
&mut self,
tool_use: LanguageModelToolUse,
cx: &mut Context<Self>,
) -> Task<LanguageModelToolResult> {
if let Some(last_message) = self.messages.last_mut() {
debug_assert!(last_message.role == Role::Assistant);
last_message.content.push(tool_use.clone().into());
cx.notify();
} else {
todo!("does this happen in practice?");
}
if let Some(tool) = self.tools.get(&tool_use.name) {
let pending_tool_result = tool.clone().run(
tool_use.input,
&self.build_request_messages(),
self.project.clone(),
self.action_log.clone(),
cx,
);
cx.foreground_executor().spawn(async move {
match pending_tool_result.output.await {
Ok(tool_output) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: false,
content: Arc::from(tool_output),
},
Err(error) => LanguageModelToolResult {
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
content: Arc::from(error.to_string()),
},
}
})
} else {
Task::ready(LanguageModelToolResult {
content: Arc::from(format!("No tool named {} exists", tool_use.name)),
tool_use_id: tool_use.id,
tool_name: tool_use.name,
is_error: true,
})
}
}
fn build_completion_request(&self) -> LanguageModelRequest {
LanguageModelRequest {
thread_id: None,
prompt_id: None,
messages: self.build_request_messages(),
tools: self
.tools
.values()
.filter_map(|tool| {
Some(LanguageModelRequestTool {
name: tool.name(),
description: tool.description(),
input_schema: tool
.input_schema(LanguageModelToolSchemaFormat::JsonSchema)
.log_err()?,
})
})
.collect(),
stop: Vec::new(),
temperature: None,
}
}
fn build_request_messages(&self) -> Vec<LanguageModelRequestMessage> {
self.messages
.iter()
.map(|message| LanguageModelRequestMessage {
role: message.role,
content: message.content.clone(),
cache: false,
})
.collect()
}
}

191
crates/agent2/src/tests.rs Normal file
View File

@@ -0,0 +1,191 @@
use super::*;
use assistant_tool::{IconName, Project, ToolResult};
use client::{Client, UserStore};
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use language_model::LanguageModelRegistry;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::time::Duration;
mod tools;
use tools::*;
#[gpui::test]
async fn test_echo(cx: &mut TestAppContext) {
let AgentTest { model, agent, .. } = agent_test(cx).await;
let events = agent
.update(cx, |agent, cx| {
agent.send(model.clone(), "Testing: Reply with 'Hello'", cx)
})
.collect()
.await;
agent.update(cx, |agent, _cx| {
assert_eq!(
agent.messages.last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![StopReason::EndTurn]);
}
#[gpui::test]
async fn test_tool_calls(cx: &mut TestAppContext) {
let AgentTest { model, agent, .. } = agent_test(cx).await;
// Test a tool calls that's likely to complete before streaming stops.
let events = agent
.update(cx, |agent, cx| {
agent.add_tool(Arc::new(EchoTool));
agent.send(
model.clone(),
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
cx,
)
})
.collect()
.await;
assert_eq!(
stop_events(events),
vec![StopReason::ToolUse, StopReason::EndTurn]
);
// Test a tool calls that's likely to complete after streaming stops.
let events = agent
.update(cx, |agent, cx| {
agent.remove_tool(&EchoTool.name());
agent.add_tool(Arc::new(DelayTool));
agent.send(
model.clone(),
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
cx,
)
})
.collect()
.await;
assert_eq!(
stop_events(events),
vec![StopReason::ToolUse, StopReason::EndTurn]
);
agent.update(cx, |agent, _cx| {
assert!(agent
.messages
.last()
.unwrap()
.content
.iter()
.any(|content| {
if let MessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
}
}));
});
}
#[gpui::test]
async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
let AgentTest { model, agent, .. } = agent_test(cx).await;
// Test concurrent tool calls with different delay times
let events = agent
.update(cx, |agent, cx| {
agent.add_tool(Arc::new(DelayTool));
agent.send(
model.clone(),
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
cx,
)
})
.map(|event| dbg!(event))
.collect()
.await;
let stop_reasons = stop_events(events);
assert_eq!(stop_reasons, vec![StopReason::ToolUse, StopReason::EndTurn]);
agent.update(cx, |agent, _cx| {
let last_message = agent.messages.last().unwrap();
let text = last_message
.content
.iter()
.filter_map(|content| {
if let MessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
}
})
.collect::<String>();
assert!(text.contains("Ding"));
});
}
fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
LanguageModelCompletionEvent::Stop(stop_reason) => Some(stop_reason),
_ => None,
})
.collect()
}
struct AgentTest {
model: Arc<dyn LanguageModel>,
agent: Entity<Agent>,
}
async fn agent_test(cx: &mut TestAppContext) -> AgentTest {
cx.executor().allow_parking();
cx.update(settings::init);
let fs = FakeFs::new(cx.executor().clone());
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let agent = cx.new(|_| Agent::new(project.clone(), action_log.clone()));
let model = cx
.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
let models = LanguageModelRegistry::read_global(cx);
let model = models
.available_models(cx)
.find(|model| model.id().0 == "claude-3-7-sonnet-latest")
.unwrap();
let provider = models.provider(&model.provider_id()).unwrap();
let authenticated = provider.authenticate(cx);
cx.spawn(async move |cx| {
authenticated.await.unwrap();
model
})
})
.await;
AgentTest {
model,
agent: agent,
}
}
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
if std::env::var("RUST_LOG").is_ok() {
env_logger::init();
}
}

View File

@@ -0,0 +1,102 @@
use super::*;
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
text: String,
}
pub struct EchoTool;
impl Tool for EchoTool {
fn name(&self) -> String {
"echo".to_string()
}
fn description(&self) -> String {
"A tool that echoes its input".to_string()
}
fn icon(&self) -> IconName {
IconName::Ai
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &gpui::App) -> bool {
false
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Echo".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: gpui::Entity<Project>,
_action_log: gpui::Entity<assistant_tool::ActionLog>,
cx: &mut gpui::App,
) -> ToolResult {
ToolResult {
output: cx.foreground_executor().spawn(async move {
let input: EchoToolInput = serde_json::from_value(input)?;
Ok(input.text)
}),
card: None,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
assistant_tools::json_schema_for::<EchoToolInput>(format)
}
}
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct DelayToolInput {
ms: u64,
}
pub struct DelayTool;
impl Tool for DelayTool {
fn name(&self) -> String {
"delay".to_string()
}
fn description(&self) -> String {
"A tool that waits for a specified delay".to_string()
}
fn icon(&self) -> IconName {
IconName::Cog
}
fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &gpui::App) -> bool {
false
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Delay".to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
_project: gpui::Entity<Project>,
_action_log: gpui::Entity<assistant_tool::ActionLog>,
cx: &mut gpui::App,
) -> ToolResult {
ToolResult {
output: cx.foreground_executor().spawn(async move {
let input: DelayToolInput = serde_json::from_value(input)?;
smol::Timer::after(Duration::from_millis(input.ms)).await;
Ok("Ding".to_string())
}),
card: None,
}
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
assistant_tools::json_schema_for::<DelayToolInput>(format)
}
}

View File

@@ -416,6 +416,7 @@ pub async fn stream_completion_with_rate_limit_info(
let beta_headers = Model::from_id(&request.base.model)
.map(|model| model.beta_headers())
.unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
@@ -423,6 +424,7 @@ pub async fn stream_completion_with_rate_limit_info(
.header("Anthropic-Beta", beta_headers)
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
let serialized_request =
serde_json::to_string(&request).context("failed to serialize request")?;
let request = request_builder

View File

@@ -27,7 +27,7 @@ use language_model::{
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
use prompt_store::{PromptBuilder, PromptId};
use search::{BufferSearchBar, buffer_search::DivRegistrar};
use settings::{Settings, update_settings_file};
@@ -58,11 +58,11 @@ pub fn init(cx: &mut App) {
.register_action(AssistantPanel::show_configuration)
.register_action(AssistantPanel::create_new_context)
.register_action(AssistantPanel::restart_context_servers)
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
.register_action(|workspace, _: &OpenPromptLibrary, window, cx| {
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
workspace.focus_panel::<AssistantPanel>(window, cx);
panel.update(cx, |panel, cx| {
panel.deploy_prompt_library(action, window, cx)
panel.deploy_prompt_library(&OpenPromptLibrary::default(), window, cx)
});
}
});
@@ -1060,9 +1060,7 @@ impl AssistantPanel {
None,
))
}),
action.prompt_to_select.map(|uuid| PromptId::User {
uuid: UserPromptId(uuid),
}),
action.prompt_to_focus.map(|uuid| PromptId::User { uuid }),
cx,
)
.detach_and_log_err(cx);

View File

@@ -37,7 +37,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::{CodeAction, LspAction, ProjectTransaction};
@@ -1766,7 +1766,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View File

@@ -19,7 +19,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder;
use settings::{Settings, update_settings_file};
use std::{
@@ -755,7 +755,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)

View File

@@ -39,7 +39,7 @@ use language_model::{
Role,
};
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
use multi_buffer::MultiBufferRow;
use picker::Picker;
@@ -298,7 +298,6 @@ impl ContextEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
ModelType::Default,
window,
cx,
)
@@ -2089,7 +2088,7 @@ impl ContextEditor {
continue;
};
let image_id = image.id();
let image_task = LanguageModelImage::from_image(Arc::new(image), cx).shared();
let image_task = LanguageModelImage::from_image(image, cx).shared();
for image_position in image_positions.iter() {
context.insert_content(

View File

@@ -44,10 +44,9 @@ impl SlashCommand for PromptSlashCommand {
let store = PromptStore::global(cx);
let query = arguments.to_owned().join(" ");
cx.spawn(async move |cx| {
let cancellation_flag = Arc::new(AtomicBool::default());
let prompts: Vec<PromptMetadata> = store
.await?
.read_with(cx, |store, cx| store.search(query, cancellation_flag, cx))?
.read_with(cx, |store, cx| store.search(query, cx))?
.await;
Ok(prompts
.into_iter()

View File

@@ -39,9 +39,10 @@ impl ActionLog {
self.edited_since_project_diagnostics_check
}
fn track_buffer_internal(
fn track_buffer(
&mut self,
buffer: Entity<Buffer>,
created: bool,
cx: &mut Context<Self>,
) -> &mut TrackedBuffer {
let tracked_buffer = self
@@ -58,11 +59,7 @@ impl ActionLog {
let base_text;
let status;
let unreviewed_changes;
if buffer
.read(cx)
.file()
.map_or(true, |file| !file.disk_state().exists())
{
if created {
base_text = Rope::default();
status = TrackedBufferStatus::Created;
unreviewed_changes = Patch::new(vec![Edit {
@@ -149,7 +146,7 @@ impl ActionLog {
// resurrected externally, we want to clear the changes we
// were tracking and reset the buffer's state.
self.tracked_buffers.remove(&buffer);
self.track_buffer_internal(buffer, cx);
self.track_buffer(buffer, false, cx);
}
cx.notify();
}
@@ -263,15 +260,26 @@ impl ActionLog {
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn track_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer_internal(buffer, cx);
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer, false, cx);
}
/// Track a buffer that was added as context, so we can notify the model about user edits.
pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer, false, cx);
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer.clone(), true, cx);
self.buffer_edited(buffer, cx)
}
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.edited_since_project_diagnostics_check = true;
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
tracked_buffer.status = TrackedBufferStatus::Modified;
}
@@ -279,7 +287,7 @@ impl ActionLog {
}
pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
match tracked_buffer.status {
TrackedBufferStatus::Created => {
self.tracked_buffers.remove(&buffer);
@@ -389,7 +397,7 @@ impl ActionLog {
// Clear all tracked changes for this buffer and start over as if we just read it.
self.tracked_buffers.remove(&buffer);
self.track_buffer_internal(buffer.clone(), cx);
self.track_buffer(buffer.clone(), false, cx);
cx.notify();
save
}
@@ -687,20 +695,12 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx)
@@ -765,23 +765,12 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/dir"),
json!({"file": "abc\ndef\nghi\njkl\nmno\npqr"}),
)
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 0)..Point::new(2, 0), "")], None, cx)
@@ -850,20 +839,12 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
@@ -947,21 +928,25 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
.unwrap();
// Simulate file2 being recreated by a tool.
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("lorem", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
@@ -1082,9 +1067,8 @@ mod tests {
.update(cx, |project, cx| project.open_buffer(file2_path, cx))
.await
.unwrap();
action_log.update(cx, |log, cx| log.track_buffer(buffer2.clone(), cx));
buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer2.clone(), cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer2.clone(), cx));
project
.update(cx, |project, cx| project.save_buffer(buffer2.clone(), cx))
.await
@@ -1129,7 +1113,7 @@ mod tests {
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
@@ -1264,7 +1248,7 @@ mod tests {
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
@@ -1397,9 +1381,8 @@ mod tests {
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
@@ -1455,7 +1438,7 @@ mod tests {
.await
.unwrap();
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
for _ in 0..operations {
match rng.gen_range(0..100) {
@@ -1507,7 +1490,7 @@ mod tests {
log::info!("quiescing...");
cx.run_until_parked();
action_log.update(cx, |log, cx| {
let tracked_buffer = log.track_buffer_internal(buffer.clone(), cx);
let tracked_buffer = log.track_buffer(buffer.clone(), false, cx);
let mut old_text = tracked_buffer.base_text.clone();
let new_text = buffer.read(cx).as_rope();
for edit in tracked_buffer.unreviewed_changes.edits() {

View File

@@ -14,10 +14,10 @@ use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
pub use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
pub use project::Project;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
@@ -30,7 +30,6 @@ pub fn init(cx: &mut App) {
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
InputStillStreaming,
NeedsConfirmation,
Pending,
Running,
@@ -42,7 +41,6 @@ impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
ToolUseStatus::InputStillStreaming => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
@@ -150,12 +148,6 @@ pub trait Tool: 'static + Send + Sync {
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String;
/// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
/// (so information may be missing).
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
self.ui_text(input)
}
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,

View File

@@ -22,7 +22,6 @@ mod schema;
mod symbol_info_tool;
mod terminal_tool;
mod thinking_tool;
mod ui;
mod web_search_tool;
use std::sync::Arc;
@@ -55,8 +54,7 @@ use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub use path_search_tool::PathSearchToolInput;
pub use schema::json_schema_for;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);

View File

@@ -159,7 +159,7 @@ impl Tool for CodeActionTool {
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
let range = {

View File

@@ -174,7 +174,7 @@ pub async fn file_outline(
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
// Wait until the buffer has been fully parsed, so that we can read its outline.

View File

@@ -209,7 +209,7 @@ impl Tool for ContentsTool {
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)
@@ -221,7 +221,7 @@ impl Tool for ContentsTool {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)

View File

@@ -33,18 +33,8 @@ pub struct CreateFileToolInput {
pub contents: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
contents: String,
}
pub struct CreateFileTool;
const DEFAULT_UI_TEXT: &str = "Create file";
impl Tool for CreateFileTool {
fn name(&self) -> String {
"create_file".into()
@@ -72,14 +62,7 @@ impl Tool for CreateFileTool {
let path = MarkdownString::inline_code(&input.path);
format!("Create file {path}")
}
Err(_) => DEFAULT_UI_TEXT.to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<PartialInput>(input.clone()).ok() {
Some(input) if !input.path.is_empty() => input.path,
_ => DEFAULT_UI_TEXT.to_string(),
Err(_) => "Create file".to_string(),
}
}
@@ -112,12 +95,9 @@ impl Tool for CreateFileTool {
.await
.map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
cx.update(|cx| {
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx)
});
buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx)
action_log.will_create_buffer(buffer.clone(), cx)
});
})?;
@@ -131,60 +111,3 @@ impl Tool for CreateFileTool {
.into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let tool = CreateFileTool;
let input = json!({
"path": "src/main.rs",
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
});
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_without_path() {
let tool = CreateFileTool;
let input = json!({
"path": "",
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
});
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn still_streaming_ui_text_with_null() {
let tool = CreateFileTool;
let input = serde_json::Value::Null;
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn ui_text_with_valid_input() {
let tool = CreateFileTool;
let input = json!({
"path": "src/main.rs",
"contents": "fn main() {\n println!(\"Hello, world!\");\n}"
});
assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
}
#[test]
fn ui_text_with_invalid_input() {
let tool = CreateFileTool;
let input = json!({
"invalid": "field"
});
assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
}
}

View File

@@ -47,22 +47,8 @@ pub struct EditFileToolInput {
pub new_string: String,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
struct PartialInput {
#[serde(default)]
path: String,
#[serde(default)]
display_description: String,
#[serde(default)]
old_string: String,
#[serde(default)]
new_string: String,
}
pub struct EditFileTool;
const DEFAULT_UI_TEXT: &str = "Editing file";
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
@@ -87,26 +73,10 @@ impl Tool for EditFileTool {
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
Err(_) => "Edit file".to_string(),
}
}
fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
let description = input.display_description.trim();
if !description.is_empty() {
return description.to_string();
}
let path = input.path.trim();
if !path.is_empty() {
return path.to_string();
}
}
DEFAULT_UI_TEXT.to_string()
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
@@ -182,7 +152,7 @@ impl Tool for EditFileTool {
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| {
log.track_buffer(buffer.clone(), cx)
log.buffer_read(buffer.clone(), cx)
});
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
@@ -211,69 +181,3 @@ impl Tool for EditFileTool {
}).into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn still_streaming_ui_text_with_path() {
let tool = EditFileTool;
let input = json!({
"path": "src/main.rs",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
}
#[test]
fn still_streaming_ui_text_with_description() {
let tool = EditFileTool;
let input = json!({
"path": "",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
}
#[test]
fn still_streaming_ui_text_with_path_and_description() {
let tool = EditFileTool;
let input = json!({
"path": "src/main.rs",
"display_description": "Fix error handling",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
}
#[test]
fn still_streaming_ui_text_no_path_or_description() {
let tool = EditFileTool;
let input = json!({
"path": "",
"display_description": "",
"old_string": "old code",
"new_string": "new code"
});
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
#[test]
fn still_streaming_ui_text_with_null() {
let tool = EditFileTool;
let input = serde_json::Value::Null;
assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
}
}

View File

@@ -0,0 +1,268 @@
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use crate::replace::replace_exact;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FindReplaceFileToolInput {
/// The path of the file to modify.
///
/// WARNING: When specifying which file path need changing, you MUST
/// start each path with one of the project's root directories.
///
/// The following examples assume we have two root directories in the project:
/// - backend
/// - frontend
///
/// <example>
/// `backend/src/main.rs`
///
/// Notice how the file path starts with root-1. Without that, the path
/// would be ambiguous and the call would fail!
/// </example>
///
/// <example>
/// `frontend/db.js`
/// </example>
pub path: PathBuf,
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
pub display_description: String,
/// The unique string to find in the file. This string cannot be empty;
/// if the string is empty, the tool call will fail. Remember, do not use this tool
/// to create new files from scratch, or to overwrite existing files! Use a different
/// approach if you want to do that.
///
/// If this string appears more than once in the file, this tool call will fail,
/// so it is absolutely critical that you verify ahead of time that the string
/// is unique. You can search within the file to verify this.
///
/// To make the string more likely to be unique, include a minimum of 3 lines of context
/// before the string you actually want to find, as well as a minimum of 3 lines of
/// context after the string you want to find. (These lines of context should appear
/// in the `replace` string as well.) If 3 lines of context is not enough to obtain
/// a string that appears only once in the file, then double the number of context lines
/// until the string becomes unique. (Start with 3 lines before and 3 lines after
/// though, because too much context is needlessly costly.)
///
/// Do not alter the context lines of code in any way, and make sure to preserve all
/// whitespace and indentation for all lines of code. This string must be exactly as
/// it appears in the file, because this tool will do a literal find/replace, and if
/// even one character in this string is different in any way from how it appears
/// in the file, then the tool call will fail.
///
/// If you get an error that the `find` string was not found, this means that either
/// you made a mistake, or that the file has changed since you last looked at it.
/// Either way, when this happens, you should retry doing this tool call until it
/// succeeds, up to 3 times. Each time you retry, you should take another look at
/// the exact text of the file in question, to make sure that you are searching for
/// exactly the right string. Regardless of whether it was because you made a mistake
/// or because the file changed since you last looked at it, you should be extra
/// careful when retrying in this way. It's a bad experience for the user if
/// this `find` string isn't found, so be super careful to get it exactly right!
///
/// <example>
/// If a file contains this code:
///
/// ```ignore
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
/// // Check if user exists first
/// let user = database.find_user(user_id)?;
///
/// // This is the part we want to modify
/// if user.role == "admin" {
/// return Ok(true);
/// }
///
/// // Check other permissions
/// check_custom_permissions(user_id)
/// }
/// ```
///
/// Your find string should include at least 3 lines of context before and after the part
/// you want to change:
///
/// ```ignore
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
/// // Check if user exists first
/// let user = database.find_user(user_id)?;
///
/// // This is the part we want to modify
/// if user.role == "admin" {
/// return Ok(true);
/// }
///
/// // Check other permissions
/// check_custom_permissions(user_id)
/// }
/// ```
///
/// And your replace string might look like:
///
/// ```ignore
/// fn check_user_permissions(user_id: &str) -> Result<bool> {
/// // Check if user exists first
/// let user = database.find_user(user_id)?;
///
/// // This is the part we want to modify
/// if user.role == "admin" || user.role == "superuser" {
/// return Ok(true);
/// }
///
/// // Check other permissions
/// check_custom_permissions(user_id)
/// }
/// ```
/// </example>
pub find: String,
/// The string to replace the one unique occurrence of the find string with.
pub replace: String,
}
pub struct FindReplaceFileTool;
impl Tool for FindReplaceFileTool {
fn name(&self) -> String {
"find_replace_file".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
false
}
fn description(&self) -> String {
include_str!("find_replace_tool/description.md").to_string()
}
fn icon(&self) -> IconName {
IconName::Pencil
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindReplaceFileToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindReplaceFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Edit file".to_string(),
}
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FindReplaceFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
.context("Path not found in project")
})??;
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.find.is_empty() {
return Err(anyhow!("`find` string cannot be empty. Use a different tool if you want to create a file."));
}
if input.find == input.replace {
return Err(anyhow!("The `find` and `replace` strings are identical, so no changes would be made."));
}
let result = cx
.background_spawn(async move {
// Try to match exactly
let diff = replace_exact(&input.find, &input.replace, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&input.find, &input.replace, &snapshot))?;
if diff.edits.is_empty() {
return None;
}
let old_text = snapshot.text();
Some((old_text, diff))
})
.await;
let Some((old_text, diff)) = result else {
let err = buffer.read_with(cx, |buffer, _cx| {
let file_exists = buffer
.file()
.map_or(false, |file| file.disk_state().exists());
if !file_exists {
anyhow!("{} does not exist", input.path.display())
} else if buffer.is_empty() {
anyhow!(
"{} is empty, so the provided `find` string wasn't found.",
input.path.display()
)
} else {
anyhow!("Failed to match the provided `find` string")
}
})?;
return Err(err)
};
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx)
});
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
buffer.finalize_last_transaction();
buffer.snapshot()
});
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
});
snapshot
})?;
project.update( cx, |project, cx| {
project.save_buffer(buffer, cx)
})?.await?;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
}).await;
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
}).into()
}
}

View File

@@ -134,7 +134,7 @@ impl Tool for ReadFileTool {
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)
@@ -147,7 +147,7 @@ impl Tool for ReadFileTool {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)

View File

@@ -106,7 +106,7 @@ impl Tool for RenameTool {
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
let position = {

View File

@@ -59,8 +59,10 @@ pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapsho
let max_row = buffer.max_point().row;
'windows: for start_row in 0..max_row + 1 {
let end_row = start_row + old_lines.len().saturating_sub(1) as u32;
'windows: for start_row in 0..max_row.saturating_sub(old_lines.len() as u32 - 1) {
let mut common_leading = None;
let end_row = start_row + old_lines.len() as u32 - 1;
if end_row > max_row {
// The buffer ends before fully matching the pattern
@@ -75,14 +77,6 @@ pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapsho
let mut window_lines = window_text.lines();
let mut old_lines_iter = old_lines.iter();
let mut common_mismatch = None;
#[derive(Eq, PartialEq)]
enum Mismatch {
OverIndented(String),
UnderIndented(String),
}
while let (Some(window_line), Some(old_line)) = (window_lines.next(), old_lines_iter.next())
{
let line_trimmed = window_line.trim_start();
@@ -95,24 +89,18 @@ pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapsho
continue;
}
let line_mismatch = if window_line.len() > old_line.len() {
let prefix = window_line[..window_line.len() - old_line.len()].to_string();
Mismatch::UnderIndented(prefix)
} else {
let prefix = old_line[..old_line.len() - window_line.len()].to_string();
Mismatch::OverIndented(prefix)
};
let line_leading = &window_line[..window_line.len() - old_line.len()];
match &common_mismatch {
Some(common_mismatch) if common_mismatch != &line_mismatch => {
match &common_leading {
Some(common_leading) if common_leading != line_leading => {
continue 'windows;
}
Some(_) => (),
None => common_mismatch = Some(line_mismatch),
None => common_leading = Some(line_leading.to_string()),
}
}
if let Some(common_mismatch) = &common_mismatch {
if let Some(common_leading) = common_leading {
let line_ending = buffer.line_ending();
let replacement = new_lines
.iter()
@@ -120,13 +108,7 @@ pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapsho
if new_line.trim().is_empty() {
new_line.to_string()
} else {
match common_mismatch {
Mismatch::UnderIndented(prefix) => prefix.to_string() + new_line,
Mismatch::OverIndented(prefix) => new_line
.strip_prefix(prefix)
.unwrap_or(new_line)
.to_string(),
}
common_leading.to_string() + new_line
}
})
.collect::<Vec<_>>()
@@ -168,123 +150,14 @@ fn lines_with_min_indent(input: &str) -> (Vec<&str>, usize) {
}
#[cfg(test)]
mod replace_exact_tests {
use super::*;
use gpui::TestAppContext;
use gpui::prelude::*;
#[gpui::test]
async fn basic(cx: &mut TestAppContext) {
let result = test_replace_exact(cx, "let x = 41;", "let x = 41;", "let x = 42;").await;
assert_eq!(result, Some("let x = 42;".to_string()));
}
#[gpui::test]
async fn no_match(cx: &mut TestAppContext) {
let result = test_replace_exact(cx, "let x = 41;", "let y = 42;", "let y = 43;").await;
assert_eq!(result, None);
}
#[gpui::test]
async fn multi_line(cx: &mut TestAppContext) {
let whole = "fn example() {\n let x = 41;\n println!(\"x = {}\", x);\n}";
let old_text = " let x = 41;\n println!(\"x = {}\", x);";
let new_text = " let x = 42;\n println!(\"x = {}\", x);";
let result = test_replace_exact(cx, whole, old_text, new_text).await;
assert_eq!(
result,
Some("fn example() {\n let x = 42;\n println!(\"x = {}\", x);\n}".to_string())
);
}
#[gpui::test]
async fn multiple_occurrences(cx: &mut TestAppContext) {
let whole = "let x = 41;\nlet y = 41;\nlet z = 41;";
let result = test_replace_exact(cx, whole, "let x = 41;", "let x = 42;").await;
assert_eq!(
result,
Some("let x = 42;\nlet y = 41;\nlet z = 41;".to_string())
);
}
#[gpui::test]
async fn empty_buffer(cx: &mut TestAppContext) {
let result = test_replace_exact(cx, "", "let x = 41;", "let x = 42;").await;
assert_eq!(result, None);
}
#[gpui::test]
async fn partial_match(cx: &mut TestAppContext) {
let whole = "let x = 41; let y = 42;";
let result = test_replace_exact(cx, whole, "let x = 41", "let x = 42").await;
assert_eq!(result, Some("let x = 42; let y = 42;".to_string()));
}
#[gpui::test]
async fn whitespace_sensitive(cx: &mut TestAppContext) {
let result = test_replace_exact(cx, "let x = 41;", " let x = 41;", "let x = 42;").await;
assert_eq!(result, None);
}
#[gpui::test]
async fn entire_buffer(cx: &mut TestAppContext) {
let result = test_replace_exact(cx, "let x = 41;", "let x = 41;", "let x = 42;").await;
assert_eq!(result, Some("let x = 42;".to_string()));
}
async fn test_replace_exact(
cx: &mut TestAppContext,
whole: &str,
old: &str,
new: &str,
) -> Option<String> {
let buffer = cx.new(|cx| language::Buffer::local(whole, cx));
let buffer_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let diff = replace_exact(old, new, &buffer_snapshot).await;
diff.map(|diff| {
buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
})
})
}
}
#[cfg(test)]
mod flexible_indent_tests {
mod tests {
use super::*;
use gpui::TestAppContext;
use gpui::prelude::*;
use unindent::Unindent;
#[gpui::test]
fn test_underindented_single_line(cx: &mut TestAppContext) {
let cur = " let a = 41;".to_string();
let old = " let a = 41;".to_string();
let new = " let a = 42;".to_string();
let exp = " let a = 42;".to_string();
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
assert_eq!(result, Some(exp.to_string()))
}
#[gpui::test]
fn test_overindented_single_line(cx: &mut TestAppContext) {
let cur = " let a = 41;".to_string();
let old = " let a = 41;".to_string();
let new = " let a = 42;".to_string();
let exp = " let a = 42;".to_string();
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
assert_eq!(result, Some(exp.to_string()))
}
#[gpui::test]
fn test_underindented_multi_line(cx: &mut TestAppContext) {
fn test_replace_consistent_indentation(cx: &mut TestAppContext) {
let whole = r#"
fn test() {
let x = 5;
@@ -321,33 +194,6 @@ mod flexible_indent_tests {
);
}
#[gpui::test]
fn test_overindented_multi_line(cx: &mut TestAppContext) {
let cur = r#"
fn foo() {
let a = 41;
let b = 3.13;
}
"#
.unindent();
// 6 space indent instead of 4
let old = " let a = 41;\n let b = 3.13;";
let new = " let a = 42;\n let b = 3.14;";
let expected = r#"
fn foo() {
let a = 42;
let b = 3.14;
}
"#
.unindent();
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
assert_eq!(result, Some(expected.to_string()))
}
#[gpui::test]
fn test_replace_inconsistent_indentation(cx: &mut TestAppContext) {
let whole = r#"
@@ -420,6 +266,7 @@ mod flexible_indent_tests {
#[gpui::test]
fn test_replace_no_match(cx: &mut TestAppContext) {
// Test with no match
let whole = r#"
fn test() {
let x = 5;
@@ -470,71 +317,6 @@ mod flexible_indent_tests {
);
}
#[gpui::test]
fn test_replace_whole_is_shorter_than_old(cx: &mut TestAppContext) {
let whole = r#"
let x = 5;
"#
.unindent();
let old = r#"
let x = 5;
let y = 10;
"#
.unindent();
let new = r#"
let x = 5;
let y = 20;
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
None
);
}
#[gpui::test]
fn test_replace_old_is_empty(cx: &mut TestAppContext) {
let whole = r#"
fn test() {
let x = 5;
}
"#
.unindent();
let old = "";
let new = r#"
let y = 10;
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
None
);
}
#[gpui::test]
fn test_replace_whole_is_empty(cx: &mut TestAppContext) {
let whole = "";
let old = r#"
let x = 5;
"#
.unindent();
let new = r#"
let x = 10;
"#
.unindent();
assert_eq!(
test_replace_with_flexible_indent(cx, &whole, &old, &new),
None
);
}
#[test]
fn test_lines_with_min_indent() {
// Empty string
@@ -722,133 +504,6 @@ mod flexible_indent_tests {
);
}
#[gpui::test]
async fn test_replace_exact_basic(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
assert!(diff.is_some());
let diff = diff.unwrap();
assert_eq!(diff.edits.len(), 1);
let result = buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
});
assert_eq!(result, "let x = 42;");
}
#[gpui::test]
async fn test_replace_exact_no_match(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let diff = replace_exact("let y = 42;", "let y = 43;", &snapshot).await;
assert!(diff.is_none());
}
#[gpui::test]
async fn test_replace_exact_multi_line(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| {
language::Buffer::local(
"fn example() {\n let x = 41;\n println!(\"x = {}\", x);\n}",
cx,
)
});
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let old_text = " let x = 41;\n println!(\"x = {}\", x);";
let new_text = " let x = 42;\n println!(\"x = {}\", x);";
let diff = replace_exact(old_text, new_text, &snapshot).await;
assert!(diff.is_some());
let diff = diff.unwrap();
let result = buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
});
assert_eq!(
result,
"fn example() {\n let x = 42;\n println!(\"x = {}\", x);\n}"
);
}
#[gpui::test]
async fn test_replace_exact_multiple_occurrences(cx: &mut TestAppContext) {
let buffer =
cx.new(|cx| language::Buffer::local("let x = 41;\nlet y = 41;\nlet z = 41;", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
// Should replace only the first occurrence
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
assert!(diff.is_some());
let diff = diff.unwrap();
let result = buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
});
assert_eq!(result, "let x = 42;\nlet y = 41;\nlet z = 41;");
}
#[gpui::test]
async fn test_replace_exact_empty_buffer(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| language::Buffer::local("", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
assert!(diff.is_none());
}
#[gpui::test]
async fn test_replace_exact_partial_match(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| language::Buffer::local("let x = 41; let y = 42;", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
// Verify substring replacement actually works
let diff = replace_exact("let x = 41", "let x = 42", &snapshot).await;
assert!(diff.is_some());
let diff = diff.unwrap();
let result = buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
});
assert_eq!(result, "let x = 42; let y = 42;");
}
#[gpui::test]
async fn test_replace_exact_whitespace_sensitive(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let diff = replace_exact(" let x = 41;", "let x = 42;", &snapshot).await;
assert!(diff.is_none());
}
#[gpui::test]
async fn test_replace_exact_entire_buffer(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
assert!(diff.is_some());
let diff = diff.unwrap();
let result = buffer.update(cx, |buffer, cx| {
let _ = buffer.apply_diff(diff, cx);
buffer.text()
});
assert_eq!(result, "let x = 42;");
}
fn test_replace_with_flexible_indent(
cx: &mut TestAppContext,
whole: &str,

View File

@@ -140,7 +140,7 @@ impl Tool for SymbolInfoTool {
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
let position = {

View File

@@ -1,3 +0,0 @@
mod tool_call_card_header;
pub use tool_call_card_header::*;

View File

@@ -1,102 +0,0 @@
use gpui::{Animation, AnimationExt, App, IntoElement, pulsating_between};
use std::time::Duration;
use ui::{Tooltip, prelude::*};
/// A reusable header component for tool call cards.
#[derive(IntoElement)]
pub struct ToolCallCardHeader {
icon: IconName,
primary_text: SharedString,
secondary_text: Option<SharedString>,
is_loading: bool,
error: Option<String>,
}
impl ToolCallCardHeader {
pub fn new(icon: IconName, primary_text: impl Into<SharedString>) -> Self {
Self {
icon,
primary_text: primary_text.into(),
secondary_text: None,
is_loading: false,
error: None,
}
}
pub fn with_secondary_text(mut self, text: impl Into<SharedString>) -> Self {
self.secondary_text = Some(text.into());
self
}
pub fn loading(mut self) -> Self {
self.is_loading = true;
self
}
pub fn with_error(mut self, error: impl Into<String>) -> Self {
self.error = Some(error.into());
self
}
}
impl RenderOnce for ToolCallCardHeader {
fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement {
let font_size = rems(0.8125);
let secondary_text = self.secondary_text;
h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.opacity(0.8)
.child(
h_flex().h(window.line_height()).justify_center().child(
Icon::new(self.icon)
.size(IconSize::XSmall)
.color(Color::Muted),
),
)
.child(
h_flex()
.h(window.line_height())
.gap_1p5()
.text_size(font_size)
.map(|this| {
if let Some(error) = &self.error {
this.child(format!("{} failed", self.primary_text)).child(
IconButton::new("error_info", IconName::Warning)
.shape(ui::IconButtonShape::Square)
.icon_size(IconSize::XSmall)
.icon_color(Color::Warning)
.tooltip(Tooltip::text(error.clone())),
)
} else {
this.child(self.primary_text.clone())
}
})
.when_some(secondary_text, |this, secondary_text| {
this.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
)
.child(div().text_size(font_size).child(secondary_text.clone()))
})
.with_animation(
"loading-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
move |this, delta| {
if self.is_loading {
this.opacity(delta)
} else {
this
}
},
),
)
}
}

View File

@@ -1,11 +1,13 @@
use std::{sync::Arc, time::Duration};
use crate::schema::json_schema_for;
use crate::ui::ToolCallCardHeader;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{Future, FutureExt, TryFutureExt};
use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
use futures::{FutureExt, TryFutureExt};
use gpui::{
Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
pulsating_between,
};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -45,7 +47,7 @@ impl Tool for WebSearchTool {
}
fn ui_text(&self, _input: &serde_json::Value) -> String {
"Searching the Web".to_string()
"Web Search".to_string()
}
fn run(
@@ -113,30 +115,61 @@ impl ToolCard for WebSearchToolCard {
_window: &mut Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = match self.response.as_ref() {
Some(Ok(response)) => {
let text: SharedString = if response.citations.len() == 1 {
"1 result".into()
} else {
format!("{} results", response.citations.len()).into()
};
ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
.with_secondary_text(text)
}
Some(Err(error)) => {
ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
}
None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
};
let header = h_flex()
.id("tool-label-container")
.gap_1p5()
.max_w_full()
.overflow_x_scroll()
.child(
Icon::new(IconName::Globe)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(match self.response.as_ref() {
Some(Ok(response)) => {
let text: SharedString = if response.citations.len() == 1 {
"1 result".into()
} else {
format!("{} results", response.citations.len()).into()
};
h_flex()
.gap_1p5()
.child(Label::new("Searched the Web").size(LabelSize::Small))
.child(
div()
.size(px(3.))
.rounded_full()
.bg(cx.theme().colors().text),
)
.child(Label::new(text).size(LabelSize::Small))
.into_any_element()
}
Some(Err(error)) => div()
.id("web-search-error")
.child(Label::new("Web Search failed").size(LabelSize::Small))
.tooltip(Tooltip::text(error.to_string()))
.into_any_element(),
None => Label::new("Searching the Web…")
.size(LabelSize::Small)
.with_animation(
"web-search-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.alpha(delta),
)
.into_any_element(),
})
.into_any();
let content =
self.response.as_ref().and_then(|response| match response {
Ok(response) => {
Some(
v_flex()
.overflow_hidden()
.ml_1p5()
.pl(px(5.))
.pl_1p5()
.border_l_1()
.border_color(cx.theme().colors().border_variant)
.gap_1()
@@ -176,7 +209,7 @@ impl ToolCard for WebSearchToolCard {
Err(_) => None,
});
v_flex().mb_3().gap_1().child(header).children(content)
v_flex().my_2().gap_1().child(header).children(content)
}
}

View File

@@ -4,7 +4,7 @@ use crate::TelemetrySettings;
use anyhow::Result;
use clock::SystemClock;
use futures::channel::mpsc;
use futures::{Future, FutureExt, StreamExt};
use futures::{Future, StreamExt};
use gpui::{App, AppContext as _, BackgroundExecutor, Task};
use http_client::{self, AsyncBody, HttpClient, HttpClientWithUrl, Method, Request};
use parking_lot::Mutex;
@@ -290,10 +290,6 @@ impl Telemetry {
paths::logs_dir().join("telemetry.log")
}
pub fn has_checksum_seed(&self) -> bool {
ZED_CLIENT_CHECKSUM_SEED.is_some()
}
pub fn start(
self: &Arc<Self>,
system_id: Option<String>,
@@ -434,7 +430,7 @@ impl Telemetry {
let executor = self.executor.clone();
state.flush_events_task = Some(self.executor.spawn(async move {
executor.timer(FLUSH_INTERVAL).await;
this.flush_events().detach();
this.flush_events();
}));
}
@@ -460,7 +456,7 @@ impl Telemetry {
if state.installation_id.is_some() && state.events_queue.len() >= state.max_queue_size {
drop(state);
self.flush_events().detach();
self.flush_events();
}
}
@@ -503,59 +499,60 @@ impl Telemetry {
.body(json_bytes.into())?)
}
pub fn flush_events(self: &Arc<Self>) -> Task<()> {
pub fn flush_events(self: &Arc<Self>) {
let mut state = self.state.lock();
state.first_event_date_time = None;
let mut events = mem::take(&mut state.events_queue);
state.flush_events_task.take();
drop(state);
if events.is_empty() {
return Task::ready(());
return;
}
let this = self.clone();
self.executor.spawn(
async move {
let mut json_bytes = Vec::new();
self.executor
.spawn(
async move {
let mut json_bytes = Vec::new();
if let Some(file) = &mut this.state.lock().log_file {
for event in &mut events {
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, event)?;
file.write_all(&json_bytes)?;
file.write_all(b"\n")?;
if let Some(file) = &mut this.state.lock().log_file {
for event in &mut events {
json_bytes.clear();
serde_json::to_writer(&mut json_bytes, event)?;
file.write_all(&json_bytes)?;
file.write_all(b"\n")?;
}
}
}
let request_body = {
let state = this.state.lock();
let request_body = {
let state = this.state.lock();
EventRequestBody {
system_id: state.system_id.as_deref().map(Into::into),
installation_id: state.installation_id.as_deref().map(Into::into),
session_id: state.session_id.clone(),
metrics_id: state.metrics_id.as_deref().map(Into::into),
is_staff: state.is_staff,
app_version: state.app_version.clone(),
os_name: state.os_name.clone(),
os_version: state.os_version.clone(),
architecture: state.architecture.to_string(),
EventRequestBody {
system_id: state.system_id.as_deref().map(Into::into),
installation_id: state.installation_id.as_deref().map(Into::into),
session_id: state.session_id.clone(),
metrics_id: state.metrics_id.as_deref().map(Into::into),
is_staff: state.is_staff,
app_version: state.app_version.clone(),
os_name: state.os_name.clone(),
os_version: state.os_version.clone(),
architecture: state.architecture.to_string(),
release_channel: state.release_channel.map(Into::into),
events,
release_channel: state.release_channel.map(Into::into),
events,
}
};
let request = this.build_request(json_bytes, request_body)?;
let response = this.http_client.send(request).await?;
if response.status() != 200 {
log::error!("Failed to send events: HTTP {:?}", response.status());
}
};
let request = this.build_request(json_bytes, request_body)?;
let response = this.http_client.send(request).await?;
if response.status() != 200 {
log::error!("Failed to send events: HTTP {:?}", response.status());
anyhow::Ok(())
}
anyhow::Ok(())
}
.log_err()
.map(|_| ()),
)
.log_err(),
)
.detach();
}
}

View File

@@ -128,7 +128,6 @@ serde_json.workspace = true
session = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
sqlx = { version = "0.8", features = ["sqlite"] }
task.workspace = true
theme.workspace = true
unindent.workspace = true
util.workspace = true

View File

@@ -492,8 +492,7 @@ CREATE TABLE IF NOT EXISTS billing_customers (
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users (id),
has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE,
stripe_customer_id TEXT NOT NULL,
trial_started_at TIMESTAMP
stripe_customer_id TEXT NOT NULL
);
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);

View File

@@ -1,2 +0,0 @@
alter table billing_customers
add column trial_started_at timestamp without time zone;

View File

@@ -287,7 +287,7 @@ async fn create_billing_subscription(
}
}
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
let customer_id = if let Some(existing_customer) = existing_billing_customer {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
@@ -320,15 +320,6 @@ async fn create_billing_subscription(
.await?
}
Some(ProductCode::ZedProTrial) => {
if let Some(existing_billing_customer) = &existing_billing_customer {
if existing_billing_customer.trial_started_at.is_some() {
return Err(Error::http(
StatusCode::FORBIDDEN,
"user already used free trial".into(),
));
}
}
stripe_billing
.checkout_with_zed_pro_trial(
app.config.zed_pro_price_id()?,
@@ -826,24 +817,6 @@ async fn handle_customer_subscription_event(
.await?
.ok_or_else(|| anyhow!("billing customer not found"))?;
if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
if subscription.status == SubscriptionStatus::Trialing {
let current_period_start =
DateTime::from_timestamp(subscription.current_period_start, 0)
.ok_or_else(|| anyhow!("No trial subscription period start"))?;
app.db
.update_billing_customer(
billing_customer.id,
&UpdateBillingCustomerParams {
trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
..Default::default()
},
)
.await?;
}
}
let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
&& subscription
.cancellation_details
@@ -870,28 +843,6 @@ async fn handle_customer_subscription_event(
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
.await?
{
let llm_db = app
.llm_db
.clone()
.ok_or_else(|| anyhow!("LLM DB not initialized"))?;
let new_period_start_at =
chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
.ok_or_else(|| anyhow!("No subscription period start"))?;
let new_period_end_at =
chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
.ok_or_else(|| anyhow!("No subscription period end"))?;
llm_db
.transfer_existing_subscription_usage(
billing_customer.user_id,
&existing_subscription,
subscription_kind,
new_period_start_at,
new_period_end_at,
)
.await?;
app.db
.update_billing_subscription(
existing_subscription.id,

View File

@@ -516,7 +516,6 @@ pub async fn post_events(
if let Some(kinesis_client) = app.kinesis_client.clone() {
if let Some(stream) = app.config.kinesis_stream.clone() {
let mut request = kinesis_client.put_records().stream_name(stream);
let mut has_records = false;
for row in for_snowflake(
request_body.clone(),
first_event_at,
@@ -531,12 +530,9 @@ pub async fn post_events(
.build()
.unwrap(),
);
has_records = true;
}
}
if has_records {
request.send().await.log_err();
}
request.send().await.log_err();
}
};
@@ -559,7 +555,7 @@ fn for_snowflake(
country_code: Option<String>,
checksum_matched: bool,
) -> impl Iterator<Item = SnowflakeRow> {
body.events.into_iter().filter_map(move |event| {
body.events.into_iter().flat_map(move |event| {
let timestamp =
first_event_at + Duration::milliseconds(event.milliseconds_since_first_event);
// We will need to double check, but I believe all of the events that
@@ -748,11 +744,9 @@ fn for_snowflake(
// NOTE: most amplitude user properties are read out of our event_properties
// dictionary. See https://app.amplitude.com/data/zed/Zed/sources/detail/production/falcon%3A159998
// for how that is configured.
let user_properties = body.is_staff.map(|is_staff| {
serde_json::json!({
"is_staff": is_staff,
})
});
let user_properties = Some(serde_json::json!({
"is_staff": body.is_staff,
}));
Some(SnowflakeRow {
time: timestamp,

View File

@@ -11,7 +11,6 @@ pub struct UpdateBillingCustomerParams {
pub user_id: ActiveValue<UserId>,
pub stripe_customer_id: ActiveValue<String>,
pub has_overdue_invoices: ActiveValue<bool>,
pub trial_started_at: ActiveValue<Option<DateTime>>,
}
impl Database {
@@ -46,8 +45,7 @@ impl Database {
user_id: params.user_id.clone(),
stripe_customer_id: params.stripe_customer_id.clone(),
has_overdue_invoices: params.has_overdue_invoices.clone(),
trial_started_at: params.trial_started_at.clone(),
created_at: ActiveValue::not_set(),
..Default::default()
})
.exec(&*tx)
.await?;

View File

@@ -10,7 +10,6 @@ pub struct Model {
pub user_id: UserId,
pub stripe_customer_id: String,
pub has_overdue_invoices: bool,
pub trial_started_at: Option<DateTime>,
pub created_at: DateTime,
}

View File

@@ -1,87 +1,8 @@
use chrono::Timelike;
use time::PrimitiveDateTime;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{UserId, billing_subscription};
use crate::db::UserId;
use super::*;
fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
use chrono::{Datelike as _, Timelike as _};
let date = time::Date::from_calendar_date(
datetime.year(),
time::Month::try_from(datetime.month() as u8).unwrap(),
datetime.day() as u8,
)?;
let time = time::Time::from_hms_nano(
datetime.hour() as u8,
datetime.minute() as u8,
datetime.second() as u8,
datetime.nanosecond(),
)?;
Ok(PrimitiveDateTime::new(date, time))
}
impl LlmDatabase {
pub async fn create_subscription_usage(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
plan: SubscriptionKind,
model_requests: i32,
edit_predictions: i32,
) -> Result<subscription_usage::Model> {
self.transaction(|tx| async move {
self.create_subscription_usage_in_tx(
user_id,
period_start_at,
period_end_at,
plan,
model_requests,
edit_predictions,
&tx,
)
.await
})
.await
}
async fn create_subscription_usage_in_tx(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
plan: SubscriptionKind,
model_requests: i32,
edit_predictions: i32,
tx: &DatabaseTransaction,
) -> Result<subscription_usage::Model> {
// Clear out the nanoseconds so that these timestamps are comparable with Unix timestamps.
let period_start_at = period_start_at.with_nanosecond(0).unwrap();
let period_end_at = period_end_at.with_nanosecond(0).unwrap();
let period_start_at = convert_chrono_to_time(period_start_at)?;
let period_end_at = convert_chrono_to_time(period_end_at)?;
Ok(
subscription_usage::Entity::insert(subscription_usage::ActiveModel {
id: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
period_start_at: ActiveValue::set(period_start_at),
period_end_at: ActiveValue::set(period_end_at),
plan: ActiveValue::set(plan),
model_requests: ActiveValue::set(model_requests),
edit_predictions: ActiveValue::set(edit_predictions),
})
.exec_with_returning(tx)
.await?,
)
}
pub async fn get_subscription_usage_for_period(
&self,
user_id: UserId,
@@ -89,77 +10,12 @@ impl LlmDatabase {
period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
self.get_subscription_usage_for_period_in_tx(
user_id,
period_start_at,
period_end_at,
&tx,
)
.await
})
.await
}
async fn get_subscription_usage_for_period_in_tx(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
tx: &DatabaseTransaction,
) -> Result<Option<subscription_usage::Model>> {
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(tx)
.await?)
}
pub async fn transfer_existing_subscription_usage(
&self,
user_id: UserId,
existing_subscription: &billing_subscription::Model,
new_subscription_kind: Option<SubscriptionKind>,
new_period_start_at: DateTimeUtc,
new_period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
match existing_subscription.kind {
Some(SubscriptionKind::ZedProTrial) => {
let trial_period_start_at = existing_subscription
.current_period_start_at()
.ok_or_else(|| anyhow!("No trial subscription period start"))?;
let trial_period_end_at = existing_subscription
.current_period_end_at()
.ok_or_else(|| anyhow!("No trial subscription period end"))?;
let existing_usage = self
.get_subscription_usage_for_period_in_tx(
user_id,
trial_period_start_at,
trial_period_end_at,
&tx,
)
.await?;
if let Some(existing_usage) = existing_usage {
return Ok(Some(
self.create_subscription_usage_in_tx(
user_id,
new_period_start_at,
new_period_end_at,
new_subscription_kind.unwrap_or(existing_usage.plan),
existing_usage.model_requests,
existing_usage.edit_predictions,
&tx,
)
.await?,
));
}
}
_ => {}
}
Ok(None)
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(&*tx)
.await?)
})
.await
}

View File

@@ -1,5 +1,4 @@
mod provider_tests;
mod subscription_usage_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;

View File

@@ -1,69 +0,0 @@
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{UserId, billing_subscription};
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;
test_llm_db!(
test_transfer_existing_subscription_usage,
test_transfer_existing_subscription_usage_postgres
);
async fn test_transfer_existing_subscription_usage(db: &mut LlmDatabase) {
let user_id = UserId(1);
let now = Utc::now();
let trial_period_start_at = now - Duration::days(14);
let trial_period_end_at = now;
let new_period_start_at = now;
let new_period_end_at = now + Duration::days(30);
let existing_subscription = billing_subscription::Model {
kind: Some(SubscriptionKind::ZedProTrial),
stripe_current_period_start: Some(trial_period_start_at.timestamp()),
stripe_current_period_end: Some(trial_period_end_at.timestamp()),
..Default::default()
};
let existing_usage = db
.create_subscription_usage(
user_id,
trial_period_start_at,
trial_period_end_at,
SubscriptionKind::ZedProTrial,
25,
1_000,
)
.await
.unwrap();
let transferred_usage = db
.transfer_existing_subscription_usage(
user_id,
&existing_subscription,
Some(SubscriptionKind::ZedPro),
new_period_start_at,
new_period_end_at,
)
.await
.unwrap();
assert!(
transferred_usage.is_some(),
"subscription usage not transferred successfully"
);
let transferred_usage = transferred_usage.unwrap();
assert_eq!(
transferred_usage.model_requests,
existing_usage.model_requests
);
assert_eq!(
transferred_usage.edit_predictions,
existing_usage.edit_predictions
);
}

View File

@@ -1,5 +1,4 @@
use crate::Cents;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{billing_subscription, user};
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::{Config, db::billing_preference};
@@ -33,8 +32,6 @@ pub struct LlmTokenClaims {
pub plan: Plan,
#[serde(default)]
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
#[serde(default)]
pub can_use_web_search_tool: bool,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
@@ -46,6 +43,7 @@ impl LlmTokenClaims {
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
has_legacy_llm_subscription: bool,
plan: rpc::proto::Plan,
subscription: Option<billing_subscription::Model>,
system_id: Option<String>,
config: &Config,
@@ -72,7 +70,6 @@ impl LlmTokenClaims {
bypass_account_age_check: feature_flags
.iter()
.any(|flag| flag == "bypass-account-age-check"),
can_use_web_search_tool: feature_flags.iter().any(|flag| flag == "assistant2"),
has_llm_subscription: has_legacy_llm_subscription,
max_monthly_spend_in_cents: billing_preferences
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
@@ -81,14 +78,11 @@ impl LlmTokenClaims {
custom_llm_monthly_allowance_in_cents: user
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
plan: subscription
.as_ref()
.and_then(|subscription| subscription.kind)
.map_or(Plan::Free, |kind| match kind {
SubscriptionKind::ZedFree => Plan::Free,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
}),
plan: match plan {
rpc::proto::Plan::Free => Plan::Free,
rpc::proto::Plan::ZedPro => Plan::ZedPro,
rpc::proto::Plan::ZedProTrial => Plan::ZedProTrial,
},
subscription_period: maybe!({
let subscription = subscription?;
let period_start_at = subscription.current_period_start_at()?;

View File

@@ -4147,6 +4147,7 @@ async fn get_llm_api_token(
billing_preferences,
&flags,
has_legacy_llm_subscription,
session.current_plan(&db).await?,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,

View File

@@ -416,16 +416,9 @@ impl StripeBilling {
let mut params = stripe::CreateCheckoutSession::new();
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
trial_period_days: Some(14),
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
}
}),
..Default::default()
});
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {

View File

@@ -1,7 +1,7 @@
use call::ActiveCall;
use dap::DebugRequestType;
use dap::requests::{Initialize, Launch, StackTrace};
use dap::{SourceBreakpoint, requests::SetBreakpoints};
use dap::DebugRequestType;
use dap::{requests::SetBreakpoints, SourceBreakpoint};
use debugger_ui::debugger_panel::DebugPanel;
use debugger_ui::session::DebugSession;
use editor::Editor;
@@ -13,7 +13,7 @@ use std::{
path::Path,
sync::atomic::{AtomicBool, Ordering},
};
use workspace::{Workspace, dock::Panel};
use workspace::{dock::Panel, Workspace};
use super::{TestClient, TestServer};

View File

@@ -6,18 +6,17 @@ use collab_ui::{
channel_view::ChannelView,
notifications::project_shared_notification::ProjectSharedNotification,
};
use editor::{Editor, MultiBuffer, PathKey};
use editor::{Editor, ExcerptRange, MultiBuffer};
use gpui::{
AppContext as _, BackgroundExecutor, BorrowAppContext, Entity, SharedString, TestAppContext,
VisualContext, VisualTestContext, point,
VisualTestContext, point,
};
use language::Capability;
use project::WorktreeSettings;
use rpc::proto::PeerId;
use serde_json::json;
use settings::SettingsStore;
use text::{Point, ToPoint};
use util::{path, test::sample_text};
use util::path;
use workspace::{SplitDirection, Workspace, item::ItemHandle as _};
use super::TestClient;
@@ -296,20 +295,8 @@ async fn test_basic_following(
.unwrap()
});
let mut result = MultiBuffer::new(Capability::ReadWrite);
result.set_excerpts_for_path(
PathKey::for_buffer(&buffer_a1, cx),
buffer_a1,
[Point::row_range(1..2)],
1,
cx,
);
result.set_excerpts_for_path(
PathKey::for_buffer(&buffer_a2, cx),
buffer_a2,
[Point::row_range(5..6)],
1,
cx,
);
result.push_excerpts(buffer_a1, [ExcerptRange::new(0..3)], cx);
result.push_excerpts(buffer_a2, [ExcerptRange::new(4..7)], cx);
result
});
let multibuffer_editor_a = workspace_a.update_in(cx_a, |workspace, window, cx| {
@@ -2083,83 +2070,6 @@ async fn share_workspace(
.await
}
#[gpui::test]
async fn test_following_after_replacement(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
let (_server, client_a, client_b, channel) = TestServer::start2(cx_a, cx_b).await;
let (workspace, cx_a) = client_a.build_test_workspace(cx_a).await;
join_channel(channel, &client_a, cx_a).await.unwrap();
share_workspace(&workspace, cx_a).await.unwrap();
let buffer = workspace.update(cx_a, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.create_local_buffer(&sample_text(26, 5, 'a'), None, cx)
})
});
let multibuffer = cx_a.new(|cx| {
let mut mb = MultiBuffer::new(Capability::ReadWrite);
mb.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer.clone(),
[Point::row_range(1..1), Point::row_range(5..5)],
1,
cx,
);
mb
});
let snapshot = buffer.update(cx_a, |buffer, _| buffer.snapshot());
let editor: Entity<Editor> = cx_a.new_window_entity(|window, cx| {
Editor::for_multibuffer(
multibuffer.clone(),
Some(workspace.read(cx).project().clone()),
window,
cx,
)
});
workspace.update_in(cx_a, |workspace, window, cx| {
workspace.add_item_to_center(Box::new(editor.clone()) as _, window, cx)
});
editor.update_in(cx_a, |editor, window, cx| {
editor.change_selections(None, window, cx, |s| {
s.select_ranges([Point::row_range(4..4)]);
})
});
let positions = editor.update(cx_a, |editor, _| {
editor
.selections
.disjoint_anchor_ranges()
.map(|range| range.start.text_anchor.to_point(&snapshot))
.collect::<Vec<_>>()
});
multibuffer.update(cx_a, |multibuffer, cx| {
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
[Point::row_range(1..5)],
1,
cx,
);
});
let (workspace_b, cx_b) = client_b.join_workspace(channel, cx_b).await;
cx_b.run_until_parked();
let editor_b = workspace_b
.update(cx_b, |workspace, cx| {
workspace
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
})
.unwrap();
let new_positions = editor_b.update(cx_b, |editor, _| {
editor
.selections
.disjoint_anchor_ranges()
.map(|range| range.start.text_anchor.to_point(&snapshot))
.collect::<Vec<_>>()
});
assert_eq!(positions, new_positions);
}
#[gpui::test]
async fn test_following_to_channel_notes_other_workspace(
cx_a: &mut TestAppContext,

View File

@@ -1,14 +1,12 @@
use crate::tests::TestServer;
use call::ActiveCall;
use collections::{HashMap, HashSet};
use debugger_ui::debugger_panel::DebugPanel;
use dap::DapRegistry;
use extension::ExtensionHostProxy;
use fs::{FakeFs, Fs as _, RemoveOptions};
use futures::StreamExt as _;
use gpui::{
AppContext as _, BackgroundExecutor, SemanticVersion, TestAppContext, UpdateGlobal as _,
VisualContext,
};
use http_client::BlockedHttpClient;
use language::{
@@ -26,7 +24,6 @@ use project::{
};
use remote::SshRemoteClient;
use remote_server::{HeadlessAppState, HeadlessProject};
use rpc::proto;
use serde_json::json;
use settings::SettingsStore;
use std::{path::Path, sync::Arc};
@@ -89,6 +86,7 @@ async fn test_sharing_an_ssh_remote_project(
http_client: remote_http_client,
node_runtime: node,
languages,
debug_adapters: Arc::new(DapRegistry::fake()),
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
},
cx,
@@ -256,6 +254,7 @@ async fn test_ssh_collaboration_git_branches(
http_client: remote_http_client,
node_runtime: node,
languages,
debug_adapters: Arc::new(DapRegistry::fake()),
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
},
cx,
@@ -461,6 +460,7 @@ async fn test_ssh_collaboration_formatting_with_prettier(
http_client: remote_http_client,
node_runtime: NodeRuntime::unavailable(),
languages,
debug_adapters: Arc::new(DapRegistry::fake()),
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
},
cx,
@@ -579,108 +579,3 @@ async fn test_ssh_collaboration_formatting_with_prettier(
"Prettier formatting was not applied to client buffer after host's request"
);
}
#[gpui::test]
async fn test_remote_server_debugger(cx_a: &mut TestAppContext, server_cx: &mut TestAppContext) {
cx_a.update(|cx| {
release_channel::init(SemanticVersion::default(), cx);
command_palette_hooks::init(cx);
if std::env::var("RUST_LOG").is_ok() {
env_logger::try_init().ok();
}
});
server_cx.update(|cx| {
release_channel::init(SemanticVersion::default(), cx);
});
let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor());
remote_fs
.insert_tree(
path!("/code"),
json!({
"lib.rs": "fn one() -> usize { 1 }"
}),
)
.await;
// User A connects to the remote project via SSH.
server_cx.update(HeadlessProject::init);
let remote_http_client = Arc::new(BlockedHttpClient);
let node = NodeRuntime::unavailable();
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
let _headless_project = server_cx.new(|cx| {
client::init_settings(cx);
HeadlessProject::new(
HeadlessAppState {
session: server_ssh,
fs: remote_fs.clone(),
http_client: remote_http_client,
node_runtime: node,
languages,
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
},
cx,
)
});
let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await;
let mut server = TestServer::start(server_cx.executor()).await;
let client_a = server.create_client(cx_a, "user_a").await;
cx_a.update(|cx| {
debugger_ui::init(cx);
command_palette_hooks::init(cx);
});
let (project_a, _) = client_a
.build_ssh_project(path!("/code"), client_ssh.clone(), cx_a)
.await;
let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a);
let debugger_panel = workspace
.update_in(cx_a, |_workspace, window, cx| {
cx.spawn_in(window, DebugPanel::load)
})
.await
.unwrap();
workspace.update_in(cx_a, |workspace, window, cx| {
workspace.add_panel(debugger_panel, window, cx);
});
cx_a.run_until_parked();
let debug_panel = workspace
.update(cx_a, |workspace, cx| workspace.panel::<DebugPanel>(cx))
.unwrap();
let workspace_window = cx_a
.window_handle()
.downcast::<workspace::Workspace>()
.unwrap();
let session = debugger_ui::tests::start_debug_session(&workspace_window, cx_a, |_| {}).unwrap();
cx_a.run_until_parked();
debug_panel.update(cx_a, |debug_panel, cx| {
assert_eq!(
debug_panel.active_session().unwrap().read(cx).session(cx),
session
)
});
session.update(cx_a, |session, _| {
assert_eq!(session.binary().command, "ssh");
});
let shutdown_session = workspace.update(cx_a, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.dap_store().update(cx, |dap_store, cx| {
dap_store.shutdown_session(session.read(cx).session_id(), cx)
})
})
});
client_ssh.update(cx_a, |a, _| {
a.shutdown_processes(Some(proto::ShutdownRemoteServer {}))
});
shutdown_session.await.unwrap();
}

View File

@@ -14,7 +14,7 @@ use client::{
use clock::FakeSystemClock;
use collab_ui::channel_view::ChannelView;
use collections::{HashMap, HashSet};
use dap::DapRegistry;
use fs::FakeFs;
use futures::{StreamExt as _, channel::oneshot};
use git::GitHostingProviderRegistry;
@@ -275,12 +275,14 @@ impl TestServer {
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
let debug_adapters = Arc::new(DapRegistry::default());
let session = cx.new(|cx| AppSession::new(Session::test(), cx));
let app_state = Arc::new(workspace::AppState {
client: client.clone(),
user_store: user_store.clone(),
workspace_store,
languages: language_registry,
debug_adapters,
fs: fs.clone(),
build_window_options: |_, _| Default::default(),
node_runtime: NodeRuntime::unavailable(),
@@ -796,6 +798,7 @@ impl TestClient {
self.app_state.node_runtime.clone(),
self.app_state.user_store.clone(),
self.app_state.languages.clone(),
self.app_state.debug_adapters.clone(),
self.app_state.fs.clone(),
None,
cx,

View File

@@ -166,9 +166,6 @@ impl ComponentPreview {
component_preview.update_component_list(cx);
let focus_handle = component_preview.filter_editor.read(cx).focus_handle(cx);
window.focus(&focus_handle);
component_preview
}
@@ -782,13 +779,10 @@ impl Item for ComponentPreview {
fn added_to_workspace(
&mut self,
workspace: &mut Workspace,
window: &mut Window,
cx: &mut Context<Self>,
_window: &mut Window,
_cx: &mut Context<Self>,
) {
self.workspace_id = workspace.database_id();
let focus_handle = self.filter_editor.read(cx).focus_handle(cx);
window.focus(&focus_handle);
}
}

View File

@@ -39,7 +39,6 @@ log.workspace = true
node_runtime.workspace = true
parking_lot.workspace = true
paths.workspace = true
proto.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -3,8 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use async_trait::async_trait;
use collections::HashMap;
use dap_types::{StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest};
use dap_types::StartDebuggingRequestArguments;
use futures::io::BufReader;
use gpui::{AsyncApp, SharedString};
pub use http_client::{HttpClient, github::latest_github_release};
@@ -14,10 +13,16 @@ use serde::{Deserialize, Serialize};
use settings::WorktreeId;
use smol::{self, fs::File, lock::Mutex};
use std::{
borrow::Borrow, collections::HashSet, ffi::OsStr, fmt::Debug, net::Ipv4Addr, ops::Deref,
path::PathBuf, sync::Arc,
borrow::Borrow,
collections::{HashMap, HashSet},
ffi::{OsStr, OsString},
fmt::Debug,
net::Ipv4Addr,
ops::Deref,
path::PathBuf,
sync::Arc,
};
use task::{DebugTaskDefinition, TcpArgumentsTemplate};
use task::DebugTaskDefinition;
use util::ResultExt;
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -88,91 +93,17 @@ pub struct TcpArguments {
pub port: u16,
pub timeout: Option<u64>,
}
impl TcpArguments {
pub fn from_proto(proto: proto::TcpHost) -> anyhow::Result<Self> {
let host = TcpArgumentsTemplate::from_proto(proto)?;
Ok(TcpArguments {
host: host.host.ok_or_else(|| anyhow!("missing host"))?,
port: host.port.ok_or_else(|| anyhow!("missing port"))?,
timeout: host.timeout,
})
}
pub fn to_proto(&self) -> proto::TcpHost {
TcpArgumentsTemplate {
host: Some(self.host),
port: Some(self.port),
timeout: self.timeout,
}
.to_proto()
}
}
#[derive(Debug, Clone)]
pub struct DebugAdapterBinary {
pub adapter_name: DebugAdapterName,
pub command: String,
pub arguments: Vec<String>,
pub envs: HashMap<String, String>,
pub arguments: Option<Vec<OsString>>,
pub envs: Option<HashMap<String, String>>,
pub cwd: Option<PathBuf>,
pub connection: Option<TcpArguments>,
pub request_args: StartDebuggingRequestArguments,
}
impl DebugAdapterBinary {
pub fn from_proto(binary: proto::DebugAdapterBinary) -> anyhow::Result<Self> {
let request = match binary.launch_type() {
proto::debug_adapter_binary::LaunchType::Launch => {
StartDebuggingRequestArgumentsRequest::Launch
}
proto::debug_adapter_binary::LaunchType::Attach => {
StartDebuggingRequestArgumentsRequest::Attach
}
};
Ok(DebugAdapterBinary {
command: binary.command,
arguments: binary.arguments,
envs: binary.envs.into_iter().collect(),
connection: binary
.connection
.map(TcpArguments::from_proto)
.transpose()?,
request_args: StartDebuggingRequestArguments {
configuration: serde_json::from_str(&binary.configuration)?,
request,
},
cwd: binary.cwd.map(|cwd| cwd.into()),
})
}
pub fn to_proto(&self) -> proto::DebugAdapterBinary {
proto::DebugAdapterBinary {
command: self.command.clone(),
arguments: self.arguments.clone(),
envs: self
.envs
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
cwd: self
.cwd
.as_ref()
.map(|cwd| cwd.to_string_lossy().to_string()),
connection: self.connection.as_ref().map(|c| c.to_proto()),
launch_type: match self.request_args.request {
StartDebuggingRequestArgumentsRequest::Launch => {
proto::debug_adapter_binary::LaunchType::Launch.into()
}
StartDebuggingRequestArgumentsRequest::Attach => {
proto::debug_adapter_binary::LaunchType::Attach.into()
}
},
configuration: self.request_args.configuration.to_string(),
}
}
}
#[derive(Debug)]
pub struct AdapterVersion {
pub tag_name: String,
@@ -387,22 +318,22 @@ impl FakeAdapter {
fn request_args(&self, config: &DebugTaskDefinition) -> StartDebuggingRequestArguments {
use serde_json::json;
use task::DebugRequest;
use task::DebugRequestType;
let value = json!({
"request": match config.request {
DebugRequest::Launch(_) => "launch",
DebugRequest::Attach(_) => "attach",
DebugRequestType::Launch(_) => "launch",
DebugRequestType::Attach(_) => "attach",
},
"process_id": if let DebugRequest::Attach(attach_config) = &config.request {
"process_id": if let DebugRequestType::Attach(attach_config) = &config.request {
attach_config.process_id
} else {
None
},
});
let request = match config.request {
DebugRequest::Launch(_) => dap_types::StartDebuggingRequestArgumentsRequest::Launch,
DebugRequest::Attach(_) => dap_types::StartDebuggingRequestArgumentsRequest::Attach,
DebugRequestType::Launch(_) => dap_types::StartDebuggingRequestArgumentsRequest::Launch,
DebugRequestType::Attach(_) => dap_types::StartDebuggingRequestArgumentsRequest::Attach,
};
StartDebuggingRequestArguments {
configuration: value,
@@ -426,10 +357,11 @@ impl DebugAdapter for FakeAdapter {
_: &mut AsyncApp,
) -> Result<DebugAdapterBinary> {
Ok(DebugAdapterBinary {
adapter_name: Self::ADAPTER_NAME.into(),
command: "command".into(),
arguments: vec![],
arguments: None,
connection: None,
envs: HashMap::default(),
envs: None,
cwd: None,
request_args: self.request_args(config),
})

View File

@@ -1,5 +1,5 @@
use crate::{
adapters::DebugAdapterBinary,
adapters::{DebugAdapterBinary, DebugAdapterName},
transport::{IoKind, LogKind, TransportDelegate},
};
use anyhow::{Result, anyhow};
@@ -88,6 +88,7 @@ impl DebugAdapterClient {
) -> Result<Self> {
let binary = match self.transport_delegate.transport() {
crate::transport::Transport::Tcp(tcp_transport) => DebugAdapterBinary {
adapter_name: binary.adapter_name,
command: binary.command,
arguments: binary.arguments,
envs: binary.envs,
@@ -218,6 +219,9 @@ impl DebugAdapterClient {
self.id
}
pub fn name(&self) -> DebugAdapterName {
self.binary.adapter_name.clone()
}
pub fn binary(&self) -> &DebugAdapterBinary {
&self.binary
}
@@ -318,6 +322,7 @@ mod tests {
let client = DebugAdapterClient::start(
crate::client::SessionId(1),
DebugAdapterBinary {
adapter_name: "adapter".into(),
command: "command".into(),
arguments: Default::default(),
envs: Default::default(),
@@ -388,6 +393,7 @@ mod tests {
let client = DebugAdapterClient::start(
crate::client::SessionId(1),
DebugAdapterBinary {
adapter_name: "adapter".into(),
command: "command".into(),
arguments: Default::default(),
envs: Default::default(),
@@ -441,6 +447,7 @@ mod tests {
let client = DebugAdapterClient::start(
crate::client::SessionId(1),
DebugAdapterBinary {
adapter_name: "test-adapter".into(),
command: "command".into(),
arguments: Default::default(),
envs: Default::default(),

View File

@@ -7,7 +7,7 @@ pub mod transport;
pub use dap_types::*;
pub use registry::DapRegistry;
pub use task::DebugRequest;
pub use task::DebugRequestType;
pub type ScopeId = u64;
pub type VariableReference = u64;

View File

@@ -1,4 +1,3 @@
use gpui::{App, Global};
use parking_lot::RwLock;
use crate::adapters::{DebugAdapter, DebugAdapterName};
@@ -12,20 +11,8 @@ struct DapRegistryState {
#[derive(Clone, Default)]
/// Stores available debug adapters.
pub struct DapRegistry(Arc<RwLock<DapRegistryState>>);
impl Global for DapRegistry {}
impl DapRegistry {
pub fn global(cx: &mut App) -> &mut Self {
let ret = cx.default_global::<Self>();
#[cfg(any(test, feature = "test-support"))]
if ret.adapter(crate::FakeAdapter::ADAPTER_NAME).is_none() {
ret.add_adapter(Arc::new(crate::FakeAdapter::new()));
}
ret
}
pub fn add_adapter(&self, adapter: Arc<dyn DebugAdapter>) {
let name = adapter.name();
let _previous_value = self.0.write().adapters.insert(name, adapter);
@@ -34,12 +21,19 @@ impl DapRegistry {
"Attempted to insert a new debug adapter when one is already registered"
);
}
pub fn adapter(&self, name: &str) -> Option<Arc<dyn DebugAdapter>> {
self.0.read().adapters.get(name).cloned()
}
pub fn enumerate_adapters(&self) -> Vec<DebugAdapterName> {
self.0.read().adapters.keys().cloned().collect()
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake() -> Self {
use crate::FakeAdapter;
let register = Self::default();
register.add_adapter(Arc::new(FakeAdapter::new()));
register
}
}

View File

@@ -21,8 +21,8 @@ use std::{
sync::Arc,
time::Duration,
};
use task::TcpArgumentsTemplate;
use util::{ResultExt as _, TryFutureExt};
use task::TCPHost;
use util::ResultExt as _;
use crate::{adapters::DebugAdapterBinary, debugger_settings::DebuggerSettings};
@@ -74,14 +74,16 @@ pub enum Transport {
}
impl Transport {
async fn start(binary: &DebugAdapterBinary, cx: AsyncApp) -> Result<(TransportPipe, Self)> {
#[cfg(any(test, feature = "test-support"))]
async fn start(_: &DebugAdapterBinary, cx: AsyncApp) -> Result<(TransportPipe, Self)> {
#[cfg(any(test, feature = "test-support"))]
if cfg!(any(test, feature = "test-support")) {
return FakeTransport::start(cx)
.await
.map(|(transports, fake)| (transports, Self::Fake(fake)));
}
return FakeTransport::start(cx)
.await
.map(|(transports, fake)| (transports, Self::Fake(fake)));
}
#[cfg(not(any(test, feature = "test-support")))]
async fn start(binary: &DebugAdapterBinary, cx: AsyncApp) -> Result<(TransportPipe, Self)> {
if binary.connection.is_some() {
TcpTransport::start(binary, cx)
.await
@@ -126,7 +128,6 @@ pub(crate) struct TransportDelegate {
pending_requests: Requests,
transport: Transport,
server_tx: Arc<Mutex<Option<Sender<Message>>>>,
_tasks: Vec<gpui::Task<Option<()>>>,
}
impl TransportDelegate {
@@ -141,7 +142,6 @@ impl TransportDelegate {
log_handlers: Default::default(),
current_requests: Default::default(),
pending_requests: Default::default(),
_tasks: Default::default(),
};
let messages = this.start_handlers(transport_pipes, cx).await?;
Ok((messages, this))
@@ -168,43 +168,35 @@ impl TransportDelegate {
cx.update(|cx| {
if let Some(stdout) = params.stdout.take() {
self._tasks.push(
cx.background_executor()
.spawn(Self::handle_adapter_log(stdout, log_handler.clone()).log_err()),
);
cx.background_executor()
.spawn(Self::handle_adapter_log(stdout, log_handler.clone()))
.detach_and_log_err(cx);
}
self._tasks.push(
cx.background_executor().spawn(
Self::handle_output(
params.output,
client_tx,
self.pending_requests.clone(),
log_handler.clone(),
)
.log_err(),
),
);
cx.background_executor()
.spawn(Self::handle_output(
params.output,
client_tx,
self.pending_requests.clone(),
log_handler.clone(),
))
.detach_and_log_err(cx);
if let Some(stderr) = params.stderr.take() {
self._tasks.push(
cx.background_executor()
.spawn(Self::handle_error(stderr, self.log_handlers.clone()).log_err()),
);
cx.background_executor()
.spawn(Self::handle_error(stderr, self.log_handlers.clone()))
.detach_and_log_err(cx);
}
self._tasks.push(
cx.background_executor().spawn(
Self::handle_input(
params.input,
client_rx,
self.current_requests.clone(),
self.pending_requests.clone(),
log_handler.clone(),
)
.log_err(),
),
);
cx.background_executor()
.spawn(Self::handle_input(
params.input,
client_rx,
self.current_requests.clone(),
self.pending_requests.clone(),
log_handler.clone(),
))
.detach_and_log_err(cx);
})?;
{
@@ -377,7 +369,6 @@ impl TransportDelegate {
where
Stderr: AsyncRead + Unpin + Send + 'static,
{
log::debug!("Handle error started");
let mut buffer = String::new();
let mut reader = BufReader::new(stderr);
@@ -529,21 +520,18 @@ pub struct TcpTransport {
impl TcpTransport {
/// Get an open port to use with the tcp client when not supplied by debug config
pub async fn port(host: &TcpArgumentsTemplate) -> Result<u16> {
pub async fn port(host: &TCPHost) -> Result<u16> {
if let Some(port) = host.port {
Ok(port)
} else {
Self::unused_port(host.host()).await
Ok(TcpListener::bind(SocketAddrV4::new(host.host(), 0))
.await?
.local_addr()?
.port())
}
}
pub async fn unused_port(host: Ipv4Addr) -> Result<u16> {
Ok(TcpListener::bind(SocketAddrV4::new(host, 0))
.await?
.local_addr()?
.port())
}
#[allow(dead_code, reason = "This is used in non test builds of Zed")]
async fn start(binary: &DebugAdapterBinary, cx: AsyncApp) -> Result<(TransportPipe, Self)> {
let Some(connection_args) = binary.connection.as_ref() else {
return Err(anyhow!("No connection arguments provided"));
@@ -558,8 +546,13 @@ impl TcpTransport {
command.current_dir(cwd);
}
command.args(&binary.arguments);
command.envs(&binary.envs);
if let Some(args) = &binary.arguments {
command.args(args);
}
if let Some(envs) = &binary.envs {
command.envs(envs);
}
command
.stdin(Stdio::null())
@@ -642,8 +635,13 @@ impl StdioTransport {
command.current_dir(cwd);
}
command.args(&binary.arguments);
command.envs(&binary.envs);
if let Some(args) = &binary.arguments {
command.args(args);
}
if let Some(envs) = &binary.envs {
command.envs(envs);
}
command
.stdin(Stdio::piped())

View File

@@ -1,10 +1,10 @@
use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
use std::{path::PathBuf, sync::OnceLock};
use anyhow::{Result, bail};
use async_trait::async_trait;
use dap::adapters::latest_github_release;
use gpui::AsyncApp;
use task::{DebugRequest, DebugTaskDefinition};
use task::{DebugRequestType, DebugTaskDefinition};
use crate::*;
@@ -19,8 +19,8 @@ impl CodeLldbDebugAdapter {
fn request_args(&self, config: &DebugTaskDefinition) -> dap::StartDebuggingRequestArguments {
let mut configuration = json!({
"request": match config.request {
DebugRequest::Launch(_) => "launch",
DebugRequest::Attach(_) => "attach",
DebugRequestType::Launch(_) => "launch",
DebugRequestType::Attach(_) => "attach",
},
});
let map = configuration.as_object_mut().unwrap();
@@ -28,10 +28,10 @@ impl CodeLldbDebugAdapter {
map.insert("name".into(), Value::String(config.label.clone()));
let request = config.request.to_dap();
match &config.request {
DebugRequest::Attach(attach) => {
DebugRequestType::Attach(attach) => {
map.insert("pid".into(), attach.process_id.into());
}
DebugRequest::Launch(launch) => {
DebugRequestType::Launch(launch) => {
map.insert("program".into(), launch.program.clone().into());
if !launch.args.is_empty() {
@@ -140,13 +140,16 @@ impl DebugAdapter for CodeLldbDebugAdapter {
.ok_or_else(|| anyhow!("Adapter path is expected to be valid UTF-8"))?;
Ok(DebugAdapterBinary {
command,
cwd: None,
arguments: vec![
cwd: Some(adapter_dir),
arguments: Some(vec![
"--settings".into(),
json!({"sourceLanguages": ["cpp", "rust"]}).to_string(),
],
json!({"sourceLanguages": ["cpp", "rust"]})
.to_string()
.into(),
]),
request_args: self.request_args(config),
envs: HashMap::default(),
adapter_name: "test".into(),
envs: None,
connection: None,
})
}

View File

@@ -11,7 +11,7 @@ use anyhow::{Result, anyhow};
use async_trait::async_trait;
use codelldb::CodeLldbDebugAdapter;
use dap::{
DapRegistry, DebugRequest,
DapRegistry, DebugRequestType,
adapters::{
self, AdapterVersion, DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName,
GithubRepo,
@@ -19,26 +19,23 @@ use dap::{
};
use gdb::GdbDebugAdapter;
use go::GoDebugAdapter;
use gpui::{App, BorrowAppContext};
use javascript::JsDebugAdapter;
use php::PhpDebugAdapter;
use python::PythonDebugAdapter;
use serde_json::{Value, json};
use task::TcpArgumentsTemplate;
use task::TCPHost;
pub fn init(cx: &mut App) {
cx.update_default_global(|registry: &mut DapRegistry, _cx| {
registry.add_adapter(Arc::from(CodeLldbDebugAdapter::default()));
registry.add_adapter(Arc::from(PythonDebugAdapter));
registry.add_adapter(Arc::from(PhpDebugAdapter));
registry.add_adapter(Arc::from(JsDebugAdapter));
registry.add_adapter(Arc::from(GoDebugAdapter));
registry.add_adapter(Arc::from(GdbDebugAdapter));
})
pub fn init(registry: Arc<DapRegistry>) {
registry.add_adapter(Arc::from(CodeLldbDebugAdapter::default()));
registry.add_adapter(Arc::from(PythonDebugAdapter));
registry.add_adapter(Arc::from(PhpDebugAdapter));
registry.add_adapter(Arc::from(JsDebugAdapter));
registry.add_adapter(Arc::from(GoDebugAdapter));
registry.add_adapter(Arc::from(GdbDebugAdapter));
}
pub(crate) async fn configure_tcp_connection(
tcp_connection: TcpArgumentsTemplate,
tcp_connection: TCPHost,
) -> Result<(Ipv4Addr, u16, Option<u64>)> {
let host = tcp_connection.host();
let timeout = tcp_connection.timeout;
@@ -56,7 +53,7 @@ trait ToDap {
fn to_dap(&self) -> dap::StartDebuggingRequestArgumentsRequest;
}
impl ToDap for DebugRequest {
impl ToDap for DebugRequestType {
fn to_dap(&self) -> dap::StartDebuggingRequestArgumentsRequest {
match self {
Self::Launch(_) => dap::StartDebuggingRequestArgumentsRequest::Launch,

View File

@@ -1,10 +1,10 @@
use std::{collections::HashMap, ffi::OsStr};
use std::ffi::OsStr;
use anyhow::{Result, bail};
use async_trait::async_trait;
use dap::StartDebuggingRequestArguments;
use gpui::AsyncApp;
use task::{DebugRequest, DebugTaskDefinition};
use task::{DebugRequestType, DebugTaskDefinition};
use crate::*;
@@ -17,18 +17,18 @@ impl GdbDebugAdapter {
fn request_args(&self, config: &DebugTaskDefinition) -> StartDebuggingRequestArguments {
let mut args = json!({
"request": match config.request {
DebugRequest::Launch(_) => "launch",
DebugRequest::Attach(_) => "attach",
DebugRequestType::Launch(_) => "launch",
DebugRequestType::Attach(_) => "attach",
},
});
let map = args.as_object_mut().unwrap();
match &config.request {
DebugRequest::Attach(attach) => {
DebugRequestType::Attach(attach) => {
map.insert("pid".into(), attach.process_id.into());
}
DebugRequest::Launch(launch) => {
DebugRequestType::Launch(launch) => {
map.insert("program".into(), launch.program.clone().into());
if !launch.args.is_empty() {
@@ -82,9 +82,10 @@ impl DebugAdapter for GdbDebugAdapter {
let gdb_path = user_setting_path.unwrap_or(gdb_path?);
Ok(DebugAdapterBinary {
adapter_name: Self::ADAPTER_NAME.into(),
command: gdb_path,
arguments: vec!["-i=dap".into()],
envs: HashMap::default(),
arguments: Some(vec!["-i=dap".into()]),
envs: None,
cwd: None,
connection: None,
request_args: self.request_args(config),

View File

@@ -1,6 +1,6 @@
use dap::StartDebuggingRequestArguments;
use gpui::AsyncApp;
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
use std::{ffi::OsStr, path::PathBuf};
use task::DebugTaskDefinition;
use crate::*;
@@ -12,12 +12,12 @@ impl GoDebugAdapter {
const ADAPTER_NAME: &'static str = "Delve";
fn request_args(&self, config: &DebugTaskDefinition) -> StartDebuggingRequestArguments {
let mut args = match &config.request {
dap::DebugRequest::Attach(attach_config) => {
dap::DebugRequestType::Attach(attach_config) => {
json!({
"processId": attach_config.process_id,
})
}
dap::DebugRequest::Launch(launch_config) => json!({
dap::DebugRequestType::Launch(launch_config) => json!({
"program": launch_config.program,
"cwd": launch_config.cwd,
"args": launch_config.args
@@ -92,14 +92,15 @@ impl DebugAdapter for GoDebugAdapter {
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
Ok(DebugAdapterBinary {
adapter_name: self.name(),
command: delve_path,
arguments: vec![
arguments: Some(vec![
"dap".into(),
"--listen".into(),
format!("{}:{}", host, port),
],
format!("{}:{}", host, port).into(),
]),
cwd: None,
envs: HashMap::default(),
envs: None,
connection: Some(adapters::TcpArguments {
host,
port,

View File

@@ -1,8 +1,8 @@
use adapters::latest_github_release;
use dap::StartDebuggingRequestArguments;
use gpui::AsyncApp;
use std::{collections::HashMap, path::PathBuf};
use task::{DebugRequest, DebugTaskDefinition};
use std::path::PathBuf;
use task::{DebugRequestType, DebugTaskDefinition};
use crate::*;
@@ -18,16 +18,16 @@ impl JsDebugAdapter {
let mut args = json!({
"type": "pwa-node",
"request": match config.request {
DebugRequest::Launch(_) => "launch",
DebugRequest::Attach(_) => "attach",
DebugRequestType::Launch(_) => "launch",
DebugRequestType::Attach(_) => "attach",
},
});
let map = args.as_object_mut().unwrap();
match &config.request {
DebugRequest::Attach(attach) => {
DebugRequestType::Attach(attach) => {
map.insert("processId".into(), attach.process_id.into());
}
DebugRequest::Launch(launch) => {
DebugRequestType::Launch(launch) => {
map.insert("program".into(), launch.program.clone().into());
if !launch.args.is_empty() {
@@ -106,22 +106,20 @@ impl DebugAdapter for JsDebugAdapter {
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
Ok(DebugAdapterBinary {
adapter_name: self.name(),
command: delegate
.node_runtime()
.binary_path()
.await?
.to_string_lossy()
.into_owned(),
arguments: vec![
adapter_path
.join(Self::ADAPTER_PATH)
.to_string_lossy()
.to_string(),
port.to_string(),
host.to_string(),
],
arguments: Some(vec![
adapter_path.join(Self::ADAPTER_PATH).into(),
port.to_string().into(),
host.to_string().into(),
]),
cwd: None,
envs: HashMap::default(),
envs: None,
connection: Some(adapters::TcpArguments {
host,
port,

View File

@@ -1,7 +1,7 @@
use adapters::latest_github_release;
use dap::adapters::TcpArguments;
use gpui::AsyncApp;
use std::{collections::HashMap, path::PathBuf};
use std::path::PathBuf;
use task::DebugTaskDefinition;
use crate::*;
@@ -19,18 +19,20 @@ impl PhpDebugAdapter {
config: &DebugTaskDefinition,
) -> Result<dap::StartDebuggingRequestArguments> {
match &config.request {
dap::DebugRequest::Attach(_) => {
dap::DebugRequestType::Attach(_) => {
anyhow::bail!("php adapter does not support attaching")
}
dap::DebugRequest::Launch(launch_config) => Ok(dap::StartDebuggingRequestArguments {
configuration: json!({
"program": launch_config.program,
"cwd": launch_config.cwd,
"args": launch_config.args,
"stopOnEntry": config.stop_on_entry.unwrap_or_default(),
}),
request: config.request.to_dap(),
}),
dap::DebugRequestType::Launch(launch_config) => {
Ok(dap::StartDebuggingRequestArguments {
configuration: json!({
"program": launch_config.program,
"cwd": launch_config.cwd,
"args": launch_config.args,
"stopOnEntry": config.stop_on_entry.unwrap_or_default(),
}),
request: config.request.to_dap(),
})
}
}
}
}
@@ -92,26 +94,24 @@ impl DebugAdapter for PhpDebugAdapter {
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
Ok(DebugAdapterBinary {
adapter_name: self.name(),
command: delegate
.node_runtime()
.binary_path()
.await?
.to_string_lossy()
.into_owned(),
arguments: vec![
adapter_path
.join(Self::ADAPTER_PATH)
.to_string_lossy()
.to_string(),
format!("--server={}", port),
],
arguments: Some(vec![
adapter_path.join(Self::ADAPTER_PATH).into(),
format!("--server={}", port).into(),
]),
connection: Some(TcpArguments {
port,
host,
timeout,
}),
cwd: None,
envs: HashMap::default(),
envs: None,
request_args: self.request_args(config)?,
})
}

View File

@@ -1,7 +1,7 @@
use crate::*;
use dap::{DebugRequest, StartDebuggingRequestArguments};
use dap::{DebugRequestType, StartDebuggingRequestArguments};
use gpui::AsyncApp;
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
use std::{ffi::OsStr, path::PathBuf};
use task::DebugTaskDefinition;
#[derive(Default)]
@@ -16,18 +16,18 @@ impl PythonDebugAdapter {
fn request_args(&self, config: &DebugTaskDefinition) -> StartDebuggingRequestArguments {
let mut args = json!({
"request": match config.request {
DebugRequest::Launch(_) => "launch",
DebugRequest::Attach(_) => "attach",
DebugRequestType::Launch(_) => "launch",
DebugRequestType::Attach(_) => "attach",
},
"subProcess": true,
"redirectOutput": true,
});
let map = args.as_object_mut().unwrap();
match &config.request {
DebugRequest::Attach(attach) => {
DebugRequestType::Attach(attach) => {
map.insert("processId".into(), attach.process_id.into());
}
DebugRequest::Launch(launch) => {
DebugRequestType::Launch(launch) => {
map.insert("program".into(), launch.program.clone().into());
map.insert("args".into(), launch.args.clone().into());
@@ -141,22 +141,20 @@ impl DebugAdapter for PythonDebugAdapter {
};
Ok(DebugAdapterBinary {
adapter_name: self.name(),
command: python_path.ok_or(anyhow!("failed to find binary path for python"))?,
arguments: vec![
debugpy_dir
.join(Self::ADAPTER_PATH)
.to_string_lossy()
.to_string(),
format!("--port={}", port),
format!("--host={}", host),
],
arguments: Some(vec![
debugpy_dir.join(Self::ADAPTER_PATH).into(),
format!("--port={}", port).into(),
format!("--host={}", host).into(),
]),
connection: Some(adapters::TcpArguments {
host,
port,
timeout,
}),
cwd: None,
envs: HashMap::default(),
envs: None,
request_args: self.request_args(config),
})
}

View File

@@ -12,9 +12,6 @@ workspace = true
path = "src/debugger_tools.rs"
doctest = false
[features]
test-support = []
[dependencies]
anyhow.workspace = true
dap.workspace = true

View File

@@ -41,7 +41,7 @@ struct DapLogView {
_subscriptions: Vec<Subscription>,
}
pub struct LogStore {
struct LogStore {
projects: HashMap<WeakEntity<Project>, ProjectState>,
debug_clients: HashMap<SessionId, DebugAdapterState>,
rpc_tx: UnboundedSender<(SessionId, IoKind, String)>,
@@ -101,7 +101,7 @@ impl DebugAdapterState {
}
impl LogStore {
pub fn new(cx: &Context<Self>) -> Self {
fn new(cx: &Context<Self>) -> Self {
let (rpc_tx, mut rpc_rx) = unbounded::<(SessionId, IoKind, String)>();
cx.spawn(async move |this, cx| {
while let Some((client_id, io_kind, message)) = rpc_rx.next().await {
@@ -566,13 +566,11 @@ impl DapLogView {
.dap_store()
.read(cx)
.sessions()
.filter_map(|session| {
let session = session.read(cx);
session.adapter_name();
let client = session.adapter_client()?;
.filter_map(|client| {
let client = client.read(cx).adapter_client()?;
Some(DapMenuItem {
client_id: client.id(),
client_name: session.adapter_name().to_string(),
client_name: client.name().0.as_ref().into(),
has_adapter_logs: client.has_adapter_logs(),
selected_entry: self.current_view.map_or(LogKind::Adapter, |(_, kind)| kind),
})
@@ -845,29 +843,3 @@ impl EventEmitter<Event> for LogStore {}
impl EventEmitter<Event> for DapLogView {}
impl EventEmitter<EditorEvent> for DapLogView {}
impl EventEmitter<SearchEvent> for DapLogView {}
#[cfg(any(test, feature = "test-support"))]
impl LogStore {
pub fn contained_session_ids(&self) -> Vec<SessionId> {
self.debug_clients.keys().cloned().collect()
}
pub fn rpc_messages_for_session_id(&self, session_id: SessionId) -> Vec<String> {
self.debug_clients
.get(&session_id)
.expect("This session should exist if a test is calling")
.rpc_messages
.messages
.clone()
.into()
}
pub fn log_messages_for_session_id(&self, session_id: SessionId) -> Vec<String> {
self.debug_clients
.get(&session_id)
.expect("This session should exist if a test is calling")
.log_messages
.clone()
.into()
}
}

View File

@@ -20,9 +20,6 @@ test-support = [
"project/test-support",
"util/test-support",
"workspace/test-support",
"env_logger",
"unindent",
"debugger_tools"
]
[dependencies]
@@ -40,7 +37,6 @@ gpui.workspace = true
language.workspace = true
log.workspace = true
menu.workspace = true
parking_lot.workspace = true
picker.workspace = true
pretty_assertions.workspace = true
project.workspace = true
@@ -57,13 +53,9 @@ ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
env_logger = { workspace = true, optional = true }
debugger_tools = { workspace = true, optional = true }
unindent = { workspace = true, optional = true }
[dev-dependencies]
dap = { workspace = true, features = ["test-support"] }
debugger_tools = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -1,7 +1,7 @@
use dap::DebugRequest;
use dap::DebugRequestType;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::Subscription;
use gpui::{DismissEvent, Entity, EventEmitter, Focusable, Render};
use gpui::{Subscription, WeakEntity};
use picker::{Picker, PickerDelegate};
use std::sync::Arc;
@@ -9,9 +9,7 @@ use sysinfo::System;
use ui::{Context, Tooltip, prelude::*};
use ui::{ListItem, ListItemSpacing};
use util::debug_panic;
use workspace::{ModalView, Workspace};
use crate::debugger_panel::DebugPanel;
use workspace::ModalView;
#[derive(Debug, Clone)]
pub(super) struct Candidate {
@@ -24,19 +22,19 @@ pub(crate) struct AttachModalDelegate {
selected_index: usize,
matches: Vec<StringMatch>,
placeholder_text: Arc<str>,
workspace: WeakEntity<Workspace>,
project: Entity<project::Project>,
pub(crate) debug_config: task::DebugTaskDefinition,
candidates: Arc<[Candidate]>,
}
impl AttachModalDelegate {
fn new(
workspace: Entity<Workspace>,
project: Entity<project::Project>,
debug_config: task::DebugTaskDefinition,
candidates: Arc<[Candidate]>,
) -> Self {
Self {
workspace: workspace.downgrade(),
project,
debug_config,
candidates,
selected_index: 0,
@@ -53,7 +51,7 @@ pub struct AttachModal {
impl AttachModal {
pub fn new(
workspace: Entity<Workspace>,
project: Entity<project::Project>,
debug_config: task::DebugTaskDefinition,
modal: bool,
window: &mut Window,
@@ -77,11 +75,11 @@ impl AttachModal {
.collect();
processes.sort_by_key(|k| k.name.clone());
let processes = processes.into_iter().collect();
Self::with_processes(workspace, debug_config, processes, modal, window, cx)
Self::with_processes(project, debug_config, processes, modal, window, cx)
}
pub(super) fn with_processes(
workspace: Entity<Workspace>,
project: Entity<project::Project>,
debug_config: task::DebugTaskDefinition,
processes: Arc<[Candidate]>,
modal: bool,
@@ -90,7 +88,7 @@ impl AttachModal {
) -> Self {
let picker = cx.new(|cx| {
Picker::uniform_list(
AttachModalDelegate::new(workspace, debug_config, processes),
AttachModalDelegate::new(project, debug_config, processes),
window,
cx,
)
@@ -204,7 +202,7 @@ impl PickerDelegate for AttachModalDelegate {
})
}
fn confirm(&mut self, _: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let candidate = self
.matches
.get(self.selected_index())
@@ -218,26 +216,23 @@ impl PickerDelegate for AttachModalDelegate {
};
match &mut self.debug_config.request {
DebugRequest::Attach(config) => {
DebugRequestType::Attach(config) => {
config.process_id = Some(candidate.pid);
}
DebugRequest::Launch(_) => {
DebugRequestType::Launch(_) => {
debug_panic!("Debugger attach modal used on launch debug config");
return;
}
}
let definition = self.debug_config.clone();
let panel = self
.workspace
.update(cx, |workspace, cx| workspace.panel::<DebugPanel>(cx))
.ok()
.flatten();
if let Some(panel) = panel {
panel.update(cx, |panel, cx| {
panel.start_session(definition, window, cx);
});
}
let config = self.debug_config.clone();
self.project
.update(cx, |project, cx| {
let ret = project.start_debug_session(config, cx);
ret
})
.detach_and_log_err(cx);
cx.emit(DismissEvent);
}

View File

@@ -6,7 +6,6 @@ use crate::{new_session_modal::NewSessionModal, session::DebugSession};
use anyhow::{Result, anyhow};
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use dap::StartDebuggingRequestArguments;
use dap::{
ContinuedEvent, LoadedSourceEvent, ModuleEvent, OutputEvent, StoppedEvent, ThreadEvent,
client::SessionId, debugger_settings::DebuggerSettings,
@@ -18,7 +17,6 @@ use gpui::{
actions, anchored, deferred,
};
use project::debugger::session::{Session, SessionStateEvent};
use project::{
Project,
debugger::{
@@ -32,9 +30,10 @@ use settings::Settings;
use std::any::TypeId;
use std::path::Path;
use std::sync::Arc;
use task::{DebugTaskDefinition, DebugTaskTemplate};
use task::DebugTaskDefinition;
use terminal_view::terminal_panel::TerminalPanel;
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
use util::debug_panic;
use workspace::{
Workspace,
dock::{DockPosition, Panel, PanelEvent},
@@ -64,7 +63,7 @@ pub struct DebugPanel {
active_session: Option<Entity<DebugSession>>,
/// This represents the last debug definition that was created in the new session modal
pub(crate) past_debug_definition: Option<DebugTaskDefinition>,
project: Entity<Project>,
project: WeakEntity<Project>,
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
@@ -98,10 +97,10 @@ impl DebugPanel {
window,
|panel, _, event: &tasks_ui::ShowAttachModal, window, cx| {
panel.workspace.update(cx, |workspace, cx| {
let workspace_handle = cx.entity().clone();
let project = workspace.project().clone();
workspace.toggle_modal(window, cx, |window, cx| {
crate::attach_modal::AttachModal::new(
workspace_handle,
project,
event.debug_config.clone(),
true,
window,
@@ -128,7 +127,7 @@ impl DebugPanel {
_subscriptions,
past_debug_definition: None,
focus_handle: cx.focus_handle(),
project,
project: project.downgrade(),
workspace: workspace.weak_handle(),
context_menu: None,
};
@@ -220,7 +219,7 @@ impl DebugPanel {
pub fn load(
workspace: WeakEntity<Workspace>,
cx: &mut AsyncWindowContext,
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
workspace.update_in(cx, |workspace, window, cx| {
@@ -246,226 +245,114 @@ impl DebugPanel {
});
})
.detach();
workspace.set_debugger_provider(DebuggerProvider(debug_panel.clone()));
debug_panel
})
})
}
pub fn start_session(
&mut self,
definition: DebugTaskDefinition,
window: &mut Window,
cx: &mut Context<Self>,
) {
let task_contexts = self
.workspace
.update(cx, |workspace, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})
.ok();
let dap_store = self.project.read(cx).dap_store().clone();
cx.spawn_in(window, async move |this, cx| {
let task_context = if let Some(task) = task_contexts {
task.await
.active_worktree_context
.map_or(task::TaskContext::default(), |context| context.1)
} else {
task::TaskContext::default()
};
let (session, task) = dap_store.update(cx, |dap_store, cx| {
let template = DebugTaskTemplate {
locator: None,
definition: definition.clone(),
};
let session = if let Some(debug_config) = template
.to_zed_format()
.resolve_task("debug_task", &task_context)
.and_then(|resolved_task| resolved_task.resolved_debug_adapter_config())
{
dap_store.new_session(debug_config.definition, None, cx)
} else {
dap_store.new_session(definition.clone(), None, cx)
};
(session.clone(), dap_store.boot_session(session, cx))
})?;
match task.await {
Err(e) => {
this.update(cx, |this, cx| {
this.workspace
.update(cx, |workspace, cx| {
workspace.show_error(&e, cx);
})
.ok();
})
.ok();
session
.update(cx, |session, cx| session.shutdown(cx))?
.await;
}
Ok(_) => Self::register_session(this, session, cx).await?,
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
async fn register_session(
this: WeakEntity<Self>,
session: Entity<Session>,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let adapter_name = session.update(cx, |session, _| session.adapter_name())?;
this.update_in(cx, |_, window, cx| {
cx.subscribe_in(
&session,
window,
move |_, session, event: &SessionStateEvent, window, cx| match event {
SessionStateEvent::Restart => {
let mut curr_session = session.clone();
while let Some(parent_session) = curr_session
.read_with(cx, |session, _| session.parent_session().cloned())
{
curr_session = parent_session;
}
let definition = curr_session.update(cx, |session, _| session.definition());
let task = curr_session.update(cx, |session, cx| session.shutdown(cx));
let definition = definition.clone();
cx.spawn_in(window, async move |this, cx| {
task.await;
this.update_in(cx, |this, window, cx| {
this.start_session(definition, window, cx)
})
})
.detach_and_log_err(cx);
}
_ => {}
},
)
.detach();
})
.ok();
let serialized_layout = persistence::get_serialized_pane_layout(adapter_name).await;
let workspace = this.update_in(cx, |this, window, cx| {
this.sessions.retain(|session| {
session
.read(cx)
.mode()
.as_running()
.map_or(false, |running_state| {
!running_state.read(cx).session().read(cx).is_terminated()
})
});
let session_item = DebugSession::running(
this.project.clone(),
this.workspace.clone(),
session,
cx.weak_entity(),
serialized_layout,
window,
cx,
);
if let Some(running) = session_item.read(cx).mode().as_running().cloned() {
// We might want to make this an event subscription and only notify when a new thread is selected
// This is used to filter the command menu correctly
cx.observe(&running, |_, _, cx| cx.notify()).detach();
}
this.sessions.push(session_item.clone());
this.activate_session(session_item, window, cx);
this.workspace.clone()
})?;
workspace.update_in(cx, |workspace, window, cx| {
workspace.focus_panel::<Self>(window, cx);
})?;
Ok(())
}
pub fn start_child_session(
&mut self,
request: &StartDebuggingRequestArguments,
parent_session: Entity<Session>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(worktree) = parent_session.read(cx).worktree() else {
log::error!("Attempted to start a child session from non local debug session");
return;
};
let dap_store_handle = self.project.read(cx).dap_store().clone();
let breakpoint_store = self.project.read(cx).breakpoint_store();
let definition = parent_session.read(cx).definition().clone();
let mut binary = parent_session.read(cx).binary().clone();
binary.request_args = request.clone();
cx.spawn_in(window, async move |this, cx| {
let (session, task) = dap_store_handle.update(cx, |dap_store, cx| {
let session =
dap_store.new_session(definition.clone(), Some(parent_session.clone()), cx);
let task = session.update(cx, |session, cx| {
session.boot(
binary,
worktree,
breakpoint_store,
dap_store_handle.downgrade(),
cx,
)
});
(session, task)
})?;
match task.await {
Err(e) => {
this.update(cx, |this, cx| {
this.workspace
.update(cx, |workspace, cx| {
workspace.show_error(&e, cx);
})
.ok();
})
.ok();
session
.update(cx, |session, cx| session.shutdown(cx))?
.await;
}
Ok(_) => Self::register_session(this, session, cx).await?,
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
pub fn active_session(&self) -> Option<Entity<DebugSession>> {
self.active_session.clone()
}
pub fn debug_panel_items_by_client(
&self,
client_id: &SessionId,
cx: &Context<Self>,
) -> Vec<Entity<DebugSession>> {
self.sessions
.iter()
.filter(|item| item.read(cx).session_id(cx) == *client_id)
.map(|item| item.clone())
.collect()
}
pub fn debug_panel_item_by_client(
&self,
client_id: SessionId,
cx: &mut Context<Self>,
) -> Option<Entity<DebugSession>> {
self.sessions
.iter()
.find(|item| {
let item = item.read(cx);
item.session_id(cx) == client_id
})
.cloned()
}
fn handle_dap_store_event(
&mut self,
_dap_store: &Entity<DapStore>,
dap_store: &Entity<DapStore>,
event: &dap_store::DapStoreEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
match event {
dap_store::DapStoreEvent::DebugSessionInitialized(session_id) => {
let Some(session) = dap_store.read(cx).session_by_id(session_id) else {
return log::error!(
"Couldn't get session with id: {session_id:?} from DebugClientStarted event"
);
};
let adapter_name = session.read(cx).adapter_name();
let session_id = *session_id;
cx.spawn_in(window, async move |this, cx| {
let serialized_layout =
persistence::get_serialized_pane_layout(adapter_name).await;
this.update_in(cx, |this, window, cx| {
let Some(project) = this.project.upgrade() else {
return log::error!(
"Debug Panel out lived it's weak reference to Project"
);
};
if this
.sessions
.iter()
.any(|item| item.read(cx).session_id(cx) == session_id)
{
// We already have an item for this session.
debug_panic!("We should never reuse session ids");
return;
}
this.sessions.retain(|session| {
session
.read(cx)
.mode()
.as_running()
.map_or(false, |running_state| {
!running_state.read(cx).session().read(cx).is_terminated()
})
});
let session_item = DebugSession::running(
project,
this.workspace.clone(),
session,
cx.weak_entity(),
serialized_layout,
window,
cx,
);
if let Some(running) = session_item.read(cx).mode().as_running().cloned() {
// We might want to make this an event subscription and only notify when a new thread is selected
// This is used to filter the command menu correctly
cx.observe(&running, |_, _, cx| cx.notify()).detach();
}
this.sessions.push(session_item.clone());
this.activate_session(session_item, window, cx);
})
})
.detach();
}
dap_store::DapStoreEvent::RunInTerminal {
title,
cwd,
@@ -487,12 +374,6 @@ impl DebugPanel {
)
.detach_and_log_err(cx);
}
dap_store::DapStoreEvent::SpawnChildSession {
request,
parent_session,
} => {
self.start_child_session(request, parent_session.clone(), window, cx);
}
_ => {}
}
}
@@ -527,7 +408,7 @@ impl DebugPanel {
cwd,
title,
},
task::RevealStrategy::Never,
task::RevealStrategy::Always,
window,
cx,
);
@@ -587,6 +468,8 @@ impl DebugPanel {
let session = this.dap_store().read(cx).session_by_id(session_id);
session.map(|session| !session.read(cx).is_terminated())
})
.ok()
.flatten()
.unwrap_or_default();
cx.spawn_in(window, async move |this, cx| {
@@ -1010,6 +893,7 @@ impl DebugPanel {
impl EventEmitter<PanelEvent> for DebugPanel {}
impl EventEmitter<DebugPanelEvent> for DebugPanel {}
impl EventEmitter<project::Event> for DebugPanel {}
impl Focusable for DebugPanel {
fn focus_handle(&self, _: &App) -> FocusHandle {
@@ -1155,15 +1039,3 @@ impl Render for DebugPanel {
.into_any()
}
}
struct DebuggerProvider(Entity<DebugPanel>);
impl workspace::DebuggerProvider for DebuggerProvider {
fn start_session(&self, definition: DebugTaskDefinition, window: &mut Window, cx: &mut App) {
self.0.update(cx, |_, cx| {
cx.defer_in(window, |this, window, cx| {
this.start_session(definition, window, cx);
})
})
}
}

View File

@@ -16,7 +16,7 @@ mod new_session_modal;
mod persistence;
pub(crate) mod session;
#[cfg(any(test, feature = "test-support"))]
#[cfg(test)]
pub mod tests;
actions!(

View File

@@ -4,14 +4,16 @@ use std::{
path::{Path, PathBuf},
};
use dap::{DapRegistry, DebugRequest};
use anyhow::{Result, anyhow};
use dap::DebugRequestType;
use editor::{Editor, EditorElement, EditorStyle};
use gpui::{
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle,
WeakEntity,
};
use project::Project;
use settings::Settings;
use task::{DebugTaskDefinition, DebugTaskTemplate, LaunchRequest};
use task::{DebugTaskDefinition, LaunchConfig};
use theme::ThemeSettings;
use ui::{
ActiveTheme, Button, ButtonCommon, ButtonSize, CheckboxWithLabel, Clickable, Color, Context,
@@ -19,6 +21,7 @@ use ui::{
LabelCommon as _, ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton,
ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
};
use util::ResultExt;
use workspace::{ModalView, Workspace};
use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
@@ -34,9 +37,9 @@ pub(super) struct NewSessionModal {
last_selected_profile_name: Option<SharedString>,
}
fn suggested_label(request: &DebugRequest, debugger: &str) -> String {
fn suggested_label(request: &DebugRequestType, debugger: &str) -> String {
match request {
DebugRequest::Launch(config) => {
DebugRequestType::Launch(config) => {
let last_path_component = Path::new(&config.program)
.file_name()
.map(|name| name.to_string_lossy())
@@ -44,7 +47,7 @@ fn suggested_label(request: &DebugRequest, debugger: &str) -> String {
format!("{} ({debugger})", last_path_component)
}
DebugRequest::Attach(config) => format!(
DebugRequestType::Attach(config) => format!(
"pid: {} ({debugger})",
config.process_id.unwrap_or(u32::MAX)
),
@@ -68,7 +71,7 @@ impl NewSessionModal {
.and_then(|def| def.stop_on_entry);
let launch_config = match past_debug_definition.map(|def| def.request) {
Some(DebugRequest::Launch(launch_config)) => Some(launch_config),
Some(DebugRequestType::Launch(launch_config)) => Some(launch_config),
_ => None,
};
@@ -85,38 +88,39 @@ impl NewSessionModal {
}
}
fn debug_config(&self, cx: &App, debugger: &str) -> DebugTaskDefinition {
fn debug_config(&self, cx: &App) -> Option<DebugTaskDefinition> {
let request = self.mode.debug_task(cx);
DebugTaskDefinition {
adapter: debugger.to_owned(),
label: suggested_label(&request, debugger),
Some(DebugTaskDefinition {
adapter: self.debugger.clone()?.to_string(),
label: suggested_label(&request, self.debugger.as_deref()?),
request,
initialize_args: self.initialize_args.clone(),
tcp_connection: None,
locator: None,
stop_on_entry: match self.stop_on_entry {
ToggleState::Selected => Some(true),
_ => None,
},
}
})
}
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) {
let Some(debugger) = self.debugger.as_ref() else {
// todo: show in UI.
log::error!("No debugger selected");
return;
};
let config = self.debug_config(cx, debugger);
let debug_panel = self.debug_panel.clone();
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) -> Result<()> {
let workspace = self.workspace.clone();
let config = self
.debug_config(cx)
.ok_or_else(|| anyhow!("Failed to create a debug config"))?;
let task_contexts = self
.workspace
let _ = self.debug_panel.update(cx, |panel, _| {
panel.past_debug_definition = Some(config.clone());
});
let task_contexts = workspace
.update(cx, |workspace, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})
.ok();
cx.spawn_in(window, async move |this, cx| {
cx.spawn(async move |this, cx| {
let task_context = if let Some(task) = task_contexts {
task.await
.active_worktree_context
@@ -124,29 +128,39 @@ impl NewSessionModal {
} else {
task::TaskContext::default()
};
let project = workspace.update(cx, |workspace, _| workspace.project().clone())?;
debug_panel.update_in(cx, |debug_panel, window, cx| {
let template = DebugTaskTemplate {
locator: None,
definition: config.clone(),
};
if let Some(debug_config) = template
.to_zed_format()
.resolve_task("debug_task", &task_context)
.and_then(|resolved_task| resolved_task.resolved_debug_adapter_config())
let task = project.update(cx, |this, cx| {
if let Some(debug_config) =
config
.clone()
.to_zed_format()
.ok()
.and_then(|task_template| {
task_template
.resolve_task("debug_task", &task_context)
.and_then(|resolved_task| {
resolved_task.resolved_debug_adapter_config()
})
})
{
debug_panel.start_session(debug_config.definition, window, cx)
this.start_debug_session(debug_config, cx)
} else {
debug_panel.start_session(config, window, cx)
this.start_debug_session(config, cx)
}
})?;
this.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
let spawn_result = task.await;
if spawn_result.is_ok() {
this.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
spawn_result?;
anyhow::Result::<_, anyhow::Error>::Ok(())
})
.detach_and_log_err(cx);
Ok(())
}
fn update_attach_picker(
@@ -200,7 +214,12 @@ impl NewSessionModal {
};
let available_adapters = workspace
.update(cx, |_, cx| DapRegistry::global(cx).enumerate_adapters())
.update(cx, |this, cx| {
this.project()
.read(cx)
.debug_adapters()
.enumerate_adapters()
})
.ok()
.unwrap_or_default();
@@ -232,20 +251,23 @@ impl NewSessionModal {
this.debugger = Some(task.adapter.clone().into());
this.initialize_args = task.initialize_args.clone();
match &task.request {
DebugRequest::Launch(launch_config) => {
DebugRequestType::Launch(launch_config) => {
this.mode = NewSessionMode::launch(
Some(launch_config.clone()),
window,
cx,
);
}
DebugRequest::Attach(_) => {
let Some(workspace) = this.workspace.upgrade() else {
DebugRequestType::Attach(_) => {
let Ok(project) = this
.workspace
.read_with(cx, |this, _| this.project().clone())
else {
return;
};
this.mode = NewSessionMode::attach(
this.debugger.clone(),
workspace,
project,
window,
cx,
);
@@ -263,7 +285,7 @@ impl NewSessionModal {
}
};
let available_adapters: Vec<DebugTaskTemplate> = workspace
let available_adapters: Vec<DebugTaskDefinition> = workspace
.update(cx, |this, cx| {
this.project()
.read(cx)
@@ -281,9 +303,9 @@ impl NewSessionModal {
for debug_definition in available_adapters {
menu = menu.entry(
debug_definition.definition.label.clone(),
debug_definition.label.clone(),
None,
setter_for_name(debug_definition.definition),
setter_for_name(debug_definition),
);
}
menu
@@ -300,7 +322,7 @@ struct LaunchMode {
impl LaunchMode {
fn new(
past_launch_config: Option<LaunchRequest>,
past_launch_config: Option<LaunchConfig>,
window: &mut Window,
cx: &mut App,
) -> Entity<Self> {
@@ -326,9 +348,9 @@ impl LaunchMode {
cx.new(|_| Self { program, cwd })
}
fn debug_task(&self, cx: &App) -> task::LaunchRequest {
fn debug_task(&self, cx: &App) -> task::LaunchConfig {
let path = self.cwd.read(cx).text(cx);
task::LaunchRequest {
task::LaunchConfig {
program: self.program.read(cx).text(cx),
cwd: path.is_empty().not().then(|| PathBuf::from(path)),
args: Default::default(),
@@ -345,20 +367,21 @@ struct AttachMode {
impl AttachMode {
fn new(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Entity<Self> {
let debug_definition = DebugTaskDefinition {
label: "Attach New Session Setup".into(),
request: dap::DebugRequest::Attach(task::AttachRequest { process_id: None }),
request: dap::DebugRequestType::Attach(task::AttachConfig { process_id: None }),
tcp_connection: None,
adapter: debugger.clone().unwrap_or_default().into(),
locator: None,
initialize_args: None,
stop_on_entry: Some(false),
};
let attach_picker = cx.new(|cx| {
let modal = AttachModal::new(workspace, debug_definition.clone(), false, window, cx);
let modal = AttachModal::new(project, debug_definition.clone(), false, window, cx);
window.focus(&modal.focus_handle(cx));
modal
@@ -368,8 +391,8 @@ impl AttachMode {
attach_picker,
})
}
fn debug_task(&self) -> task::AttachRequest {
task::AttachRequest { process_id: None }
fn debug_task(&self) -> task::AttachConfig {
task::AttachConfig { process_id: None }
}
}
@@ -383,7 +406,7 @@ enum NewSessionMode {
}
impl NewSessionMode {
fn debug_task(&self, cx: &App) -> DebugRequest {
fn debug_task(&self, cx: &App) -> DebugRequestType {
match self {
NewSessionMode::Launch(entity) => entity.read(cx).debug_task(cx).into(),
NewSessionMode::Attach(entity) => entity.read(cx).debug_task().into(),
@@ -458,14 +481,14 @@ impl RenderOnce for NewSessionMode {
impl NewSessionMode {
fn attach(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
Self::Attach(AttachMode::new(debugger, project, window, cx))
}
fn launch(
past_launch_config: Option<LaunchRequest>,
past_launch_config: Option<LaunchConfig>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
@@ -557,12 +580,15 @@ impl Render for NewSessionModal {
.toggle_state(matches!(self.mode, NewSessionMode::Attach(_)))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
let Some(workspace) = this.workspace.upgrade() else {
let Ok(project) = this
.workspace
.read_with(cx, |this, _| this.project().clone())
else {
return;
};
this.mode = NewSessionMode::attach(
this.debugger.clone(),
workspace,
project,
window,
cx,
);
@@ -616,7 +642,7 @@ impl Render for NewSessionModal {
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx);
this.start_new_session(window, cx).log_err();
}))
.disabled(self.debugger.is_none()),
),

View File

@@ -88,12 +88,6 @@ impl DebugSession {
}
}
pub fn session(&self, cx: &App) -> Entity<Session> {
match &self.mode {
DebugSessionState::Running(entity) => entity.read(cx).session().clone(),
}
}
pub(crate) fn shutdown(&mut self, cx: &mut Context<Self>) {
match &self.mode {
DebugSessionState::Running(state) => state.update(cx, |state, cx| state.shutdown(cx)),
@@ -121,7 +115,13 @@ impl DebugSession {
};
self.label
.get_or_init(|| session.read(cx).label())
.get_or_init(|| {
session
.read(cx)
.as_local()
.expect("Remote Debug Sessions are not implemented yet")
.label()
})
.to_owned()
}

Some files were not shown because too many files have changed in this diff Show More