Compare commits

..

2 Commits

Author SHA1 Message Date
Thomas Mickley-Doyle
fcb71cda2a Update repo for eval data 2025-04-22 16:08:58 -05:00
Thomas Mickley-Doyle
5145a8dfca Add eval example for cli todo app 2025-04-22 11:58:26 -05:00
271 changed files with 4700 additions and 13137 deletions

View File

@@ -3,6 +3,12 @@ name: Run Agent Eval
on:
schedule:
- cron: "0 * * * *"
push:
branches:
- main
- "v[0-9]+.[0-9]+.x"
tags:
- "v*"
pull_request:
branches:

428
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -296,7 +296,6 @@ livekit_api = { path = "crates/livekit_api" }
livekit_client = { path = "crates/livekit_client" }
lmstudio = { path = "crates/lmstudio" }
lsp = { path = "crates/lsp" }
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
markdown = { path = "crates/markdown" }
markdown_preview = { path = "crates/markdown_preview" }
media = { path = "crates/media" }
@@ -606,7 +605,7 @@ wasmtime-wasi = "29"
which = "6.0.0"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "0.7.1"
zed_llm_client = "0.6.1"
zstd = "0.11"
metal = "0.29"

View File

@@ -31,9 +31,6 @@ If appropriate, use tool calls to explore the current project, which contains th
- When looking for symbols in the project, prefer the `grep` tool.
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
- Bias towards not asking the user for help if you can find the answer yourself.
{{! TODO: Only mention tools if they are enabled }}
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
## Fixing Diagnostics

View File

@@ -646,7 +646,7 @@
"fetch": true,
"list_directory": false,
"now": true,
"find_path": true,
"path_search": true,
"read_file": true,
"grep": true,
"thinking": true,
@@ -670,7 +670,7 @@
"list_directory": true,
"move_path": false,
"now": false,
"find_path": true,
"path_search": true,
"read_file": true,
"grep": true,
"rename": false,

View File

@@ -6,9 +6,7 @@ use crate::thread::{
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
use crate::ui::{
AddedContext, AgentNotification, AgentNotificationEvent, AnimatedLabel, ContextPill,
};
use crate::ui::{AddedContext, AgentNotification, AgentNotificationEvent, ContextPill};
use crate::{AssistantPanel, OpenActiveThreadAsMarkdown};
use anyhow::Context as _;
use assistant_settings::{AssistantSettings, NotifyWhenAgentWaiting};
@@ -433,39 +431,47 @@ fn render_markdown_code_block(
workspace
.update(cx, {
|workspace, cx| {
let Some(project_path) = workspace
if let Some(project_path) = workspace
.project()
.read(cx)
.find_project_path(&path_range.path, cx)
else {
return;
};
let Some(target) = path_range.range.as_ref().map(|range| {
Point::new(
// Line number is 1-based
range.start.line.saturating_sub(1),
range.start.col.unwrap_or(0),
)
}) else {
return;
};
let open_task =
workspace.open_path(project_path, None, true, window, cx);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) = item.downcast::<Editor>() {
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
target, window, cx,
);
})
.ok();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
{
let target = path_range.range.as_ref().map(|range| {
Point::new(
// Line number is 1-based
range.start.line.saturating_sub(1),
range.start.col.unwrap_or(0),
)
});
let open_task = workspace.open_path(
project_path,
None,
true,
window,
cx,
);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(target) = target {
if let Some(active_editor) =
item.downcast::<Editor>()
{
active_editor
.downgrade()
.update_in(cx, |editor, window, cx| {
editor
.go_to_singleton_buffer_point(
target, window, cx,
);
})
.log_err();
}
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
}
})
.ok();
@@ -799,11 +805,10 @@ impl ActiveThread {
self.thread.read(cx).summary_or_default()
}
pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool {
pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
self.last_error.take();
self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
})
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx))
}
pub fn last_error(&self) -> Option<ThreadError> {
@@ -1307,7 +1312,7 @@ impl ActiveThread {
fn confirm_editing_message(
&mut self,
_: &menu::Confirm,
window: &mut Window,
_: &mut Window,
cx: &mut Context<Self>,
) {
let Some((message_id, state)) = self.editing_message.take() else {
@@ -1337,7 +1342,7 @@ impl ActiveThread {
self.thread.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model.model, Some(window.window_handle()), cx);
thread.send_to_model(model.model, cx)
});
cx.notify();
}
@@ -1499,8 +1504,45 @@ impl ActiveThread {
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
let generating_label = (is_generating && is_last_message)
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
let generating_label = (is_generating && is_last_message).then(|| {
Label::new("Generating")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animations(
"generating-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),
],
|mut label, animation_ix, delta| {
match animation_ix {
0 => {
let chars_to_show = (delta * 10.).ceil() as usize;
let text = &"Generating"[0..chars_to_show];
label.set_text(text);
}
1 => {
let text = match delta {
d if d < 0.25 => "Generating",
d if d < 0.5 => "Generating.",
d if d < 0.75 => "Generating..",
_ => "Generating...",
};
label.set_text(text);
}
_ => {}
}
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.map_element(|label| label.alpha(delta)),
)
});
// Don't render user messages that are just there for returning tool results.
if message.role == Role::User && thread.message_has_tool_results(message_id) {
@@ -1538,14 +1580,13 @@ impl ActiveThread {
window.dispatch_action(Box::new(OpenActiveThreadAsMarkdown), cx)
});
// For all items that should be aligned with the LLM's response.
// For all items that should be aligned with the Assistant's response.
const RESPONSE_PADDING_X: Pixels = px(18.);
let feedback_container = h_flex()
.py_2()
.px(RESPONSE_PADDING_X)
.gap_1()
.flex_wrap()
.justify_between();
let feedback_items = match self.thread.read(cx).message_feedback(message_id) {
Some(feedback) => feedback_container
@@ -1557,8 +1598,7 @@ impl ActiveThread {
}
})
.color(Color::Muted)
.size(LabelSize::XSmall)
.truncate(),
.size(LabelSize::XSmall),
)
.child(
h_flex()
@@ -1609,8 +1649,7 @@ impl ActiveThread {
"Rating the thread sends all of your current conversation to the Zed team.",
)
.color(Color::Muted)
.size(LabelSize::XSmall)
.truncate(),
.size(LabelSize::XSmall),
)
.child(
h_flex()
@@ -1657,7 +1696,7 @@ impl ActiveThread {
let message_content =
has_content.then(|| {
v_flex()
.gap_1()
.gap_1p5()
.when(!message_is_empty, |parent| {
parent.child(
if let Some(edit_message_editor) = edit_message_editor.clone() {
@@ -1680,6 +1719,7 @@ impl ActiveThread {
.on_action(cx.listener(Self::cancel_editing_message))
.on_action(cx.listener(Self::confirm_editing_message))
.min_h_6()
.pt_1()
.child(EditorElement::new(
&edit_message_editor,
EditorStyle {
@@ -1757,7 +1797,7 @@ impl ActiveThread {
.border_1()
.border_color(colors.border)
.shadow_md()
.child(div().p_2().children(message_content))
.child(div().py_2().px_2p5().children(message_content))
.child(
h_flex()
.p_1()
@@ -1846,9 +1886,11 @@ impl ActiveThread {
.gap_2()
.children(message_content)
.when(has_tool_uses, |parent| {
parent.children(tool_uses.into_iter().map(|tool_use| {
self.render_tool_use(tool_use, window, workspace.clone(), cx)
}))
parent.children(
tool_uses
.into_iter()
.map(|tool_use| self.render_tool_use(tool_use, window, cx)),
)
}),
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
@@ -2262,7 +2304,34 @@ impl ActiveThread {
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(AnimatedLabel::new("Thinking").size(LabelSize::Small)),
.child({
Label::new("Thinking")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animation(
"generating-label",
Animation::new(Duration::from_secs(1)).repeat(),
|mut label, delta| {
let text = match delta {
d if d < 0.25 => "Thinking",
d if d < 0.5 => "Thinking.",
d if d < 0.75 => "Thinking..",
_ => "Thinking...",
};
label.set_text(text);
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| {
label.map_element(|label| label.alpha(delta))
},
)
}),
)
.child(
h_flex()
@@ -2441,11 +2510,10 @@ impl ActiveThread {
&self,
tool_use: ToolUse,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement + use<> {
if let Some(card) = self.thread.read(cx).card_for_tool(&tool_use.id) {
return card.render(&tool_use.status, window, workspace, cx);
return card.render(&tool_use.status, window, cx);
}
let is_open = self
@@ -2461,7 +2529,7 @@ impl ActiveThread {
.upgrade()
.map(|workspace| workspace.read(cx).app_state().fs.clone());
let needs_confirmation = matches!(&tool_use.status, ToolUseStatus::NeedsConfirmation);
let needs_confirmation_tools = tool_use.needs_confirmation;
let edit_tools = tool_use.needs_confirmation;
let status_icons = div().child(match &tool_use.status {
ToolUseStatus::NeedsConfirmation => {
@@ -2559,33 +2627,33 @@ impl ActiveThread {
)),
),
ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child(
results_content_container()
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
h_flex()
.gap_1()
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(
delta,
)))
},
),
)
.child(
Label::new("Running…")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
),
results_content_container().child(
h_flex()
.gap_1()
.pb_1()
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
Icon::new(IconName::ArrowCircle)
.size(IconSize::Small)
.color(Color::Accent)
.with_animation(
"arrow-circle",
Animation::new(Duration::from_secs(2)).repeat(),
|icon, delta| {
icon.transform(Transformation::rotate(percentage(
delta,
)))
},
),
)
.child(
Label::new("Running…")
.size(LabelSize::XSmall)
.color(Color::Muted)
.buffer_font(cx),
),
),
),
ToolUseStatus::Error(_) => container.child(
results_content_container()
@@ -2649,8 +2717,8 @@ impl ActiveThread {
))
};
v_flex().gap_1().mb_2().map(|element| {
if !needs_confirmation_tools {
v_flex().gap_1().mb_3().map(|element| {
if !edit_tools {
element.child(
v_flex()
.child(
@@ -2828,7 +2896,30 @@ impl ActiveThread {
.border_color(self.tool_card_border_color(cx))
.rounded_b_lg()
.child(
AnimatedLabel::new("Waiting for Confirmation").size(LabelSize::Small)
Label::new("Waiting for Confirmation")
.color(Color::Muted)
.size(LabelSize::Small)
.with_animation(
"generating-label",
Animation::new(Duration::from_secs(1)).repeat(),
|mut label, delta| {
let text = match delta {
d if d < 0.25 => "Waiting for Confirmation",
d if d < 0.5 => "Waiting for Confirmation.",
d if d < 0.75 => "Waiting for Confirmation..",
_ => "Waiting for Confirmation...",
};
label.set_text(text);
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.map_element(|label| label.alpha(delta)),
),
)
.child(
h_flex()
@@ -3042,7 +3133,7 @@ impl ActiveThread {
&mut self,
tool_use_id: LanguageModelToolUseId,
_: &ClickEvent,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
@@ -3058,7 +3149,6 @@ impl ActiveThread {
c.input.clone(),
&c.messages,
c.tool.clone(),
Some(window.window_handle()),
cx,
);
});
@@ -3070,12 +3160,11 @@ impl ActiveThread {
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
_: &ClickEvent,
window: &mut Window,
_window: &mut Window,
cx: &mut Context<Self>,
) {
let window_handle = window.window_handle();
self.thread.update(cx, |thread, cx| {
thread.deny_tool_use(tool_use_id, tool_name, Some(window_handle), cx);
thread.deny_tool_use(tool_use_id, tool_name, cx);
});
}

View File

@@ -1,4 +1,4 @@
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent, ui::AnimatedLabel};
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
use anyhow::Result;
use buffer_diff::DiffHunkStatus;
use collections::{HashMap, HashSet};
@@ -8,8 +8,8 @@ use editor::{
scroll::Autoscroll,
};
use gpui::{
Action, AnyElement, AnyView, App, Empty, Entity, EventEmitter, FocusHandle, Focusable,
SharedString, Subscription, Task, WeakEntity, Window, prelude::*,
Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
Subscription, Task, WeakEntity, Window, prelude::*,
};
use language::{Capability, DiskState, OffsetRangeExt, Point};
use multi_buffer::PathKey;
@@ -307,10 +307,6 @@ impl AgentDiff {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.thread.read(cx).is_generating() {
return;
}
let snapshot = self.multibuffer.read(cx).snapshot(cx);
let diff_hunks_in_ranges = self
.editor
@@ -343,10 +339,6 @@ impl AgentDiff {
window: &mut Window,
cx: &mut Context<Self>,
) {
if self.thread.read(cx).is_generating() {
return;
}
let snapshot = self.multibuffer.read(cx).snapshot(cx);
let diff_hunks_in_ranges = self
.editor
@@ -658,11 +650,6 @@ fn render_diff_hunk_controls(
cx: &mut App,
) -> AnyElement {
let editor = editor.clone();
if agent_diff.read(cx).thread.read(cx).is_generating() {
return Empty.into_any();
}
h_flex()
.h(line_height)
.mr_0p5()
@@ -870,14 +857,8 @@ impl Render for AgentDiffToolbar {
None => return div(),
};
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
if is_generating {
return div()
.w(rems(6.5625)) // Arbitrary 105px size—so the label doesn't dance around
.child(AnimatedLabel::new("Generating"));
}
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
if is_empty {
return div();
}
@@ -988,7 +969,7 @@ mod tests {
.await
.unwrap();
cx.update(|_, cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit(

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use assistant_settings::{
AgentProfile, AgentProfileContent, AgentProfileId, AssistantSettings, AssistantSettingsContent,
ContextServerPresetContent,
ContextServerPresetContent, VersionedAssistantSettingsContent,
};
use assistant_tool::{ToolSource, ToolWorkingSet};
use fs::Fs;
@@ -201,10 +201,10 @@ impl PickerDelegate for ToolPickerDelegate {
let profile_id = self.profile_id.clone();
let default_profile = self.profile.clone();
let tool = tool.clone();
move |settings: &mut AssistantSettingsContent, _cx| {
settings
.v2_setting(|v2_settings| {
let profiles = v2_settings.profiles.get_or_insert_default();
move |settings, _cx| match settings {
AssistantSettingsContent::Versioned(boxed) => {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
let profile =
profiles
.entry(profile_id)
@@ -240,10 +240,9 @@ impl PickerDelegate for ToolPickerDelegate {
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
}
}
Ok(())
})
.ok();
}
}
_ => {}
}
});
}

View File

@@ -12,13 +12,13 @@ use assistant_settings::{AssistantDockPosition, AssistantSettings};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
use client::{UserStore, zed_urls};
use client::zed_urls;
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
Corner, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels,
Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity,
EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels, Subscription, Task,
UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
};
use language::LanguageRegistry;
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
@@ -180,7 +180,6 @@ impl ActiveView {
pub struct AssistantPanel {
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
project: Entity<Project>,
fs: Arc<dyn Fs>,
language_registry: Arc<LanguageRegistry>,
@@ -244,7 +243,6 @@ impl AssistantPanel {
) -> Self {
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let fs = workspace.app_state().fs.clone();
let user_store = workspace.app_state().user_store.clone();
let project = workspace.project();
let language_registry = project.read(cx).languages().clone();
let workspace = workspace.weak_handle();
@@ -309,7 +307,6 @@ impl AssistantPanel {
Self {
active_view,
workspace,
user_store,
project: project.clone(),
fs: fs.clone(),
language_registry,
@@ -359,9 +356,14 @@ impl AssistantPanel {
&self.thread_store
}
fn cancel(&mut self, _: &editor::actions::Cancel, window: &mut Window, cx: &mut Context<Self>) {
fn cancel(
&mut self,
_: &editor::actions::Cancel,
_window: &mut Window,
cx: &mut Context<Self>,
) {
self.thread
.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
}
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
@@ -1546,19 +1548,9 @@ impl AssistantPanel {
}
fn render_usage_banner(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
let plan = self
.user_store
.read(cx)
.current_plan()
.map(|plan| match plan {
Plan::Free => zed_llm_client::Plan::Free,
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
})
.unwrap_or(zed_llm_client::Plan::Free);
let usage = self.thread.read(cx).last_usage()?;
Some(UsageBanner::new(plan, usage).into_any_element())
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage.amount).into_any_element())
}
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
@@ -1613,8 +1605,6 @@ impl AssistantPanel {
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(ERROR_MESSAGE))
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
@@ -1661,8 +1651,6 @@ impl AssistantPanel {
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(ERROR_MESSAGE))
.child(
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
cx.listener(|this, _, _, cx| {
@@ -1728,8 +1716,6 @@ impl AssistantPanel {
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(error_message))
.child(
Button::new("subscribe", call_to_action).on_click(cx.listener(
|this, _, _, cx| {
@@ -1761,7 +1747,6 @@ impl AssistantPanel {
message: SharedString,
cx: &mut Context<Self>,
) -> AnyElement {
let message_with_header = format!("{}\n{}", header, message);
v_flex()
.gap_0p5()
.child(
@@ -1776,14 +1761,12 @@ impl AssistantPanel {
.id("error-message")
.max_h_32()
.overflow_y_scroll()
.child(Label::new(message.clone())),
.child(Label::new(message)),
)
.child(
h_flex()
.justify_end()
.mt_1()
.gap_1()
.child(self.create_copy_button(message_with_header))
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|this, _, _, cx| {
this.thread.update(cx, |this, _cx| {
@@ -1797,15 +1780,6 @@ impl AssistantPanel {
.into_any()
}
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
let message = message.into();
IconButton::new("copy", IconName::Copy)
.on_click(move |_, _, cx| {
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
})
.tooltip(Tooltip::text("Copy Error Message"))
}
fn key_context(&self) -> KeyContext {
let mut key_context = KeyContext::new_with_defaults();
key_context.add("AgentPanel");

View File

@@ -1328,7 +1328,7 @@ impl InlineAssistant {
editor.highlight_rows::<InlineAssist>(
row_range,
cx.theme().status().info_background,
Default::default(),
false,
cx,
);
}
@@ -1393,7 +1393,7 @@ impl InlineAssistant {
editor.highlight_rows::<DeletedLines>(
Anchor::min()..Anchor::max(),
cx.theme().status().deleted_background,
Default::default(),
false,
cx,
);
editor

View File

@@ -195,7 +195,6 @@ impl MessageEditor {
editor.set_mode(EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: false,
})
} else {
editor.set_mode(EditorMode::AutoHeight {
@@ -278,7 +277,6 @@ impl MessageEditor {
let context_store = self.context_store.clone();
let git_store = self.project.read(cx).git_store().clone();
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
let window_handle = window.window_handle();
cx.spawn(async move |this, cx| {
let checkpoint = checkpoint.await.ok();
@@ -335,7 +333,7 @@ impl MessageEditor {
thread
.update(cx, |thread, cx| {
thread.advance_prompt_id();
thread.send_to_model(model, Some(window_handle), cx);
thread.send_to_model(model, cx);
})
.log_err();
})
@@ -343,9 +341,9 @@ impl MessageEditor {
}
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let cancelled = self.thread.update(cx, |thread, cx| {
thread.cancel_last_completion(Some(window.window_handle()), cx)
});
let cancelled = self
.thread
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
if cancelled {
self.set_editor_is_expanded(false, cx);
@@ -407,10 +405,8 @@ impl MessageEditor {
});
}
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.edits_expanded = true;
fn handle_review_click(&self, window: &mut Window, cx: &mut Context<Self>) {
AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
cx.notify();
}
fn handle_file_click(

View File

@@ -13,9 +13,7 @@ use feature_flags::{self, FeatureFlagAppExt};
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
};
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
@@ -317,7 +315,6 @@ pub struct Thread {
request_callback: Option<
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
remaining_turns: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -371,7 +368,6 @@ impl Thread {
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
}
}
@@ -446,7 +442,6 @@ impl Thread {
message_feedback: HashMap::default(),
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
}
}
@@ -527,7 +522,7 @@ impl Thread {
self.messages.iter().find(|message| message.id == id)
}
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
pub fn messages(&self) -> impl Iterator<Item = &Message> {
self.messages.iter()
}
@@ -621,12 +616,24 @@ impl Thread {
.await
.unwrap_or(false);
if !equal {
if equal {
git_store
.update(cx, |store, cx| {
store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
})?
.detach();
} else {
this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
})?;
}
git_store
.update(cx, |store, cx| {
store.delete_checkpoint(final_checkpoint, cx)
})?
.detach();
Ok(())
}
Err(_) => this.update(cx, |this, cx| {
@@ -760,18 +767,24 @@ impl Thread {
for ctx in &new_context {
match ctx {
AssistantContext::File(file_ctx) => {
log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
}
AssistantContext::Directory(dir_ctx) => {
for context_buffer in &dir_ctx.context_buffers {
log.track_buffer(context_buffer.buffer.clone(), cx);
log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
}
}
AssistantContext::Symbol(symbol_ctx) => {
log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
log.buffer_added_as_context(
symbol_ctx.context_symbol.buffer.clone(),
cx,
);
}
AssistantContext::Selection(selection_context) => {
log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
log.buffer_added_as_context(
selection_context.context_buffer.buffer.clone(),
cx,
);
}
AssistantContext::FetchedUrl(_)
| AssistantContext::Thread(_)
@@ -945,26 +958,7 @@ impl Thread {
})
}
pub fn remaining_turns(&self) -> u32 {
self.remaining_turns
}
pub fn set_remaining_turns(&mut self, remaining_turns: u32) {
self.remaining_turns = remaining_turns;
}
pub fn send_to_model(
&mut self,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
if self.remaining_turns == 0 {
return;
}
self.remaining_turns -= 1;
pub fn send_to_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
let mut request = self.to_completion_request(cx);
if model.supports_tools() {
request.tools = {
@@ -989,7 +983,7 @@ impl Thread {
};
}
self.stream_completion(request, model, window, cx);
self.stream_completion(request, model, cx);
}
pub fn used_tools_since_last_user_message(&self) -> bool {
@@ -1103,10 +1097,7 @@ impl Thread {
self.tool_use
.attach_tool_uses(message.id, &mut request_message);
// Skip empty messages to avoid sending them to the LLM model
// if !request_message.content.is_empty() {
request.messages.push(request_message);
// }
}
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
@@ -1211,7 +1202,6 @@ impl Thread {
&mut self,
request: LanguageModelRequest,
model: Arc<dyn LanguageModel>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let pending_completion_id = post_inc(&mut self.completion_count);
@@ -1393,7 +1383,7 @@ impl Thread {
match result.as_ref() {
Ok(stop_reason) => match stop_reason {
StopReason::ToolUse => {
let tool_uses = thread.use_pending_tools(window, cx);
let tool_uses = thread.use_pending_tools(cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
StopReason::EndTurn => {}
@@ -1438,7 +1428,7 @@ impl Thread {
}));
}
thread.cancel_last_completion(window, cx);
thread.cancel_last_completion(cx);
}
}
cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
@@ -1607,11 +1597,7 @@ impl Thread {
)
}
pub fn use_pending_tools(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> Vec<PendingToolUse> {
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
self.auto_capture_telemetry(cx);
let request = self.to_completion_request(cx);
let messages = Arc::new(request.messages);
@@ -1643,7 +1629,6 @@ impl Thread {
tool_use.input.clone(),
&messages,
tool,
window,
cx,
);
}
@@ -1660,10 +1645,9 @@ impl Thread {
input: serde_json::Value,
messages: &[LanguageModelRequestMessage],
tool: Arc<dyn Tool>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) {
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, window, cx);
let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
self.tool_use
.run_pending_tool(tool_use_id, ui_text.into(), task);
}
@@ -1674,7 +1658,6 @@ impl Thread {
messages: &[LanguageModelRequestMessage],
input: serde_json::Value,
tool: Arc<dyn Tool>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Thread>,
) -> Task<()> {
let tool_name: Arc<str> = tool.name().into();
@@ -1687,7 +1670,6 @@ impl Thread {
messages,
self.project.clone(),
self.action_log.clone(),
window,
cx,
)
};
@@ -1710,7 +1692,7 @@ impl Thread {
output,
cx,
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
})
.ok();
}
@@ -1722,22 +1704,18 @@ impl Thread {
tool_use_id: LanguageModelToolUseId,
pending_tool_use: Option<PendingToolUse>,
canceled: bool,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
dbg!("All tools finished");
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
dbg!("Attach tool results");
self.attach_tool_results(cx);
if !canceled {
self.send_to_model(model, window, cx);
self.send_to_model(model, cx);
}
}
}
println!("ToolFinished {}", tool_use_id);
cx.emit(ThreadEvent::ToolFinished {
tool_use_id,
pending_tool_use,
@@ -1755,11 +1733,7 @@ impl Thread {
/// Cancels the last pending completion, if there are any pending.
///
/// Returns whether a completion was canceled.
pub fn cancel_last_completion(
&mut self,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
let canceled = if self.pending_completions.pop().is_some() {
true
} else {
@@ -1770,7 +1744,6 @@ impl Thread {
pending_tool_use.id.clone(),
Some(pending_tool_use),
true,
window,
cx,
);
}
@@ -2227,7 +2200,6 @@ impl Thread {
&mut self,
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) {
let err = Err(anyhow::anyhow!(
@@ -2236,7 +2208,7 @@ impl Thread {
self.tool_use
.insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
self.tool_finished(tool_use_id.clone(), None, true, cx);
}
}

View File

@@ -72,7 +72,6 @@ impl ToolUseState {
.map(|tool_use| LanguageModelToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
raw_input: tool_use.input.to_string(),
input: tool_use.input.clone(),
is_input_complete: true,
})
@@ -450,7 +449,6 @@ impl ToolUseState {
message_id: MessageId,
request_message: &mut LanguageModelRequestMessage,
) {
dbg!(&self.tool_uses_by_assistant_message, &message_id);
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
for tool_use in tool_uses {
if self.tool_results.contains_key(&tool_use.id) {
@@ -473,7 +471,6 @@ impl ToolUseState {
message_id: MessageId,
request_message: &mut LanguageModelRequestMessage,
) {
dbg!(&self.tool_uses_by_user_message, &message_id);
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
for tool_use_id in tool_uses {
if let Some(tool_result) = self.tool_results.get(tool_use_id) {

View File

@@ -1,9 +1,7 @@
mod agent_notification;
mod animated_label;
mod context_pill;
mod usage_banner;
pub use agent_notification::*;
pub use animated_label::*;
pub use context_pill::*;
pub use usage_banner::*;

View File

@@ -1,121 +0,0 @@
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
use std::time::Duration;
use ui::prelude::*;
#[derive(IntoElement)]
pub struct AnimatedLabel {
base: Label,
text: SharedString,
}
impl AnimatedLabel {
pub fn new(text: impl Into<SharedString>) -> Self {
let text = text.into();
AnimatedLabel {
base: Label::new(text.clone()),
text,
}
}
}
impl LabelCommon for AnimatedLabel {
fn size(mut self, size: LabelSize) -> Self {
self.base = self.base.size(size);
self
}
fn weight(mut self, weight: FontWeight) -> Self {
self.base = self.base.weight(weight);
self
}
fn line_height_style(mut self, line_height_style: LineHeightStyle) -> Self {
self.base = self.base.line_height_style(line_height_style);
self
}
fn color(mut self, color: Color) -> Self {
self.base = self.base.color(color);
self
}
fn strikethrough(mut self) -> Self {
self.base = self.base.strikethrough();
self
}
fn italic(mut self) -> Self {
self.base = self.base.italic();
self
}
fn alpha(mut self, alpha: f32) -> Self {
self.base = self.base.alpha(alpha);
self
}
fn underline(mut self) -> Self {
self.base = self.base.underline();
self
}
fn truncate(mut self) -> Self {
self.base = self.base.truncate();
self
}
fn single_line(mut self) -> Self {
self.base = self.base.single_line();
self
}
fn buffer_font(mut self, cx: &App) -> Self {
self.base = self.base.buffer_font(cx);
self
}
fn inline_code(mut self, cx: &App) -> Self {
self.base = self.base.inline_code(cx);
self
}
}
impl RenderOnce for AnimatedLabel {
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
let text = self.text.clone();
self.base
.color(Color::Muted)
.with_animations(
"animated-label",
vec![
Animation::new(Duration::from_secs(1)),
Animation::new(Duration::from_secs(1)).repeat(),
],
move |mut label, animation_ix, delta| {
match animation_ix {
0 => {
let chars_to_show = (delta * text.len() as f32).ceil() as usize;
let text = SharedString::from(text[0..chars_to_show].to_string());
label.set_text(text);
}
1 => match delta {
d if d < 0.25 => label.set_text(text.clone()),
d if d < 0.5 => label.set_text(format!("{}.", text)),
d if d < 0.75 => label.set_text(format!("{}..", text)),
_ => label.set_text(format!("{}...", text)),
},
_ => {}
}
label
},
)
.with_animation(
"pulsating-label",
Animation::new(Duration::from_secs(2))
.repeat()
.with_easing(pulsating_between(0.6, 1.)),
|label, delta| label.map_element(|label| label.alpha(delta)),
)
}
}

View File

@@ -1,30 +1,31 @@
use client::zed_urls;
use language_model::RequestUsage;
use ui::{Banner, ProgressBar, Severity, prelude::*};
use zed_llm_client::{Plan, UsageLimit};
#[derive(IntoElement, RegisterComponent)]
pub struct UsageBanner {
plan: Plan,
usage: RequestUsage,
requests: i32,
}
impl UsageBanner {
pub fn new(plan: Plan, usage: RequestUsage) -> Self {
Self { plan, usage }
pub fn new(plan: Plan, requests: i32) -> Self {
Self { plan, requests }
}
}
impl RenderOnce for UsageBanner {
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
let used_percentage = match self.usage.limit {
UsageLimit::Limited(limit) => Some((self.usage.amount as f32 / limit as f32) * 100.),
let request_limit = self.plan.model_requests_limit();
let used_percentage = match request_limit {
UsageLimit::Limited(limit) => Some((self.requests as f32 / limit as f32) * 100.),
UsageLimit::Unlimited => None,
};
let (severity, message) = match self.usage.limit {
let (severity, message) = match request_limit {
UsageLimit::Limited(limit) => {
if self.usage.amount >= limit {
if self.requests >= limit {
let message = match self.plan {
Plan::ZedPro => "Monthly request limit reached",
Plan::ZedProTrial => "Trial request limit reached",
@@ -32,7 +33,7 @@ impl RenderOnce for UsageBanner {
};
(Severity::Error, message)
} else if (self.usage.amount as f32 / limit as f32) >= 0.9 {
} else if (self.requests as f32 / limit as f32) >= 0.9 {
(Severity::Warning, "Approaching request limit")
} else {
let message = match self.plan {
@@ -80,11 +81,11 @@ impl RenderOnce for UsageBanner {
.child(ProgressBar::new("usage", percent, 100., cx))
}))
.child(
Label::new(match self.usage.limit {
Label::new(match request_limit {
UsageLimit::Limited(limit) => {
format!("{} / {limit}", self.usage.amount)
format!("{} / {limit}", self.requests)
}
UsageLimit::Unlimited => format!("{} / ∞", self.usage.amount),
UsageLimit::Unlimited => format!("{} / ∞", self.requests),
})
.size(LabelSize::Small)
.color(Color::Muted),
@@ -103,131 +104,74 @@ impl Component for UsageBanner {
}
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
let trial_limit = Plan::ZedProTrial.model_requests_limit();
let trial_examples = vec![
single_example(
"Zed Pro Trial - New User",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedProTrial,
RequestUsage {
limit: trial_limit,
amount: 10,
},
))
.child(UsageBanner::new(Plan::ZedProTrial, 10))
.into_any_element(),
),
single_example(
"Zed Pro Trial - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedProTrial,
RequestUsage {
limit: trial_limit,
amount: 135,
},
))
.child(UsageBanner::new(Plan::ZedProTrial, 135))
.into_any_element(),
),
single_example(
"Zed Pro Trial - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedProTrial,
RequestUsage {
limit: trial_limit,
amount: 150,
},
))
.child(UsageBanner::new(Plan::ZedProTrial, 150))
.into_any_element(),
),
];
let free_limit = Plan::Free.model_requests_limit();
let free_examples = vec![
single_example(
"Free - Normal Usage",
div()
.size_full()
.child(UsageBanner::new(
Plan::Free,
RequestUsage {
limit: free_limit,
amount: 25,
},
))
.child(UsageBanner::new(Plan::Free, 25))
.into_any_element(),
),
single_example(
"Free - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(
Plan::Free,
RequestUsage {
limit: free_limit,
amount: 45,
},
))
.child(UsageBanner::new(Plan::Free, 45))
.into_any_element(),
),
single_example(
"Free - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(
Plan::Free,
RequestUsage {
limit: free_limit,
amount: 50,
},
))
.child(UsageBanner::new(Plan::Free, 50))
.into_any_element(),
),
];
let zed_pro_limit = Plan::ZedPro.model_requests_limit();
let zed_pro_examples = vec![
single_example(
"Zed Pro - Normal Usage",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedPro,
RequestUsage {
limit: zed_pro_limit,
amount: 250,
},
))
.child(UsageBanner::new(Plan::ZedPro, 250))
.into_any_element(),
),
single_example(
"Zed Pro - Approaching Limit",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedPro,
RequestUsage {
limit: zed_pro_limit,
amount: 450,
},
))
.child(UsageBanner::new(Plan::ZedPro, 450))
.into_any_element(),
),
single_example(
"Zed Pro - Request Limit Reached",
div()
.size_full()
.child(UsageBanner::new(
Plan::ZedPro,
RequestUsage {
limit: zed_pro_limit,
amount: 500,
},
))
.child(UsageBanner::new(Plan::ZedPro, 500))
.into_any_element(),
),
];

View File

@@ -23,6 +23,7 @@ use gpui::{
use language::LanguageRegistry;
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_library::{PromptLibrary, open_prompt_library};
@@ -488,8 +489,8 @@ impl AssistantPanel {
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
// the provider, we want to show a nudge to sign in.
let show_zed_ai_notice =
client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
let show_zed_ai_notice = client_status.is_signed_out()
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
self.show_zed_ai_notice = show_zed_ai_notice;
cx.notify();

View File

@@ -1226,7 +1226,7 @@ impl InlineAssistant {
editor.highlight_rows::<InlineAssist>(
row_range,
cx.theme().status().info_background,
Default::default(),
false,
cx,
);
}
@@ -1291,7 +1291,7 @@ impl InlineAssistant {
editor.highlight_rows::<DeletedLines>(
Anchor::min()..Anchor::max(),
cx.theme().status().deleted_background,
Default::default(),
false,
cx,
);
editor

View File

@@ -44,6 +44,4 @@ impl Settings for SlashCommandSettings {
.chain(sources.server),
)
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}

View File

@@ -112,27 +112,13 @@ impl AssistantSettings {
}
/// Assistant panel settings
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
pub struct AssistantSettingsContent {
#[serde(flatten)]
pub inner: Option<AssistantSettingsContentInner>,
}
#[derive(Clone, Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum AssistantSettingsContentInner {
pub enum AssistantSettingsContent {
Versioned(Box<VersionedAssistantSettingsContent>),
Legacy(LegacyAssistantSettingsContent),
}
impl AssistantSettingsContentInner {
fn for_v2(content: AssistantSettingsContentV2) -> Self {
AssistantSettingsContentInner::Versioned(Box::new(VersionedAssistantSettingsContent::V2(
content,
)))
}
}
impl JsonSchema for AssistantSettingsContent {
fn schema_name() -> String {
VersionedAssistantSettingsContent::schema_name()
@@ -147,21 +133,26 @@ impl JsonSchema for AssistantSettingsContent {
}
}
impl Default for AssistantSettingsContent {
fn default() -> Self {
Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
}
}
impl AssistantSettingsContent {
pub fn is_version_outdated(&self) -> bool {
match &self.inner {
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
match self {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(_) => true,
VersionedAssistantSettingsContent::V2(_) => false,
},
Some(AssistantSettingsContentInner::Legacy(_)) => true,
None => false,
AssistantSettingsContent::Legacy(_) => true,
}
}
fn upgrade(&self) -> AssistantSettingsContentV2 {
match &self.inner {
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
match self {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
enabled: settings.enabled,
button: settings.button,
@@ -221,7 +212,7 @@ impl AssistantSettingsContent {
},
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
},
Some(AssistantSettingsContentInner::Legacy(settings)) => AssistantSettingsContentV2 {
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
enabled: None,
button: settings.button,
dock: settings.dock,
@@ -246,13 +237,12 @@ impl AssistantSettingsContent {
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
},
None => AssistantSettingsContentV2::default(),
}
}
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
match &mut self.inner {
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
match self {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref mut settings) => {
settings.dock = Some(dock);
}
@@ -260,17 +250,9 @@ impl AssistantSettingsContent {
settings.dock = Some(dock);
}
},
Some(AssistantSettingsContentInner::Legacy(settings)) => {
AssistantSettingsContent::Legacy(settings) => {
settings.dock = Some(dock);
}
None => {
self.inner = Some(AssistantSettingsContentInner::for_v2(
AssistantSettingsContentV2 {
dock: Some(dock),
..Default::default()
},
))
}
}
}
@@ -278,8 +260,8 @@ impl AssistantSettingsContent {
let model = language_model.id().0.to_string();
let provider = language_model.provider_id().0.to_string();
match &mut self.inner {
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
match self {
AssistantSettingsContent::Versioned(settings) => match **settings {
VersionedAssistantSettingsContent::V1(ref mut settings) => {
match provider.as_ref() {
"zed.dev" => {
@@ -355,80 +337,56 @@ impl AssistantSettingsContent {
settings.default_model = Some(LanguageModelSelection { provider, model });
}
},
Some(AssistantSettingsContentInner::Legacy(settings)) => {
AssistantSettingsContent::Legacy(settings) => {
if let Ok(model) = OpenAiModel::from_id(&language_model.id().0) {
settings.default_open_ai_model = Some(model);
}
}
None => {
self.inner = Some(AssistantSettingsContentInner::for_v2(
AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection { provider, model }),
..Default::default()
},
));
}
}
}
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
self.v2_setting(|setting| {
setting.inline_assistant_model = Some(LanguageModelSelection { provider, model });
Ok(())
})
.ok();
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
self.v2_setting(|setting| {
setting.commit_message_model = Some(LanguageModelSelection { provider, model });
Ok(())
})
.ok();
}
pub fn v2_setting(
&mut self,
f: impl FnOnce(&mut AssistantSettingsContentV2) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
match self.inner.get_or_insert_with(|| {
AssistantSettingsContentInner::for_v2(AssistantSettingsContentV2 {
..Default::default()
})
}) {
AssistantSettingsContentInner::Versioned(boxed) => {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
f(settings)
} else {
Ok(())
}
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.commit_message_model = Some(LanguageModelSelection { provider, model });
}
_ => Ok(()),
}
}
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
self.v2_setting(|setting| {
setting.thread_summary_model = Some(LanguageModelSelection { provider, model });
Ok(())
})
.ok();
if let AssistantSettingsContent::Versioned(boxed) = self {
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
}
}
}
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
self.v2_setting(|setting| {
setting.always_allow_tool_actions = Some(allow);
Ok(())
})
.ok();
let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.always_allow_tool_actions = Some(allow);
}
}
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
self.v2_setting(|setting| {
setting.default_profile = Some(profile_id);
Ok(())
})
.ok();
let AssistantSettingsContent::Versioned(boxed) = self else {
return;
};
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
settings.default_profile = Some(profile_id);
}
}
pub fn create_profile(
@@ -436,7 +394,11 @@ impl AssistantSettingsContent {
profile_id: AgentProfileId,
profile: AgentProfile,
) -> Result<()> {
self.v2_setting(|settings| {
let AssistantSettingsContent::Versioned(boxed) = self else {
return Ok(());
};
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
let profiles = settings.profiles.get_or_insert_default();
if profiles.contains_key(&profile_id) {
bail!("profile with ID '{profile_id}' already exists");
@@ -462,9 +424,9 @@ impl AssistantSettingsContent {
.collect(),
},
);
}
Ok(())
})
Ok(())
}
}
@@ -499,7 +461,7 @@ impl Default for VersionedAssistantSettingsContent {
}
}
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
pub struct AssistantSettingsContentV2 {
/// Whether the Assistant is enabled.
///
@@ -746,39 +708,6 @@ impl Settings for AssistantSettings {
Ok(settings)
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
if let Some(b) = vscode
.read_value("chat.agent.enabled")
.and_then(|b| b.as_bool())
{
match &mut current.inner {
Some(AssistantSettingsContentInner::Versioned(versioned)) => {
match versioned.as_mut() {
VersionedAssistantSettingsContent::V1(setting) => {
setting.enabled = Some(b);
setting.button = Some(b);
}
VersionedAssistantSettingsContent::V2(setting) => {
setting.enabled = Some(b);
setting.button = Some(b);
}
}
}
Some(AssistantSettingsContentInner::Legacy(setting)) => setting.button = Some(b),
None => {
current.inner = Some(AssistantSettingsContentInner::for_v2(
AssistantSettingsContentV2 {
enabled: Some(b),
button: Some(b),
..Default::default()
},
));
}
}
}
}
}
fn merge<T>(target: &mut T, value: Option<T>) {
@@ -822,30 +751,28 @@ mod tests {
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
fs.clone(),
|settings, _| {
*settings = AssistantSettingsContent {
inner: Some(AssistantSettingsContentInner::for_v2(
AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
enable_experimental_live_diffs: None,
default_profile: None,
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
},
)),
}
*settings = AssistantSettingsContent::Versioned(Box::new(
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
default_model: Some(LanguageModelSelection {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
inline_assistant_model: None,
commit_message_model: None,
thread_summary_model: None,
inline_alternatives: None,
enabled: None,
button: None,
dock: None,
default_width: None,
default_height: None,
enable_experimental_live_diffs: None,
default_profile: None,
profiles: None,
always_allow_tool_actions: None,
notify_when_agent_waiting: None,
}),
))
},
);
});

View File

@@ -28,7 +28,6 @@ serde.workspace = true
serde_json.workspace = true
text.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
[dev-dependencies]

View File

@@ -39,9 +39,10 @@ impl ActionLog {
self.edited_since_project_diagnostics_check
}
fn track_buffer_internal(
fn track_buffer(
&mut self,
buffer: Entity<Buffer>,
created: bool,
cx: &mut Context<Self>,
) -> &mut TrackedBuffer {
let tracked_buffer = self
@@ -58,11 +59,7 @@ impl ActionLog {
let base_text;
let status;
let unreviewed_changes;
if buffer
.read(cx)
.file()
.map_or(true, |file| !file.disk_state().exists())
{
if created {
base_text = Rope::default();
status = TrackedBufferStatus::Created;
unreviewed_changes = Patch::new(vec![Edit {
@@ -149,7 +146,7 @@ impl ActionLog {
// resurrected externally, we want to clear the changes we
// were tracking and reset the buffer's state.
self.tracked_buffers.remove(&buffer);
self.track_buffer_internal(buffer, cx);
self.track_buffer(buffer, false, cx);
}
cx.notify();
}
@@ -263,15 +260,26 @@ impl ActionLog {
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn track_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer_internal(buffer, cx);
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer, false, cx);
}
/// Track a buffer that was added as context, so we can notify the model about user edits.
pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer, false, cx);
}
/// Track a buffer as read, so we can notify the model about user edits.
pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.track_buffer(buffer.clone(), true, cx);
self.buffer_edited(buffer, cx)
}
/// Mark a buffer as edited, so we can refresh it in the context
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
self.edited_since_project_diagnostics_check = true;
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
tracked_buffer.status = TrackedBufferStatus::Modified;
}
@@ -279,7 +287,7 @@ impl ActionLog {
}
pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
match tracked_buffer.status {
TrackedBufferStatus::Created => {
self.tracked_buffers.remove(&buffer);
@@ -389,7 +397,7 @@ impl ActionLog {
// Clear all tracked changes for this buffer and start over as if we just read it.
self.tracked_buffers.remove(&buffer);
self.track_buffer_internal(buffer.clone(), cx);
self.track_buffer(buffer.clone(), false, cx);
cx.notify();
save
}
@@ -687,20 +695,12 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx)
@@ -765,23 +765,12 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
path!("/dir"),
json!({"file": "abc\ndef\nghi\njkl\nmno\npqr"}),
)
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 0)..Point::new(2, 0), "")], None, cx)
@@ -850,20 +839,12 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
.await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let project = Project::test(fs.clone(), [], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
.unwrap();
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
@@ -947,21 +928,25 @@ mod tests {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let fs = FakeFs::new(cx.executor());
fs.insert_tree(path!("/dir"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
let file_path = project
.read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
.unwrap();
// Simulate file2 being recreated by a tool.
let buffer = project
.update(cx, |project, cx| project.open_buffer(file_path, cx))
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("lorem", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
@@ -1082,9 +1067,8 @@ mod tests {
.update(cx, |project, cx| project.open_buffer(file2_path, cx))
.await
.unwrap();
action_log.update(cx, |log, cx| log.track_buffer(buffer2.clone(), cx));
buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer2.clone(), cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer2.clone(), cx));
project
.update(cx, |project, cx| project.save_buffer(buffer2.clone(), cx))
.await
@@ -1129,7 +1113,7 @@ mod tests {
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
@@ -1264,7 +1248,7 @@ mod tests {
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| {
buffer
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
@@ -1397,9 +1381,8 @@ mod tests {
.await
.unwrap();
cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
});
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
@@ -1455,7 +1438,7 @@ mod tests {
.await
.unwrap();
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
for _ in 0..operations {
match rng.gen_range(0..100) {
@@ -1507,7 +1490,7 @@ mod tests {
log::info!("quiescing...");
cx.run_until_parked();
action_log.update(cx, |log, cx| {
let tracked_buffer = log.track_buffer_internal(buffer.clone(), cx);
let tracked_buffer = log.track_buffer(buffer.clone(), false, cx);
let mut old_text = tracked_buffer.base_text.clone();
let new_text = buffer.read(cx).as_rope();
for edit in tracked_buffer.unreviewed_changes.edits() {

View File

@@ -10,16 +10,14 @@ use std::sync::Arc;
use anyhow::Result;
use gpui::AnyElement;
use gpui::AnyWindowHandle;
use gpui::Context;
use gpui::IntoElement;
use gpui::Window;
use gpui::{App, Entity, SharedString, Task, WeakEntity};
use gpui::{App, Entity, SharedString, Task};
use icons::IconName;
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
use project::Project;
use workspace::Workspace;
pub use crate::action_log::*;
pub use crate::tool_registry::*;
@@ -67,7 +65,6 @@ pub trait ToolCard: 'static + Sized {
&mut self,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement;
}
@@ -79,7 +76,6 @@ pub struct AnyToolCard {
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut App,
) -> AnyElement,
}
@@ -90,14 +86,11 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
entity: gpui::AnyEntity,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut App,
) -> AnyElement {
let entity = entity.downcast::<T>().unwrap();
entity.update(cx, |entity, cx| {
entity
.render(status, window, workspace, cx)
.into_any_element()
entity.render(status, window, cx).into_any_element()
})
}
@@ -109,14 +102,8 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
}
impl AnyToolCard {
pub fn render(
&self,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut App,
) -> AnyElement {
(self.render)(self.entity.clone(), status, window, workspace, cx)
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
(self.render)(self.entity.clone(), status, window, cx)
}
}
@@ -176,7 +163,6 @@ pub trait Tool: 'static + Send + Sync {
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult;
}

View File

@@ -14,11 +14,10 @@ path = "src/assistant_tools.rs"
[dependencies]
anyhow.workspace = true
assistant_tool.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
collections.workspace = true
component.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
@@ -37,14 +36,11 @@ serde_json.workspace = true
ui.workspace = true
util.workspace = true
web_search.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
worktree.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
language = { workspace = true, features = ["test-support"] }

View File

@@ -9,12 +9,12 @@ mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
mod fetch_tool;
mod find_path_tool;
mod grep_tool;
mod list_directory_tool;
mod move_path_tool;
mod now_tool;
mod open_tool;
mod path_search_tool;
mod read_file_tool;
mod rename_tool;
mod replace;
@@ -29,9 +29,9 @@ use std::sync::Arc;
use assistant_tool::ToolRegistry;
use copy_path_tool::CopyPathTool;
use feature_flags::FeatureFlagAppExt;
use gpui::App;
use http_client::HttpClientWithUrl;
use language_model::LanguageModelRegistry;
use move_path_tool::MovePathTool;
use web_search_tool::WebSearchTool;
@@ -45,23 +45,17 @@ use crate::delete_path_tool::DeletePathTool;
use crate::diagnostics_tool::DiagnosticsTool;
use crate::edit_file_tool::EditFileTool;
use crate::fetch_tool::FetchTool;
use crate::find_path_tool::FindPathTool;
use crate::grep_tool::GrepTool;
use crate::list_directory_tool::ListDirectoryTool;
use crate::now_tool::NowTool;
use crate::open_tool::OpenTool;
use crate::path_search_tool::PathSearchTool;
use crate::read_file_tool::ReadFileTool;
use crate::rename_tool::RenameTool;
use crate::symbol_info_tool::SymbolInfoTool;
use crate::terminal_tool::TerminalTool;
use crate::thinking_tool::ThinkingTool;
pub use create_file_tool::CreateFileToolInput;
pub use edit_file_tool::EditFileToolInput;
pub use find_path_tool::FindPathToolInput;
pub use list_directory_tool::ListDirectoryToolInput;
pub use read_file_tool::ReadFileToolInput;
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
assistant_tool::init(cx);
@@ -82,52 +76,41 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
registry.register_tool(OpenTool);
registry.register_tool(CodeSymbolsTool);
registry.register_tool(ContentsTool);
registry.register_tool(FindPathTool);
registry.register_tool(PathSearchTool);
registry.register_tool(ReadFileTool);
registry.register_tool(GrepTool);
registry.register_tool(RenameTool);
registry.register_tool(ThinkingTool);
registry.register_tool(FetchTool::new(http_client));
cx.subscribe(
&LanguageModelRegistry::global(cx),
move |registry, event, cx| match event {
language_model::Event::DefaultModelChanged => {
let using_zed_provider = registry
.read(cx)
.default_model()
.map_or(false, |default| default.is_provided_by_zed());
if using_zed_provider {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
move |is_enabled, cx| {
if is_enabled {
ToolRegistry::global(cx).register_tool(WebSearchTool);
} else {
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
}
_ => {}
},
)
}
})
.detach();
}
#[cfg(test)]
mod tests {
use client::Client;
use clock::FakeSystemClock;
use http_client::FakeHttpClient;
use super::*;
#[gpui::test]
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
settings::init(cx);
let client = Client::new(
Arc::new(FakeSystemClock::new()),
FakeHttpClient::with_200_response(),
crate::init(
Arc::new(http_client::HttpClientWithUrl::new(
FakeHttpClient::with_200_response(),
"https://zed.dev",
None,
)),
cx,
);
language_model::init(client.clone(), cx);
crate::init(client.http_client(), cx);
for tool in ToolRegistry::global(cx).tools() {
let actual_schema = tool

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
use futures::future::join_all;
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -97,7 +97,7 @@ pub struct BatchToolInput {
/// }
/// },
/// {
/// "name": "find_path",
/// "name": "path_search",
/// "input": {
/// "glob": "**/*test*.rs"
/// }
@@ -218,7 +218,6 @@ impl Tool for BatchTool {
messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<BatchToolInput>(input) {
@@ -259,9 +258,7 @@ impl Tool for BatchTool {
let action_log = action_log.clone();
let messages = messages.clone();
let tool_result = cx
.update(|cx| {
tool.run(invocation.input, &messages, project, action_log, window, cx)
})
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
tasks.push(tool_result.output);

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language::{self, Anchor, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{self, LspAction, Project};
@@ -140,7 +140,6 @@ impl Tool for CodeActionTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
@@ -160,7 +159,7 @@ impl Tool for CodeActionTool {
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
let range = {

View File

@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use collections::IndexMap;
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use gpui::{App, AsyncApp, Entity, Task};
use language::{OutlineItem, ParseStatus, Point};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, Symbol};
@@ -128,7 +128,6 @@ impl Tool for CodeSymbolsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
@@ -175,7 +174,7 @@ pub async fn file_outline(
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
// Wait until the buffer has been fully parsed, so that we can read its outline.

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -102,7 +102,6 @@ impl Tool for ContentsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ContentsToolInput>(input) {
@@ -210,7 +209,7 @@ impl Tool for ContentsTool {
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)
@@ -222,7 +221,7 @@ impl Tool for ContentsTool {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)

View File

@@ -1,7 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, AppContext, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -77,7 +76,6 @@ impl Tool for CopyPathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CopyPathToolInput>(input) {

View File

@@ -1,7 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -68,7 +67,6 @@ impl Tool for CreateDirectoryTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {

View File

@@ -1,7 +1,6 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::AnyWindowHandle;
use gpui::{App, Entity, Task};
use language_model::LanguageModelRequestMessage;
use language_model::LanguageModelToolSchemaFormat;
@@ -24,9 +23,6 @@ pub struct CreateFileToolInput {
///
/// You can create a new file by providing a path of "directory1/new_file.txt"
/// </example>
///
/// Make sure to include this field before the `contents` field in the input object
/// so that we can display it immediately.
pub path: String,
/// The text contents of the file to create.
@@ -93,7 +89,6 @@ impl Tool for CreateFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<CreateFileToolInput>(input) {
@@ -117,12 +112,9 @@ impl Tool for CreateFileTool {
.await
.map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
cx.update(|cx| {
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx)
});
buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
action_log.update(cx, |action_log, cx| {
action_log.buffer_edited(buffer.clone(), cx)
action_log.will_create_buffer(buffer.clone(), cx)
});
})?;

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::{SinkExt, StreamExt, channel::mpsc};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{Project, ProjectPath};
use schemars::JsonSchema;
@@ -62,7 +62,6 @@ impl Tool for DeletePathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language::{DiagnosticSeverity, OffsetRangeExt};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -82,7 +82,6 @@ impl Tool for DiagnosticsTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
match serde_json::from_value::<DiagnosticsToolInput>(input)

View File

@@ -1,40 +1,18 @@
use crate::{
replace::{replace_exact, replace_with_flexible_indent},
schema::json_schema_for,
};
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
use gpui::{
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
};
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::{
path::{Path, PathBuf},
sync::Arc,
};
use ui::{Disclosure, Tooltip, Window, prelude::*};
use util::ResultExt;
use workspace::Workspace;
use std::{path::PathBuf, sync::Arc};
use ui::IconName;
use crate::replace::replace_exact;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFileToolInput {
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
///
/// Make sure to include this field before all the others in the input object
/// so that we can display it immediately.
pub display_description: String,
/// The full path of the file to modify in the project.
///
/// WARNING: When specifying which file path need changing, you MUST
@@ -56,6 +34,12 @@ pub struct EditFileToolInput {
/// </example>
pub path: PathBuf,
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
///
/// <example>Fix API endpoint URLs</example>
/// <example>Update copyright year in `page_footer`</example>
pub display_description: String,
/// The text to replace.
pub old_string: String,
@@ -77,7 +61,7 @@ struct PartialInput {
pub struct EditFileTool;
const DEFAULT_UI_TEXT: &str = "Editing file";
const DEFAULT_UI_TEXT: &str = "Edit file";
impl Tool for EditFileTool {
fn name(&self) -> String {
@@ -103,7 +87,7 @@ impl Tool for EditFileTool {
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
Ok(input) => input.display_description,
Err(_) => "Editing file".to_string(),
Err(_) => "Edit file".to_string(),
}
}
@@ -129,7 +113,6 @@ impl Tool for EditFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<EditFileToolInput>(input) {
@@ -137,18 +120,7 @@ impl Tool for EditFileTool {
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
let card = window.and_then(|window| {
window
.update(cx, |_, window, cx| {
cx.new(|cx| {
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
})
})
.ok()
});
let card_clone = card.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
cx.spawn(async move |cx: &mut AsyncApp| {
let project_path = project.read_with(cx, |project, cx| {
project
.find_project_path(&input.path, cx)
@@ -156,38 +128,26 @@ impl Tool for EditFileTool {
})??;
let buffer = project
.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})?
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
if input.old_string.is_empty() {
return Err(anyhow!(
"`old_string` can't be empty, use another tool if you want to create a file."
));
return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
}
if input.old_string == input.new_string {
return Err(anyhow!(
"The `old_string` and `new_string` are identical, so no changes would be made."
));
return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
}
let result = cx
.background_spawn(async move {
// Try to match exactly
let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
.await
// If that fails, try being flexible about indentation
.or_else(|| {
replace_with_flexible_indent(
&input.old_string,
&input.new_string,
&snapshot,
)
})?;
.await
// If that fails, try being flexible about indentation
.or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
if diff.edits.is_empty() {
return None;
@@ -217,409 +177,41 @@ impl Tool for EditFileTool {
}
})?;
return Err(err);
return Err(err)
};
let snapshot = cx.update(|cx| {
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx)
});
let snapshot = buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction();
buffer.apply_diff(diff, cx);
buffer.finalize_last_transaction();
buffer.snapshot()
});
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
action_log.update(cx, |log, cx| {
log.buffer_edited(buffer.clone(), cx)
});
snapshot
})?;
project
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
.await?;
project.update( cx, |project, cx| {
project.save_buffer(buffer, cx)
})?.await?;
let new_text = snapshot.text();
let diff_str = cx
.background_spawn({
let old_text = old_text.clone();
let new_text = new_text.clone();
async move { language::unified_diff(&old_text, &new_text) }
})
.await;
let diff_str = cx.background_spawn(async move {
let new_text = snapshot.text();
language::unified_diff(&old_text, &new_text)
}).await;
if let Some(card) = card_clone {
card.update(cx, |card, cx| {
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
})
.log_err();
}
Ok(format!(
"Edited {}:\n\n```diff\n{}\n```",
input.path.display(),
diff_str
))
});
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
ToolResult {
output: task,
card: card.map(AnyToolCard::from),
}
}).into()
}
}
pub struct EditFileToolCard {
path: PathBuf,
editor: Entity<Editor>,
multibuffer: Entity<MultiBuffer>,
project: Entity<Project>,
diff_task: Option<Task<Result<()>>>,
preview_expanded: bool,
full_height_expanded: bool,
editor_unique_id: EntityId,
}
impl EditFileToolCard {
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
Some(project.clone()),
window,
cx,
);
editor.set_show_scrollbars(false, cx);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_scrolling(cx);
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor
});
Self {
editor_unique_id: editor.entity_id(),
path,
project,
editor,
multibuffer,
diff_task: None,
preview_expanded: true,
full_height_expanded: false,
}
}
fn set_diff(
&mut self,
path: Arc<Path>,
old_text: String,
new_text: String,
cx: &mut Context<Self>,
) {
let language_registry = self.project.read(cx).languages().clone();
self.diff_task = Some(cx.spawn(async move |this, cx| {
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
this.update(cx, |this, cx| {
this.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let diff = buffer_diff.read(cx);
let diff_hunk_ranges = diff
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
.collect::<Vec<_>>();
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
diff_hunk_ranges,
editor::DEFAULT_MULTIBUFFER_CONTEXT,
cx,
);
debug_assert!(is_newly_added);
multibuffer.add_diff(buffer_diff, cx);
});
cx.notify();
})
}));
}
}
impl ToolCard for EditFileToolCard {
fn render(
&mut self,
status: &ToolUseStatus,
window: &mut Window,
workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let failed = matches!(status, ToolUseStatus::Error(_));
let path_label_button = h_flex()
.id(("edit-tool-path-label-button", self.editor_unique_id))
.w_full()
.max_w_full()
.px_1()
.gap_0p5()
.cursor_pointer()
.rounded_sm()
.opacity(0.8)
.hover(|label| {
label
.opacity(1.)
.bg(cx.theme().colors().element_hover.opacity(0.5))
})
.tooltip(Tooltip::text("Jump to File"))
.child(
h_flex()
.child(
Icon::new(IconName::Pencil)
.size(IconSize::XSmall)
.color(Color::Muted),
)
.child(
div()
.text_size(rems(0.8125))
.child(self.path.display().to_string())
.ml_1p5()
.mr_0p5(),
)
.child(
Icon::new(IconName::ArrowUpRight)
.size(IconSize::XSmall)
.color(Color::Ignored),
),
)
.on_click({
let path = self.path.clone();
let workspace = workspace.clone();
move |_, window, cx| {
workspace
.update(cx, {
|workspace, cx| {
let Some(project_path) =
workspace.project().read(cx).find_project_path(&path, cx)
else {
return;
};
let open_task =
workspace.open_path(project_path, None, true, window, cx);
window
.spawn(cx, async move |cx| {
let item = open_task.await?;
if let Some(active_editor) = item.downcast::<Editor>() {
active_editor
.update_in(cx, |editor, window, cx| {
editor.go_to_singleton_buffer_point(
language::Point::new(0, 0),
window,
cx,
);
})
.log_err();
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
})
.ok();
}
})
.into_any_element();
let codeblock_header_bg = cx
.theme()
.colors()
.element_background
.blend(cx.theme().colors().editor_foreground.opacity(0.025));
let codeblock_header = h_flex()
.flex_none()
.p_1()
.gap_1()
.justify_between()
.rounded_t_md()
.when(!failed, |header| header.bg(codeblock_header_bg))
.child(path_label_button)
.map(|container| {
if failed {
container.child(
Icon::new(IconName::Close)
.size(IconSize::Small)
.color(Color::Error),
)
} else {
container.child(
Disclosure::new(
("edit-file-disclosure", self.editor_unique_id),
self.preview_expanded,
)
.opened_icon(IconName::ChevronUp)
.closed_icon(IconName::ChevronDown)
.on_click(cx.listener(
move |this, _event, _window, _cx| {
this.preview_expanded = !this.preview_expanded;
},
)),
)
}
});
let editor = self.editor.update(cx, |editor, cx| {
editor.render(window, cx).into_any_element()
});
let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
(IconName::ChevronUp, "Collapse Code Block")
} else {
(IconName::ChevronDown, "Expand Code Block")
};
let gradient_overlay = div()
.absolute()
.bottom_0()
.left_0()
.w_full()
.h_2_5()
.rounded_b_lg()
.bg(gpui::linear_gradient(
0.,
gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
));
let border_color = cx.theme().colors().border.opacity(0.6);
v_flex()
.mb_2()
.border_1()
.when(failed, |card| card.border_dashed())
.border_color(border_color)
.rounded_lg()
.overflow_hidden()
.child(codeblock_header)
.when(!failed && self.preview_expanded, |card| {
card.child(
v_flex()
.relative()
.overflow_hidden()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.map(|editor_container| {
if self.full_height_expanded {
editor_container.h_full()
} else {
editor_container.max_h_64()
}
})
.child(div().pl_1().child(editor))
.when(!self.full_height_expanded, |editor_container| {
editor_container.child(gradient_overlay)
}),
)
})
.when(!failed && self.preview_expanded, |card| {
card.child(
h_flex()
.id(("edit-tool-card-inner-hflex", self.editor_unique_id))
.flex_none()
.cursor_pointer()
.h_5()
.justify_center()
.rounded_b_md()
.border_t_1()
.border_color(border_color)
.bg(cx.theme().colors().editor_background)
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
.child(
Icon::new(full_height_icon)
.size(IconSize::Small)
.color(Color::Muted),
)
.tooltip(Tooltip::text(full_height_tooltip_label))
.on_click(cx.listener(move |this, _event, _window, _cx| {
this.full_height_expanded = !this.full_height_expanded;
})),
)
})
}
}
async fn build_buffer(
mut text: String,
path: Arc<Path>,
language_registry: &Arc<language::LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<Buffer>> {
let line_ending = LineEnding::detect(&text);
LineEnding::normalize(&mut text);
let text = Rope::from(text);
let language = cx
.update(|_cx| language_registry.language_for_file_path(&path))?
.await
.ok();
let buffer = cx.new(|cx| {
let buffer = TextBuffer::new_normalized(
0,
cx.entity_id().as_non_zero_u64().into(),
line_ending,
text,
);
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
buffer.set_language(language, cx);
buffer
})?;
Ok(buffer)
}
async fn build_buffer_diff(
mut old_text: String,
buffer: &Entity<Buffer>,
language_registry: &Arc<LanguageRegistry>,
cx: &mut AsyncApp,
) -> Result<Entity<BufferDiff>> {
LineEnding::normalize(&mut old_text);
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
let base_buffer = cx
.update(|cx| {
Buffer::build_snapshot(
old_text.clone().into(),
buffer.language().cloned(),
Some(language_registry.clone()),
cx,
)
})?
.await;
let diff_snapshot = cx
.update(|cx| {
BufferDiffSnapshot::new_with_base_buffer(
buffer.text.clone(),
Some(old_text.into()),
base_buffer,
cx,
)
})?
.await;
cx.new(|cx| {
let mut diff = BufferDiff::new(&buffer.text, cx);
diff.set_snapshot(diff_snapshot, &buffer.text, cx);
diff
})
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::AsyncReadExt as _;
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
use gpui::{App, AppContext as _, Entity, Task};
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
use http_client::{AsyncBody, HttpClientWithUrl};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -145,7 +145,6 @@ impl Tool for FetchTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<FetchToolInput>(input) {

View File

@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::StreamExt;
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language::OffsetRangeExt;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::{
@@ -20,8 +20,6 @@ use util::paths::PathMatcher;
pub struct GrepToolInput {
/// A regex pattern to search for in the entire project. Note that the regex
/// will be parsed by the Rust `regex` crate.
///
/// Do NOT specify a path here! This will only be matched against the code **content**.
pub regex: String,
/// A glob pattern for the paths of files to include in the search.
@@ -98,7 +96,6 @@ impl Tool for GrepTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
const CONTEXT_LINES: u32 = 2;
@@ -408,7 +405,7 @@ mod tests {
) -> String {
let tool = Arc::new(GrepTool);
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
let task = cx.update(|cx| tool.run(input, &[], project, action_log, cx));
match task.output.await {
Ok(result) => result,

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -76,7 +76,6 @@ impl Tool for ListDirectoryTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {

View File

@@ -1 +1 @@
Lists files and directories in a given path. Prefer the `grep` or `find_path` tools when searching the codebase.
Lists files and directories in a given path. Prefer the `grep` or `path_search` tools when searching the codebase.

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -89,7 +89,6 @@ impl Tool for MovePathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<MovePathToolInput>(input) {

View File

@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use chrono::{Local, Utc};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -59,7 +59,6 @@ impl Tool for NowTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
let input: NowToolInput = match serde_json::from_value(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -52,7 +52,6 @@ impl Tool for OpenTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: OpenToolInput = match serde_json::from_value(input) {

View File

@@ -1,7 +1,7 @@
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -12,7 +12,7 @@ use util::paths::PathMatcher;
use worktree::Snapshot;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct FindPathToolInput {
pub struct PathSearchToolInput {
/// The glob to match against every path in the project.
///
/// <example>
@@ -34,11 +34,11 @@ pub struct FindPathToolInput {
const RESULTS_PER_PAGE: usize = 50;
pub struct FindPathTool;
pub struct PathSearchTool;
impl Tool for FindPathTool {
impl Tool for PathSearchTool {
fn name(&self) -> String {
"find_path".into()
"path_search".into()
}
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
@@ -46,7 +46,7 @@ impl Tool for FindPathTool {
}
fn description(&self) -> String {
include_str!("./find_path_tool/description.md").into()
include_str!("./path_search_tool/description.md").into()
}
fn icon(&self) -> IconName {
@@ -54,11 +54,11 @@ impl Tool for FindPathTool {
}
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
json_schema_for::<FindPathToolInput>(format)
json_schema_for::<PathSearchToolInput>(format)
}
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
Err(_) => "Search paths".to_string(),
}
@@ -70,10 +70,9 @@ impl Tool for FindPathTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
Ok(input) => (input.offset, input.glob),
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
};
@@ -144,7 +143,7 @@ mod test {
use util::path;
#[gpui::test]
async fn test_find_path_tool(cx: &mut TestAppContext) {
async fn test_path_search_tool(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());

View File

@@ -1,4 +1,4 @@
Fast file path pattern matching tool that works with any codebase size
Fast file pattern matching tool that works with any codebase size
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
- Returns matching file paths sorted alphabetically

View File

@@ -1,8 +1,7 @@
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use indoc::formatdoc;
use itertools::Itertools;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
@@ -88,7 +87,6 @@ impl Tool for ReadFileTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
@@ -136,7 +134,7 @@ impl Tool for ReadFileTool {
})?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)
@@ -149,7 +147,7 @@ impl Tool for ReadFileTool {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.track_buffer(buffer, cx);
log.buffer_read(buffer, cx);
})?;
Ok(result)
@@ -195,7 +193,7 @@ mod test {
"path": "root/nonexistent_file.txt"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, None, cx)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
@@ -225,7 +223,7 @@ mod test {
"path": "root/small_file.txt"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, None, cx)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
@@ -255,7 +253,7 @@ mod test {
"path": "root/large_file.rs"
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log.clone(), None, cx)
.run(input, &[], project.clone(), action_log.clone(), cx)
.output
})
.await;
@@ -279,7 +277,7 @@ mod test {
"offset": 1
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, None, cx)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;
@@ -325,7 +323,7 @@ mod test {
"end_line": 4
});
Arc::new(ReadFileTool)
.run(input, &[], project.clone(), action_log, None, cx)
.run(input, &[], project.clone(), action_log, cx)
.output
})
.await;

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language::{self, Buffer, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -87,7 +87,6 @@ impl Tool for RenameTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<RenameToolInput>(input) {
@@ -107,7 +106,7 @@ impl Tool for RenameTool {
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
let position = {

View File

@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
use gpui::{App, AsyncApp, Entity, Task};
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -121,7 +121,6 @@ impl Tool for SymbolInfoTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
@@ -141,7 +140,7 @@ impl Tool for SymbolInfoTool {
};
action_log.update(cx, |action_log, cx| {
action_log.track_buffer(buffer.clone(), cx);
action_log.buffer_read(buffer.clone(), cx);
})?;
let position = {

View File

@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
use gpui::{App, AppContext, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -78,7 +78,6 @@ impl Tool for TerminalTool {
_messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input: TerminalToolInput = match serde_json::from_value(input) {

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::schema::json_schema_for;
use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolResult};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
@@ -50,7 +50,6 @@ impl Tool for ThinkingTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
_cx: &mut App,
) -> ToolResult {
// This tool just "thinks out loud" and doesn't perform any actions.

View File

@@ -5,16 +5,13 @@ use crate::ui::ToolCallCardHeader;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
use futures::{Future, FutureExt, TryFutureExt};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
};
use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use workspace::Workspace;
use zed_llm_client::{WebSearchCitation, WebSearchResponse};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -57,7 +54,6 @@ impl Tool for WebSearchTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
@@ -115,7 +111,6 @@ impl ToolCard for WebSearchToolCard {
&mut self,
_status: &ToolUseStatus,
_window: &mut Window,
_workspace: WeakEntity<Workspace>,
cx: &mut Context<Self>,
) -> impl IntoElement {
let header = match self.response.as_ref() {
@@ -225,13 +220,8 @@ impl Component for WebSearchTool {
div()
.size_full()
.child(in_progress_search.update(cx, |tool, cx| {
tool.render(
&ToolUseStatus::Pending,
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
tool.render(&ToolUseStatus::Pending, window, cx)
.into_any_element()
}))
.into_any_element(),
),
@@ -240,13 +230,8 @@ impl Component for WebSearchTool {
div()
.size_full()
.child(successful_search.update(cx, |tool, cx| {
tool.render(
&ToolUseStatus::Finished("".into()),
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
tool.render(&ToolUseStatus::Finished("".into()), window, cx)
.into_any_element()
}))
.into_any_element(),
),
@@ -255,13 +240,8 @@ impl Component for WebSearchTool {
div()
.size_full()
.child(error_search.update(cx, |tool, cx| {
tool.render(
&ToolUseStatus::Error("".into()),
window,
WeakEntity::new_invalid(),
cx,
)
.into_any_element()
tool.render(&ToolUseStatus::Error("".into()), window, cx)
.into_any_element()
}))
.into_any_element(),
),

View File

@@ -118,13 +118,6 @@ impl Settings for AutoUpdateSetting {
Ok(Self(auto_update.0))
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
vscode.enum_setting("update.mode", current, |s| match s {
"none" | "manual" => Some(AutoUpdateSettingContent(false)),
_ => Some(AutoUpdateSettingContent(true)),
});
}
}
#[derive(Default)]

View File

@@ -32,6 +32,4 @@ impl Settings for CallSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}

View File

@@ -104,8 +104,6 @@ impl Settings for ClientSettings {
}
Ok(result)
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
@@ -132,10 +130,6 @@ impl Settings for ProxySettings {
.or(sources.default.proxy.clone()),
})
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
vscode.string_setting("http.proxy", &mut current.proxy);
}
}
pub fn init_settings(cx: &mut App) {
@@ -524,18 +518,6 @@ impl settings::Settings for TelemetrySettings {
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
})
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
Some(s == "all")
});
vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
Some(matches!(s, "all" | "error" | "crash"))
});
// we could translate telemetry.telemetryLevel, but just because users didn't want
// to send microsoft telemetry doesn't mean they don't want to send it to zed. their
// all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
}
}
impl Client {

View File

@@ -34,12 +34,14 @@ dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
futures.workspace = true
google_ai.workspace = true
hex.workspace = true
http_client.workspace = true
jsonwebtoken.workspace = true
livekit_api.workspace = true
log.workspace = true
nanoid.workspace = true
open_ai.workspace = true
parking_lot.workspace = true
prometheus = "0.14"
prost.workspace = true
@@ -126,7 +128,6 @@ serde_json.workspace = true
session = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
sqlx = { version = "0.8", features = ["sqlite"] }
task.workspace = true
theme.workspace = true
unindent.workspace = true
util.workspace = true

View File

@@ -492,8 +492,7 @@ CREATE TABLE IF NOT EXISTS billing_customers (
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
user_id INTEGER NOT NULL REFERENCES users (id),
has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE,
stripe_customer_id TEXT NOT NULL,
trial_started_at TIMESTAMP
stripe_customer_id TEXT NOT NULL
);
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);

View File

@@ -1,2 +0,0 @@
alter table billing_customers
add column trial_started_at timestamp without time zone;

View File

@@ -1,2 +0,0 @@
alter table project_repositories
add column head_commit_details varchar;

View File

@@ -287,7 +287,7 @@ async fn create_billing_subscription(
}
}
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
let customer_id = if let Some(existing_customer) = existing_billing_customer {
CustomerId::from_str(&existing_customer.stripe_customer_id)
.context("failed to parse customer ID")?
} else {
@@ -320,15 +320,6 @@ async fn create_billing_subscription(
.await?
}
Some(ProductCode::ZedProTrial) => {
if let Some(existing_billing_customer) = &existing_billing_customer {
if existing_billing_customer.trial_started_at.is_some() {
return Err(Error::http(
StatusCode::FORBIDDEN,
"user already used free trial".into(),
));
}
}
stripe_billing
.checkout_with_zed_pro_trial(
app.config.zed_pro_price_id()?,
@@ -826,24 +817,6 @@ async fn handle_customer_subscription_event(
.await?
.ok_or_else(|| anyhow!("billing customer not found"))?;
if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
if subscription.status == SubscriptionStatus::Trialing {
let current_period_start =
DateTime::from_timestamp(subscription.current_period_start, 0)
.ok_or_else(|| anyhow!("No trial subscription period start"))?;
app.db
.update_billing_customer(
billing_customer.id,
&UpdateBillingCustomerParams {
trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
..Default::default()
},
)
.await?;
}
}
let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
&& subscription
.cancellation_details
@@ -870,28 +843,6 @@ async fn handle_customer_subscription_event(
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
.await?
{
let llm_db = app
.llm_db
.clone()
.ok_or_else(|| anyhow!("LLM DB not initialized"))?;
let new_period_start_at =
chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
.ok_or_else(|| anyhow!("No subscription period start"))?;
let new_period_end_at =
chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
.ok_or_else(|| anyhow!("No subscription period end"))?;
llm_db
.transfer_existing_subscription_usage(
billing_customer.user_id,
&existing_subscription,
subscription_kind,
new_period_start_at,
new_period_end_at,
)
.await?;
app.db
.update_billing_subscription(
existing_subscription.id,

View File

@@ -11,7 +11,6 @@ pub struct UpdateBillingCustomerParams {
pub user_id: ActiveValue<UserId>,
pub stripe_customer_id: ActiveValue<String>,
pub has_overdue_invoices: ActiveValue<bool>,
pub trial_started_at: ActiveValue<Option<DateTime>>,
}
impl Database {
@@ -46,8 +45,7 @@ impl Database {
user_id: params.user_id.clone(),
stripe_customer_id: params.stripe_customer_id.clone(),
has_overdue_invoices: params.has_overdue_invoices.clone(),
trial_started_at: params.trial_started_at.clone(),
created_at: ActiveValue::not_set(),
..Default::default()
})
.exec(&*tx)
.await?;

View File

@@ -10,7 +10,6 @@ pub struct Model {
pub user_id: UserId,
pub stripe_customer_id: String,
pub has_overdue_invoices: bool,
pub trial_started_at: Option<DateTime>,
pub created_at: DateTime,
}

View File

@@ -1,87 +1,8 @@
use chrono::Timelike;
use time::PrimitiveDateTime;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{UserId, billing_subscription};
use crate::db::UserId;
use super::*;
fn convert_chrono_to_time(datetime: DateTimeUtc) -> anyhow::Result<PrimitiveDateTime> {
use chrono::{Datelike as _, Timelike as _};
let date = time::Date::from_calendar_date(
datetime.year(),
time::Month::try_from(datetime.month() as u8).unwrap(),
datetime.day() as u8,
)?;
let time = time::Time::from_hms_nano(
datetime.hour() as u8,
datetime.minute() as u8,
datetime.second() as u8,
datetime.nanosecond(),
)?;
Ok(PrimitiveDateTime::new(date, time))
}
impl LlmDatabase {
pub async fn create_subscription_usage(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
plan: SubscriptionKind,
model_requests: i32,
edit_predictions: i32,
) -> Result<subscription_usage::Model> {
self.transaction(|tx| async move {
self.create_subscription_usage_in_tx(
user_id,
period_start_at,
period_end_at,
plan,
model_requests,
edit_predictions,
&tx,
)
.await
})
.await
}
async fn create_subscription_usage_in_tx(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
plan: SubscriptionKind,
model_requests: i32,
edit_predictions: i32,
tx: &DatabaseTransaction,
) -> Result<subscription_usage::Model> {
// Clear out the nanoseconds so that these timestamps are comparable with Unix timestamps.
let period_start_at = period_start_at.with_nanosecond(0).unwrap();
let period_end_at = period_end_at.with_nanosecond(0).unwrap();
let period_start_at = convert_chrono_to_time(period_start_at)?;
let period_end_at = convert_chrono_to_time(period_end_at)?;
Ok(
subscription_usage::Entity::insert(subscription_usage::ActiveModel {
id: ActiveValue::not_set(),
user_id: ActiveValue::set(user_id),
period_start_at: ActiveValue::set(period_start_at),
period_end_at: ActiveValue::set(period_end_at),
plan: ActiveValue::set(plan),
model_requests: ActiveValue::set(model_requests),
edit_predictions: ActiveValue::set(edit_predictions),
})
.exec_with_returning(tx)
.await?,
)
}
pub async fn get_subscription_usage_for_period(
&self,
user_id: UserId,
@@ -89,77 +10,12 @@ impl LlmDatabase {
period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
self.get_subscription_usage_for_period_in_tx(
user_id,
period_start_at,
period_end_at,
&tx,
)
.await
})
.await
}
async fn get_subscription_usage_for_period_in_tx(
&self,
user_id: UserId,
period_start_at: DateTimeUtc,
period_end_at: DateTimeUtc,
tx: &DatabaseTransaction,
) -> Result<Option<subscription_usage::Model>> {
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(tx)
.await?)
}
pub async fn transfer_existing_subscription_usage(
&self,
user_id: UserId,
existing_subscription: &billing_subscription::Model,
new_subscription_kind: Option<SubscriptionKind>,
new_period_start_at: DateTimeUtc,
new_period_end_at: DateTimeUtc,
) -> Result<Option<subscription_usage::Model>> {
self.transaction(|tx| async move {
match existing_subscription.kind {
Some(SubscriptionKind::ZedProTrial) => {
let trial_period_start_at = existing_subscription
.current_period_start_at()
.ok_or_else(|| anyhow!("No trial subscription period start"))?;
let trial_period_end_at = existing_subscription
.current_period_end_at()
.ok_or_else(|| anyhow!("No trial subscription period end"))?;
let existing_usage = self
.get_subscription_usage_for_period_in_tx(
user_id,
trial_period_start_at,
trial_period_end_at,
&tx,
)
.await?;
if let Some(existing_usage) = existing_usage {
return Ok(Some(
self.create_subscription_usage_in_tx(
user_id,
new_period_start_at,
new_period_end_at,
new_subscription_kind.unwrap_or(existing_usage.plan),
existing_usage.model_requests,
existing_usage.edit_predictions,
&tx,
)
.await?,
));
}
}
_ => {}
}
Ok(None)
Ok(subscription_usage::Entity::find()
.filter(subscription_usage::Column::UserId.eq(user_id))
.filter(subscription_usage::Column::PeriodStartAt.eq(period_start_at))
.filter(subscription_usage::Column::PeriodEndAt.eq(period_end_at))
.one(&*tx)
.await?)
})
.await
}

View File

@@ -1,5 +1,4 @@
mod provider_tests;
mod subscription_usage_tests;
use gpui::BackgroundExecutor;
use parking_lot::Mutex;

View File

@@ -1,69 +0,0 @@
use chrono::{Duration, Utc};
use pretty_assertions::assert_eq;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{UserId, billing_subscription};
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;
test_llm_db!(
test_transfer_existing_subscription_usage,
test_transfer_existing_subscription_usage_postgres
);
async fn test_transfer_existing_subscription_usage(db: &mut LlmDatabase) {
let user_id = UserId(1);
let now = Utc::now();
let trial_period_start_at = now - Duration::days(14);
let trial_period_end_at = now;
let new_period_start_at = now;
let new_period_end_at = now + Duration::days(30);
let existing_subscription = billing_subscription::Model {
kind: Some(SubscriptionKind::ZedProTrial),
stripe_current_period_start: Some(trial_period_start_at.timestamp()),
stripe_current_period_end: Some(trial_period_end_at.timestamp()),
..Default::default()
};
let existing_usage = db
.create_subscription_usage(
user_id,
trial_period_start_at,
trial_period_end_at,
SubscriptionKind::ZedProTrial,
25,
1_000,
)
.await
.unwrap();
let transferred_usage = db
.transfer_existing_subscription_usage(
user_id,
&existing_subscription,
Some(SubscriptionKind::ZedPro),
new_period_start_at,
new_period_end_at,
)
.await
.unwrap();
assert!(
transferred_usage.is_some(),
"subscription usage not transferred successfully"
);
let transferred_usage = transferred_usage.unwrap();
assert_eq!(
transferred_usage.model_requests,
existing_usage.model_requests
);
assert_eq!(
transferred_usage.edit_predictions,
existing_usage.edit_predictions
);
}

View File

@@ -1,5 +1,4 @@
use crate::Cents;
use crate::db::billing_subscription::SubscriptionKind;
use crate::db::{billing_subscription, user};
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::{Config, db::billing_preference};
@@ -33,8 +32,6 @@ pub struct LlmTokenClaims {
pub plan: Plan,
#[serde(default)]
pub subscription_period: Option<(NaiveDateTime, NaiveDateTime)>,
#[serde(default)]
pub can_use_web_search_tool: bool,
}
const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
@@ -46,6 +43,7 @@ impl LlmTokenClaims {
billing_preferences: Option<billing_preference::Model>,
feature_flags: &Vec<String>,
has_legacy_llm_subscription: bool,
plan: rpc::proto::Plan,
subscription: Option<billing_subscription::Model>,
system_id: Option<String>,
config: &Config,
@@ -72,7 +70,6 @@ impl LlmTokenClaims {
bypass_account_age_check: feature_flags
.iter()
.any(|flag| flag == "bypass-account-age-check"),
can_use_web_search_tool: feature_flags.iter().any(|flag| flag == "assistant2"),
has_llm_subscription: has_legacy_llm_subscription,
max_monthly_spend_in_cents: billing_preferences
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0, |preferences| {
@@ -81,14 +78,11 @@ impl LlmTokenClaims {
custom_llm_monthly_allowance_in_cents: user
.custom_llm_monthly_allowance_in_cents
.map(|allowance| allowance as u32),
plan: subscription
.as_ref()
.and_then(|subscription| subscription.kind)
.map_or(Plan::Free, |kind| match kind {
SubscriptionKind::ZedFree => Plan::Free,
SubscriptionKind::ZedPro => Plan::ZedPro,
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
}),
plan: match plan {
rpc::proto::Plan::Free => Plan::Free,
rpc::proto::Plan::ZedPro => Plan::ZedPro,
rpc::proto::Plan::ZedProTrial => Plan::ZedProTrial,
},
subscription_period: maybe!({
let subscription = subscription?;
let period_start_at = subscription.current_period_start_at()?;

View File

@@ -3,7 +3,7 @@ mod connection_pool;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims;
use crate::{
AppState, Error, Result, auth,
AppState, Config, Error, RateLimit, Result, auth,
db::{
self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@@ -33,8 +33,11 @@ use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use http_client::HttpClient;
use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
use reqwest_client::ReqwestClient;
use rpc::proto::split_repository_update;
use sha2::Digest;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use futures::{
@@ -131,6 +134,7 @@ struct Session {
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
supermaven_client: Option<Arc<SupermavenAdminApi>>,
http_client: Arc<dyn HttpClient>,
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
@@ -423,7 +427,29 @@ impl Server {
.add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
.add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context);
.add_message_handler(update_context)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
let app_state = app_state.clone();
async move {
count_language_model_tokens(request, response, session, &app_state.config)
.await
}
}
})
.add_request_handler(get_cached_embeddings)
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
compute_embeddings(
request,
response,
session,
app_state.config.openai_api_key.clone(),
)
}
});
Arc::new(server)
}
@@ -752,6 +778,7 @@ impl Server {
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
app_state: this.app_state.clone(),
http_client,
geoip_country_code,
system_id,
_executor: executor.clone(),
@@ -3670,6 +3697,223 @@ async fn acknowledge_buffer_version(
Ok(())
}
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
session: Session,
config: &Config,
) -> Result<()> {
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
proto::Plan::Free | proto::Plan::ZedProTrial => {
Box::new(FreeCountLanguageModelTokensRateLimit)
}
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
Some(proto::LanguageModelProvider::Google) => {
let api_key = config
.google_ai_api_key
.as_ref()
.context("no Google AI API key configured on the server")?;
google_ai::count_tokens(
session.http_client.as_ref(),
google_ai::API_URL,
api_key,
serde_json::from_str(&request.request)?,
)
.await?
}
_ => return Err(anyhow!("unsupported provider"))?,
};
response.send(proto::CountLanguageModelTokensResponse {
token_count: result.total_tokens as u32,
})?;
Ok(())
}
struct ZedProCountLanguageModelTokensRateLimit;
impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"zed-pro:count-language-model-tokens"
}
}
struct FreeCountLanguageModelTokensRateLimit;
impl RateLimit for FreeCountLanguageModelTokensRateLimit {
fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:count-language-model-tokens"
}
}
struct ZedProComputeEmbeddingsRateLimit;
impl RateLimit for ZedProComputeEmbeddingsRateLimit {
fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"zed-pro:compute-embeddings"
}
}
struct FreeComputeEmbeddingsRateLimit;
impl RateLimit for FreeComputeEmbeddingsRateLimit {
fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000 / 10) // Picked arbitrarily
}
fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
fn db_name(&self) -> &'static str {
"free:compute-embeddings"
}
}
async fn compute_embeddings(
request: proto::ComputeEmbeddings,
response: Response<proto::ComputeEmbeddings>,
session: Session,
api_key: Option<Arc<str>>,
) -> Result<()> {
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
};
session
.app_state
.rate_limiter
.check(&*rate_limit, session.user_id())
.await?;
let embeddings = match request.model.as_str() {
"openai/text-embedding-3-small" => {
open_ai::embed(
session.http_client.as_ref(),
OPEN_AI_API_URL,
&api_key,
OpenAiEmbeddingModel::TextEmbedding3Small,
request.texts.iter().map(|text| text.as_str()),
)
.await?
}
provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
};
let embeddings = request
.texts
.iter()
.map(|text| {
let mut hasher = sha2::Sha256::new();
hasher.update(text.as_bytes());
let result = hasher.finalize();
result.to_vec()
})
.zip(
embeddings
.data
.into_iter()
.map(|embedding| embedding.embedding),
)
.collect::<HashMap<_, _>>();
let db = session.db().await;
db.save_embeddings(&request.model, &embeddings)
.await
.context("failed to save embeddings")
.trace_err();
response.send(proto::ComputeEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
})?;
Ok(())
}
async fn get_cached_embeddings(
request: proto::GetCachedEmbeddings,
response: Response<proto::GetCachedEmbeddings>,
session: Session,
) -> Result<()> {
authorize_access_to_legacy_llm_endpoints(&session).await?;
let db = session.db().await;
let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
response.send(proto::GetCachedEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
})?;
Ok(())
}
/// This is leftover from before the LLM service.
///
/// The endpoints protected by this check will be moved there eventually.
async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
if session.is_staff() {
Ok(())
} else {
Err(anyhow!("permission denied"))?
}
}
/// Get a Supermaven API key for the user
async fn get_supermaven_api_key(
_request: proto::GetSupermavenApiKey,
@@ -3903,6 +4147,7 @@ async fn get_llm_api_token(
billing_preferences,
&flags,
has_legacy_llm_subscription,
session.current_plan(&db).await?,
billing_subscription,
session.system_id.clone(),
&session.app_state.config,

View File

@@ -416,16 +416,9 @@ impl StripeBilling {
let mut params = stripe::CreateCheckoutSession::new();
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
trial_period_days: Some(14),
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
}
}),
..Default::default()
});
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.payment_method_collection =
Some(stripe::CheckoutSessionPaymentMethodCollection::IfRequired);
params.customer = Some(customer_id);
params.client_reference_id = Some(github_login);
params.line_items = Some(vec![stripe::CreateCheckoutSessionLineItems {

View File

@@ -1,7 +1,7 @@
use call::ActiveCall;
use dap::DebugRequestType;
use dap::requests::{Initialize, Launch, StackTrace};
use dap::{SourceBreakpoint, requests::SetBreakpoints};
use dap::DebugRequestType;
use dap::{requests::SetBreakpoints, SourceBreakpoint};
use debugger_ui::debugger_panel::DebugPanel;
use debugger_ui::session::DebugSession;
use editor::Editor;
@@ -13,7 +13,7 @@ use std::{
path::Path,
sync::atomic::{AtomicBool, Ordering},
};
use workspace::{Workspace, dock::Panel};
use workspace::{dock::Panel, Workspace};
use super::{TestClient, TestServer};

View File

@@ -1544,7 +1544,6 @@ async fn test_mutual_editor_inlay_hint_cache_update(
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
settings.defaults.inlay_hints = Some(InlayHintSettings {
enabled: true,
show_value_hints: true,
edit_debounce_ms: 0,
scroll_debounce_ms: 0,
show_type_hints: true,
@@ -1560,7 +1559,6 @@ async fn test_mutual_editor_inlay_hint_cache_update(
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
settings.defaults.inlay_hints = Some(InlayHintSettings {
show_value_hints: true,
enabled: true,
edit_debounce_ms: 0,
scroll_debounce_ms: 0,
@@ -1780,7 +1778,6 @@ async fn test_inlay_hint_refresh_is_forwarded(
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
settings.defaults.inlay_hints = Some(InlayHintSettings {
show_value_hints: true,
enabled: false,
edit_debounce_ms: 0,
scroll_debounce_ms: 0,
@@ -1797,7 +1794,6 @@ async fn test_inlay_hint_refresh_is_forwarded(
SettingsStore::update_global(cx, |store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
settings.defaults.inlay_hints = Some(InlayHintSettings {
show_value_hints: true,
enabled: true,
edit_debounce_ms: 0,
scroll_debounce_ms: 0,

View File

@@ -6,18 +6,17 @@ use collab_ui::{
channel_view::ChannelView,
notifications::project_shared_notification::ProjectSharedNotification,
};
use editor::{Editor, MultiBuffer, PathKey};
use editor::{Editor, ExcerptRange, MultiBuffer};
use gpui::{
AppContext as _, BackgroundExecutor, BorrowAppContext, Entity, SharedString, TestAppContext,
VisualContext, VisualTestContext, point,
VisualTestContext, point,
};
use language::Capability;
use project::WorktreeSettings;
use rpc::proto::PeerId;
use serde_json::json;
use settings::SettingsStore;
use text::{Point, ToPoint};
use util::{path, test::sample_text};
use util::path;
use workspace::{SplitDirection, Workspace, item::ItemHandle as _};
use super::TestClient;
@@ -296,20 +295,8 @@ async fn test_basic_following(
.unwrap()
});
let mut result = MultiBuffer::new(Capability::ReadWrite);
result.set_excerpts_for_path(
PathKey::for_buffer(&buffer_a1, cx),
buffer_a1,
[Point::row_range(1..2)],
1,
cx,
);
result.set_excerpts_for_path(
PathKey::for_buffer(&buffer_a2, cx),
buffer_a2,
[Point::row_range(5..6)],
1,
cx,
);
result.push_excerpts(buffer_a1, [ExcerptRange::new(0..3)], cx);
result.push_excerpts(buffer_a2, [ExcerptRange::new(4..7)], cx);
result
});
let multibuffer_editor_a = workspace_a.update_in(cx_a, |workspace, window, cx| {
@@ -2083,83 +2070,6 @@ async fn share_workspace(
.await
}
#[gpui::test]
async fn test_following_after_replacement(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
let (_server, client_a, client_b, channel) = TestServer::start2(cx_a, cx_b).await;
let (workspace, cx_a) = client_a.build_test_workspace(cx_a).await;
join_channel(channel, &client_a, cx_a).await.unwrap();
share_workspace(&workspace, cx_a).await.unwrap();
let buffer = workspace.update(cx_a, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.create_local_buffer(&sample_text(26, 5, 'a'), None, cx)
})
});
let multibuffer = cx_a.new(|cx| {
let mut mb = MultiBuffer::new(Capability::ReadWrite);
mb.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer.clone(),
[Point::row_range(1..1), Point::row_range(5..5)],
1,
cx,
);
mb
});
let snapshot = buffer.update(cx_a, |buffer, _| buffer.snapshot());
let editor: Entity<Editor> = cx_a.new_window_entity(|window, cx| {
Editor::for_multibuffer(
multibuffer.clone(),
Some(workspace.read(cx).project().clone()),
window,
cx,
)
});
workspace.update_in(cx_a, |workspace, window, cx| {
workspace.add_item_to_center(Box::new(editor.clone()) as _, window, cx)
});
editor.update_in(cx_a, |editor, window, cx| {
editor.change_selections(None, window, cx, |s| {
s.select_ranges([Point::row_range(4..4)]);
})
});
let positions = editor.update(cx_a, |editor, _| {
editor
.selections
.disjoint_anchor_ranges()
.map(|range| range.start.text_anchor.to_point(&snapshot))
.collect::<Vec<_>>()
});
multibuffer.update(cx_a, |multibuffer, cx| {
multibuffer.set_excerpts_for_path(
PathKey::for_buffer(&buffer, cx),
buffer,
[Point::row_range(1..5)],
1,
cx,
);
});
let (workspace_b, cx_b) = client_b.join_workspace(channel, cx_b).await;
cx_b.run_until_parked();
let editor_b = workspace_b
.update(cx_b, |workspace, cx| {
workspace
.active_item(cx)
.and_then(|item| item.downcast::<Editor>())
})
.unwrap();
let new_positions = editor_b.update(cx_b, |editor, _| {
editor
.selections
.disjoint_anchor_ranges()
.map(|range| range.start.text_anchor.to_point(&snapshot))
.collect::<Vec<_>>()
});
assert_eq!(positions, new_positions);
}
#[gpui::test]
async fn test_following_to_channel_notes_other_workspace(
cx_a: &mut TestAppContext,

View File

@@ -2,13 +2,11 @@ use crate::tests::TestServer;
use call::ActiveCall;
use collections::{HashMap, HashSet};
use debugger_ui::debugger_panel::DebugPanel;
use extension::ExtensionHostProxy;
use fs::{FakeFs, Fs as _, RemoveOptions};
use futures::StreamExt as _;
use gpui::{
AppContext as _, BackgroundExecutor, SemanticVersion, TestAppContext, UpdateGlobal as _,
VisualContext,
};
use http_client::BlockedHttpClient;
use language::{
@@ -26,7 +24,6 @@ use project::{
};
use remote::SshRemoteClient;
use remote_server::{HeadlessAppState, HeadlessProject};
use rpc::proto;
use serde_json::json;
use settings::SettingsStore;
use std::{path::Path, sync::Arc};
@@ -579,108 +576,3 @@ async fn test_ssh_collaboration_formatting_with_prettier(
"Prettier formatting was not applied to client buffer after host's request"
);
}
#[gpui::test]
async fn test_remote_server_debugger(cx_a: &mut TestAppContext, server_cx: &mut TestAppContext) {
cx_a.update(|cx| {
release_channel::init(SemanticVersion::default(), cx);
command_palette_hooks::init(cx);
if std::env::var("RUST_LOG").is_ok() {
env_logger::try_init().ok();
}
});
server_cx.update(|cx| {
release_channel::init(SemanticVersion::default(), cx);
});
let (opts, server_ssh) = SshRemoteClient::fake_server(cx_a, server_cx);
let remote_fs = FakeFs::new(server_cx.executor());
remote_fs
.insert_tree(
path!("/code"),
json!({
"lib.rs": "fn one() -> usize { 1 }"
}),
)
.await;
// User A connects to the remote project via SSH.
server_cx.update(HeadlessProject::init);
let remote_http_client = Arc::new(BlockedHttpClient);
let node = NodeRuntime::unavailable();
let languages = Arc::new(LanguageRegistry::new(server_cx.executor()));
let _headless_project = server_cx.new(|cx| {
client::init_settings(cx);
HeadlessProject::new(
HeadlessAppState {
session: server_ssh,
fs: remote_fs.clone(),
http_client: remote_http_client,
node_runtime: node,
languages,
extension_host_proxy: Arc::new(ExtensionHostProxy::new()),
},
cx,
)
});
let client_ssh = SshRemoteClient::fake_client(opts, cx_a).await;
let mut server = TestServer::start(server_cx.executor()).await;
let client_a = server.create_client(cx_a, "user_a").await;
cx_a.update(|cx| {
debugger_ui::init(cx);
command_palette_hooks::init(cx);
});
let (project_a, _) = client_a
.build_ssh_project(path!("/code"), client_ssh.clone(), cx_a)
.await;
let (workspace, cx_a) = client_a.build_workspace(&project_a, cx_a);
let debugger_panel = workspace
.update_in(cx_a, |_workspace, window, cx| {
cx.spawn_in(window, DebugPanel::load)
})
.await
.unwrap();
workspace.update_in(cx_a, |workspace, window, cx| {
workspace.add_panel(debugger_panel, window, cx);
});
cx_a.run_until_parked();
let debug_panel = workspace
.update(cx_a, |workspace, cx| workspace.panel::<DebugPanel>(cx))
.unwrap();
let workspace_window = cx_a
.window_handle()
.downcast::<workspace::Workspace>()
.unwrap();
let session = debugger_ui::tests::start_debug_session(&workspace_window, cx_a, |_| {}).unwrap();
cx_a.run_until_parked();
debug_panel.update(cx_a, |debug_panel, cx| {
assert_eq!(
debug_panel.active_session().unwrap().read(cx).session(cx),
session
)
});
session.update(cx_a, |session, _| {
assert_eq!(session.binary().command, "ssh");
});
let shutdown_session = workspace.update(cx_a, |workspace, cx| {
workspace.project().update(cx, |project, cx| {
project.dap_store().update(cx, |dap_store, cx| {
dap_store.shutdown_session(session.read(cx).session_id(), cx)
})
})
});
client_ssh.update(cx_a, |a, _| {
a.shutdown_processes(Some(proto::ShutdownRemoteServer {}))
});
shutdown_session.await.unwrap();
}

View File

@@ -86,8 +86,6 @@ impl Settings for CollaborationPanelSettings {
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
impl Settings for ChatPanelSettings {
@@ -101,8 +99,6 @@ impl Settings for ChatPanelSettings {
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
impl Settings for NotificationPanelSettings {
@@ -116,8 +112,6 @@ impl Settings for NotificationPanelSettings {
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
impl Settings for MessageEditorSettings {
@@ -131,6 +125,4 @@ impl Settings for MessageEditorSettings {
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow, bail};
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
use gpui::{AnyWindowHandle, App, Entity, Task};
use gpui::{App, Entity, Task};
use icons::IconName;
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
use project::Project;
@@ -77,7 +77,6 @@ impl Tool for ContextServerTool {
_messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_action_log: Entity<ActionLog>,
_window: Option<AnyWindowHandle>,
cx: &mut App,
) -> ToolResult {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {

View File

@@ -58,42 +58,4 @@ impl Settings for ContextServerSettings {
) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
// we don't handle "inputs" replacement strings, see perplexity-key in this example:
// https://code.visualstudio.com/docs/copilot/chat/mcp-servers#_configuration-example
#[derive(Deserialize)]
struct VsCodeServerCommand {
command: String,
args: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
// note: we don't support envFile and type
}
impl From<VsCodeServerCommand> for ServerCommand {
fn from(cmd: VsCodeServerCommand) -> Self {
Self {
path: cmd.command,
args: cmd.args.unwrap_or_default(),
env: cmd.env,
}
}
}
if let Some(mcp) = vscode.read_value("mcp").and_then(|v| v.as_object()) {
current
.context_servers
.extend(mcp.iter().filter_map(|(k, v)| {
Some((
k.clone().into(),
ServerConfig {
command: Some(
serde_json::from_value::<VsCodeServerCommand>(v.clone())
.ok()?
.into(),
),
settings: None,
},
))
}));
}
}
}

View File

@@ -36,7 +36,6 @@ gpui.workspace = true
http_client.workspace = true
language.workspace = true
log.workspace = true
lsp-types.workspace = true
node_runtime.workspace = true
parking_lot.workspace = true
paths.workspace = true

View File

@@ -284,10 +284,6 @@ pub async fn fetch_latest_adapter_version_from_github(
})
}
pub trait InlineValueProvider {
fn provide(&self, variables: Vec<(String, lsp_types::Range)>) -> Vec<lsp_types::InlineValue>;
}
#[async_trait(?Send)]
pub trait DebugAdapter: 'static + Send + Sync {
fn name(&self) -> DebugAdapterName;
@@ -377,12 +373,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
user_installed_path: Option<PathBuf>,
cx: &mut AsyncApp,
) -> Result<DebugAdapterBinary>;
fn inline_value_provider(&self) -> Option<Box<dyn InlineValueProvider>> {
None
}
}
#[cfg(any(test, feature = "test-support"))]
pub struct FakeAdapter {}

View File

@@ -54,8 +54,6 @@ impl Settings for DebuggerSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> anyhow::Result<Self> {
sources.json_merge()
}
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
}
impl Global for DebuggerSettings {}

View File

@@ -22,7 +22,7 @@ use std::{
time::Duration,
};
use task::TcpArgumentsTemplate;
use util::{ResultExt as _, TryFutureExt};
use util::ResultExt as _;
use crate::{adapters::DebugAdapterBinary, debugger_settings::DebuggerSettings};
@@ -126,7 +126,6 @@ pub(crate) struct TransportDelegate {
pending_requests: Requests,
transport: Transport,
server_tx: Arc<Mutex<Option<Sender<Message>>>>,
_tasks: Vec<gpui::Task<Option<()>>>,
}
impl TransportDelegate {
@@ -141,7 +140,6 @@ impl TransportDelegate {
log_handlers: Default::default(),
current_requests: Default::default(),
pending_requests: Default::default(),
_tasks: Default::default(),
};
let messages = this.start_handlers(transport_pipes, cx).await?;
Ok((messages, this))
@@ -168,43 +166,35 @@ impl TransportDelegate {
cx.update(|cx| {
if let Some(stdout) = params.stdout.take() {
self._tasks.push(
cx.background_executor()
.spawn(Self::handle_adapter_log(stdout, log_handler.clone()).log_err()),
);
cx.background_executor()
.spawn(Self::handle_adapter_log(stdout, log_handler.clone()))
.detach_and_log_err(cx);
}
self._tasks.push(
cx.background_executor().spawn(
Self::handle_output(
params.output,
client_tx,
self.pending_requests.clone(),
log_handler.clone(),
)
.log_err(),
),
);
cx.background_executor()
.spawn(Self::handle_output(
params.output,
client_tx,
self.pending_requests.clone(),
log_handler.clone(),
))
.detach_and_log_err(cx);
if let Some(stderr) = params.stderr.take() {
self._tasks.push(
cx.background_executor()
.spawn(Self::handle_error(stderr, self.log_handlers.clone()).log_err()),
);
cx.background_executor()
.spawn(Self::handle_error(stderr, self.log_handlers.clone()))
.detach_and_log_err(cx);
}
self._tasks.push(
cx.background_executor().spawn(
Self::handle_input(
params.input,
client_rx,
self.current_requests.clone(),
self.pending_requests.clone(),
log_handler.clone(),
)
.log_err(),
),
);
cx.background_executor()
.spawn(Self::handle_input(
params.input,
client_rx,
self.current_requests.clone(),
self.pending_requests.clone(),
log_handler.clone(),
))
.detach_and_log_err(cx);
})?;
{
@@ -377,7 +367,6 @@ impl TransportDelegate {
where
Stderr: AsyncRead + Unpin + Send + 'static,
{
log::debug!("Handle error started");
let mut buffer = String::new();
let mut reader = BufReader::new(stderr);

View File

@@ -26,7 +26,6 @@ async-trait.workspace = true
dap.workspace = true
gpui.workspace = true
language.workspace = true
lsp-types.workspace = true
paths.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -2,7 +2,7 @@ use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
use anyhow::{Result, bail};
use async_trait::async_trait;
use dap::adapters::{InlineValueProvider, latest_github_release};
use dap::adapters::latest_github_release;
use gpui::AsyncApp;
use task::{DebugRequest, DebugTaskDefinition};
@@ -150,25 +150,4 @@ impl DebugAdapter for CodeLldbDebugAdapter {
connection: None,
})
}
fn inline_value_provider(&self) -> Option<Box<dyn InlineValueProvider>> {
Some(Box::new(CodeLldbInlineValueProvider))
}
}
struct CodeLldbInlineValueProvider;
impl InlineValueProvider for CodeLldbInlineValueProvider {
fn provide(&self, variables: Vec<(String, lsp_types::Range)>) -> Vec<lsp_types::InlineValue> {
variables
.into_iter()
.map(|(variable, range)| {
lsp_types::InlineValue::VariableLookup(lsp_types::InlineValueVariableLookup {
range,
variable_name: Some(variable),
case_sensitive_lookup: true,
})
})
.collect()
}
}

View File

@@ -1,5 +1,5 @@
use crate::*;
use dap::{StartDebuggingRequestArguments, adapters::InlineValueProvider};
use dap::{DebugRequest, StartDebuggingRequestArguments};
use gpui::AsyncApp;
use std::{collections::HashMap, ffi::OsStr, path::PathBuf};
use task::DebugTaskDefinition;
@@ -160,34 +160,4 @@ impl DebugAdapter for PythonDebugAdapter {
request_args: self.request_args(config),
})
}
fn inline_value_provider(&self) -> Option<Box<dyn InlineValueProvider>> {
Some(Box::new(PythonInlineValueProvider))
}
}
struct PythonInlineValueProvider;
impl InlineValueProvider for PythonInlineValueProvider {
fn provide(&self, variables: Vec<(String, lsp_types::Range)>) -> Vec<lsp_types::InlineValue> {
variables
.into_iter()
.map(|(variable, range)| {
if variable.contains(".") || variable.contains("[") {
lsp_types::InlineValue::EvaluatableExpression(
lsp_types::InlineValueEvaluatableExpression {
range,
expression: Some(variable),
},
)
} else {
lsp_types::InlineValue::VariableLookup(lsp_types::InlineValueVariableLookup {
range,
variable_name: Some(variable),
case_sensitive_lookup: true,
})
}
})
.collect()
}
}

View File

@@ -12,9 +12,6 @@ workspace = true
path = "src/debugger_tools.rs"
doctest = false
[features]
test-support = []
[dependencies]
anyhow.workspace = true
dap.workspace = true

View File

@@ -41,7 +41,7 @@ struct DapLogView {
_subscriptions: Vec<Subscription>,
}
pub struct LogStore {
struct LogStore {
projects: HashMap<WeakEntity<Project>, ProjectState>,
debug_clients: HashMap<SessionId, DebugAdapterState>,
rpc_tx: UnboundedSender<(SessionId, IoKind, String)>,
@@ -101,7 +101,7 @@ impl DebugAdapterState {
}
impl LogStore {
pub fn new(cx: &Context<Self>) -> Self {
fn new(cx: &Context<Self>) -> Self {
let (rpc_tx, mut rpc_rx) = unbounded::<(SessionId, IoKind, String)>();
cx.spawn(async move |this, cx| {
while let Some((client_id, io_kind, message)) = rpc_rx.next().await {
@@ -845,29 +845,3 @@ impl EventEmitter<Event> for LogStore {}
impl EventEmitter<Event> for DapLogView {}
impl EventEmitter<EditorEvent> for DapLogView {}
impl EventEmitter<SearchEvent> for DapLogView {}
#[cfg(any(test, feature = "test-support"))]
impl LogStore {
pub fn contained_session_ids(&self) -> Vec<SessionId> {
self.debug_clients.keys().cloned().collect()
}
pub fn rpc_messages_for_session_id(&self, session_id: SessionId) -> Vec<String> {
self.debug_clients
.get(&session_id)
.expect("This session should exist if a test is calling")
.rpc_messages
.messages
.clone()
.into()
}
pub fn log_messages_for_session_id(&self, session_id: SessionId) -> Vec<String> {
self.debug_clients
.get(&session_id)
.expect("This session should exist if a test is calling")
.log_messages
.clone()
.into()
}
}

View File

@@ -20,9 +20,6 @@ test-support = [
"project/test-support",
"util/test-support",
"workspace/test-support",
"env_logger",
"unindent",
"debugger_tools"
]
[dependencies]
@@ -40,7 +37,6 @@ gpui.workspace = true
language.workspace = true
log.workspace = true
menu.workspace = true
parking_lot.workspace = true
picker.workspace = true
pretty_assertions.workspace = true
project.workspace = true
@@ -57,13 +53,9 @@ ui.workspace = true
util.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
env_logger = { workspace = true, optional = true }
debugger_tools = { workspace = true, optional = true }
unindent = { workspace = true, optional = true }
[dev-dependencies]
dap = { workspace = true, features = ["test-support"] }
debugger_tools = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -1,7 +1,7 @@
use dap::DebugRequest;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::Subscription;
use gpui::{DismissEvent, Entity, EventEmitter, Focusable, Render};
use gpui::{Subscription, WeakEntity};
use picker::{Picker, PickerDelegate};
use std::sync::Arc;
@@ -9,9 +9,7 @@ use sysinfo::System;
use ui::{Context, Tooltip, prelude::*};
use ui::{ListItem, ListItemSpacing};
use util::debug_panic;
use workspace::{ModalView, Workspace};
use crate::debugger_panel::DebugPanel;
use workspace::ModalView;
#[derive(Debug, Clone)]
pub(super) struct Candidate {
@@ -24,19 +22,19 @@ pub(crate) struct AttachModalDelegate {
selected_index: usize,
matches: Vec<StringMatch>,
placeholder_text: Arc<str>,
workspace: WeakEntity<Workspace>,
project: Entity<project::Project>,
pub(crate) debug_config: task::DebugTaskDefinition,
candidates: Arc<[Candidate]>,
}
impl AttachModalDelegate {
fn new(
workspace: Entity<Workspace>,
project: Entity<project::Project>,
debug_config: task::DebugTaskDefinition,
candidates: Arc<[Candidate]>,
) -> Self {
Self {
workspace: workspace.downgrade(),
project,
debug_config,
candidates,
selected_index: 0,
@@ -53,7 +51,7 @@ pub struct AttachModal {
impl AttachModal {
pub fn new(
workspace: Entity<Workspace>,
project: Entity<project::Project>,
debug_config: task::DebugTaskDefinition,
modal: bool,
window: &mut Window,
@@ -77,11 +75,11 @@ impl AttachModal {
.collect();
processes.sort_by_key(|k| k.name.clone());
let processes = processes.into_iter().collect();
Self::with_processes(workspace, debug_config, processes, modal, window, cx)
Self::with_processes(project, debug_config, processes, modal, window, cx)
}
pub(super) fn with_processes(
workspace: Entity<Workspace>,
project: Entity<project::Project>,
debug_config: task::DebugTaskDefinition,
processes: Arc<[Candidate]>,
modal: bool,
@@ -90,7 +88,7 @@ impl AttachModal {
) -> Self {
let picker = cx.new(|cx| {
Picker::uniform_list(
AttachModalDelegate::new(workspace, debug_config, processes),
AttachModalDelegate::new(project, debug_config, processes),
window,
cx,
)
@@ -204,7 +202,7 @@ impl PickerDelegate for AttachModalDelegate {
})
}
fn confirm(&mut self, _: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
fn confirm(&mut self, _: bool, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
let candidate = self
.matches
.get(self.selected_index())
@@ -227,17 +225,14 @@ impl PickerDelegate for AttachModalDelegate {
}
}
let definition = self.debug_config.clone();
let panel = self
.workspace
.update(cx, |workspace, cx| workspace.panel::<DebugPanel>(cx))
.ok()
.flatten();
if let Some(panel) = panel {
panel.update(cx, |panel, cx| {
panel.start_session(definition, window, cx);
});
}
let config = self.debug_config.clone();
self.project
.update(cx, |project, cx| {
let ret = project.start_debug_session(config, cx);
ret
})
.detach_and_log_err(cx);
cx.emit(DismissEvent);
}

View File

@@ -6,7 +6,6 @@ use crate::{new_session_modal::NewSessionModal, session::DebugSession};
use anyhow::{Result, anyhow};
use collections::HashMap;
use command_palette_hooks::CommandPaletteFilter;
use dap::StartDebuggingRequestArguments;
use dap::{
ContinuedEvent, LoadedSourceEvent, ModuleEvent, OutputEvent, StoppedEvent, ThreadEvent,
client::SessionId, debugger_settings::DebuggerSettings,
@@ -18,7 +17,6 @@ use gpui::{
actions, anchored, deferred,
};
use project::debugger::session::{Session, SessionStateEvent};
use project::{
Project,
debugger::{
@@ -32,9 +30,10 @@ use settings::Settings;
use std::any::TypeId;
use std::path::Path;
use std::sync::Arc;
use task::{DebugTaskDefinition, DebugTaskTemplate};
use task::DebugTaskDefinition;
use terminal_view::terminal_panel::TerminalPanel;
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
use util::debug_panic;
use workspace::{
Workspace,
dock::{DockPosition, Panel, PanelEvent},
@@ -64,7 +63,7 @@ pub struct DebugPanel {
active_session: Option<Entity<DebugSession>>,
/// This represents the last debug definition that was created in the new session modal
pub(crate) past_debug_definition: Option<DebugTaskDefinition>,
project: Entity<Project>,
project: WeakEntity<Project>,
workspace: WeakEntity<Workspace>,
focus_handle: FocusHandle,
context_menu: Option<(Entity<ContextMenu>, Point<Pixels>, Subscription)>,
@@ -98,10 +97,10 @@ impl DebugPanel {
window,
|panel, _, event: &tasks_ui::ShowAttachModal, window, cx| {
panel.workspace.update(cx, |workspace, cx| {
let workspace_handle = cx.entity().clone();
let project = workspace.project().clone();
workspace.toggle_modal(window, cx, |window, cx| {
crate::attach_modal::AttachModal::new(
workspace_handle,
project,
event.debug_config.clone(),
true,
window,
@@ -128,7 +127,7 @@ impl DebugPanel {
_subscriptions,
past_debug_definition: None,
focus_handle: cx.focus_handle(),
project,
project: project.downgrade(),
workspace: workspace.weak_handle(),
context_menu: None,
};
@@ -220,7 +219,7 @@ impl DebugPanel {
pub fn load(
workspace: WeakEntity<Workspace>,
cx: &mut AsyncWindowContext,
cx: AsyncWindowContext,
) -> Task<Result<Entity<Self>>> {
cx.spawn(async move |cx| {
workspace.update_in(cx, |workspace, window, cx| {
@@ -246,226 +245,114 @@ impl DebugPanel {
});
})
.detach();
workspace.set_debugger_provider(DebuggerProvider(debug_panel.clone()));
debug_panel
})
})
}
pub fn start_session(
&mut self,
definition: DebugTaskDefinition,
window: &mut Window,
cx: &mut Context<Self>,
) {
let task_contexts = self
.workspace
.update(cx, |workspace, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})
.ok();
let dap_store = self.project.read(cx).dap_store().clone();
cx.spawn_in(window, async move |this, cx| {
let task_context = if let Some(task) = task_contexts {
task.await
.active_worktree_context
.map_or(task::TaskContext::default(), |context| context.1)
} else {
task::TaskContext::default()
};
let (session, task) = dap_store.update(cx, |dap_store, cx| {
let template = DebugTaskTemplate {
locator: None,
definition: definition.clone(),
};
let session = if let Some(debug_config) = template
.to_zed_format()
.resolve_task("debug_task", &task_context)
.and_then(|resolved_task| resolved_task.resolved_debug_adapter_config())
{
dap_store.new_session(debug_config.definition, None, cx)
} else {
dap_store.new_session(definition.clone(), None, cx)
};
(session.clone(), dap_store.boot_session(session, cx))
})?;
match task.await {
Err(e) => {
this.update(cx, |this, cx| {
this.workspace
.update(cx, |workspace, cx| {
workspace.show_error(&e, cx);
})
.ok();
})
.ok();
session
.update(cx, |session, cx| session.shutdown(cx))?
.await;
}
Ok(_) => Self::register_session(this, session, cx).await?,
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
async fn register_session(
this: WeakEntity<Self>,
session: Entity<Session>,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let adapter_name = session.update(cx, |session, _| session.adapter_name())?;
this.update_in(cx, |_, window, cx| {
cx.subscribe_in(
&session,
window,
move |_, session, event: &SessionStateEvent, window, cx| match event {
SessionStateEvent::Restart => {
let mut curr_session = session.clone();
while let Some(parent_session) = curr_session
.read_with(cx, |session, _| session.parent_session().cloned())
{
curr_session = parent_session;
}
let definition = curr_session.update(cx, |session, _| session.definition());
let task = curr_session.update(cx, |session, cx| session.shutdown(cx));
let definition = definition.clone();
cx.spawn_in(window, async move |this, cx| {
task.await;
this.update_in(cx, |this, window, cx| {
this.start_session(definition, window, cx)
})
})
.detach_and_log_err(cx);
}
_ => {}
},
)
.detach();
})
.ok();
let serialized_layout = persistence::get_serialized_pane_layout(adapter_name).await;
let workspace = this.update_in(cx, |this, window, cx| {
this.sessions.retain(|session| {
session
.read(cx)
.mode()
.as_running()
.map_or(false, |running_state| {
!running_state.read(cx).session().read(cx).is_terminated()
})
});
let session_item = DebugSession::running(
this.project.clone(),
this.workspace.clone(),
session,
cx.weak_entity(),
serialized_layout,
window,
cx,
);
if let Some(running) = session_item.read(cx).mode().as_running().cloned() {
// We might want to make this an event subscription and only notify when a new thread is selected
// This is used to filter the command menu correctly
cx.observe(&running, |_, _, cx| cx.notify()).detach();
}
this.sessions.push(session_item.clone());
this.activate_session(session_item, window, cx);
this.workspace.clone()
})?;
workspace.update_in(cx, |workspace, window, cx| {
workspace.focus_panel::<Self>(window, cx);
})?;
Ok(())
}
pub fn start_child_session(
&mut self,
request: &StartDebuggingRequestArguments,
parent_session: Entity<Session>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(worktree) = parent_session.read(cx).worktree() else {
log::error!("Attempted to start a child session from non local debug session");
return;
};
let dap_store_handle = self.project.read(cx).dap_store().clone();
let breakpoint_store = self.project.read(cx).breakpoint_store();
let definition = parent_session.read(cx).definition().clone();
let mut binary = parent_session.read(cx).binary().clone();
binary.request_args = request.clone();
cx.spawn_in(window, async move |this, cx| {
let (session, task) = dap_store_handle.update(cx, |dap_store, cx| {
let session =
dap_store.new_session(definition.clone(), Some(parent_session.clone()), cx);
let task = session.update(cx, |session, cx| {
session.boot(
binary,
worktree,
breakpoint_store,
dap_store_handle.downgrade(),
cx,
)
});
(session, task)
})?;
match task.await {
Err(e) => {
this.update(cx, |this, cx| {
this.workspace
.update(cx, |workspace, cx| {
workspace.show_error(&e, cx);
})
.ok();
})
.ok();
session
.update(cx, |session, cx| session.shutdown(cx))?
.await;
}
Ok(_) => Self::register_session(this, session, cx).await?,
}
anyhow::Ok(())
})
.detach_and_log_err(cx);
}
pub fn active_session(&self) -> Option<Entity<DebugSession>> {
self.active_session.clone()
}
pub fn debug_panel_items_by_client(
&self,
client_id: &SessionId,
cx: &Context<Self>,
) -> Vec<Entity<DebugSession>> {
self.sessions
.iter()
.filter(|item| item.read(cx).session_id(cx) == *client_id)
.map(|item| item.clone())
.collect()
}
pub fn debug_panel_item_by_client(
&self,
client_id: SessionId,
cx: &mut Context<Self>,
) -> Option<Entity<DebugSession>> {
self.sessions
.iter()
.find(|item| {
let item = item.read(cx);
item.session_id(cx) == client_id
})
.cloned()
}
fn handle_dap_store_event(
&mut self,
_dap_store: &Entity<DapStore>,
dap_store: &Entity<DapStore>,
event: &dap_store::DapStoreEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
match event {
dap_store::DapStoreEvent::DebugSessionInitialized(session_id) => {
let Some(session) = dap_store.read(cx).session_by_id(session_id) else {
return log::error!(
"Couldn't get session with id: {session_id:?} from DebugClientStarted event"
);
};
let adapter_name = session.read(cx).adapter_name();
let session_id = *session_id;
cx.spawn_in(window, async move |this, cx| {
let serialized_layout =
persistence::get_serialized_pane_layout(adapter_name).await;
this.update_in(cx, |this, window, cx| {
let Some(project) = this.project.upgrade() else {
return log::error!(
"Debug Panel out lived it's weak reference to Project"
);
};
if this
.sessions
.iter()
.any(|item| item.read(cx).session_id(cx) == session_id)
{
// We already have an item for this session.
debug_panic!("We should never reuse session ids");
return;
}
this.sessions.retain(|session| {
session
.read(cx)
.mode()
.as_running()
.map_or(false, |running_state| {
!running_state.read(cx).session().read(cx).is_terminated()
})
});
let session_item = DebugSession::running(
project,
this.workspace.clone(),
session,
cx.weak_entity(),
serialized_layout,
window,
cx,
);
if let Some(running) = session_item.read(cx).mode().as_running().cloned() {
// We might want to make this an event subscription and only notify when a new thread is selected
// This is used to filter the command menu correctly
cx.observe(&running, |_, _, cx| cx.notify()).detach();
}
this.sessions.push(session_item.clone());
this.activate_session(session_item, window, cx);
})
})
.detach();
}
dap_store::DapStoreEvent::RunInTerminal {
title,
cwd,
@@ -487,12 +374,6 @@ impl DebugPanel {
)
.detach_and_log_err(cx);
}
dap_store::DapStoreEvent::SpawnChildSession {
request,
parent_session,
} => {
self.start_child_session(request, parent_session.clone(), window, cx);
}
_ => {}
}
}
@@ -527,7 +408,7 @@ impl DebugPanel {
cwd,
title,
},
task::RevealStrategy::Never,
task::RevealStrategy::Always,
window,
cx,
);
@@ -587,6 +468,8 @@ impl DebugPanel {
let session = this.dap_store().read(cx).session_by_id(session_id);
session.map(|session| !session.read(cx).is_terminated())
})
.ok()
.flatten()
.unwrap_or_default();
cx.spawn_in(window, async move |this, cx| {
@@ -1010,6 +893,7 @@ impl DebugPanel {
impl EventEmitter<PanelEvent> for DebugPanel {}
impl EventEmitter<DebugPanelEvent> for DebugPanel {}
impl EventEmitter<project::Event> for DebugPanel {}
impl Focusable for DebugPanel {
fn focus_handle(&self, _: &App) -> FocusHandle {
@@ -1155,15 +1039,3 @@ impl Render for DebugPanel {
.into_any()
}
}
struct DebuggerProvider(Entity<DebugPanel>);
impl workspace::DebuggerProvider for DebuggerProvider {
fn start_session(&self, definition: DebugTaskDefinition, window: &mut Window, cx: &mut App) {
self.0.update(cx, |_, cx| {
cx.defer_in(window, |this, window, cx| {
this.start_session(definition, window, cx);
})
})
}
}

View File

@@ -16,7 +16,7 @@ mod new_session_modal;
mod persistence;
pub(crate) mod session;
#[cfg(any(test, feature = "test-support"))]
#[cfg(test)]
pub mod tests;
actions!(
@@ -247,7 +247,7 @@ pub fn init(cx: &mut App) {
let stack_id = state.selected_stack_frame_id(cx);
state.session().update(cx, |session, cx| {
session.evaluate(text, None, stack_id, None, cx).detach();
session.evaluate(text, None, stack_id, None, cx);
});
});
Some(())

View File

@@ -4,12 +4,14 @@ use std::{
path::{Path, PathBuf},
};
use anyhow::{Result, anyhow};
use dap::{DapRegistry, DebugRequest};
use editor::{Editor, EditorElement, EditorStyle};
use gpui::{
App, AppContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, TextStyle,
WeakEntity,
};
use project::Project;
use settings::Settings;
use task::{DebugTaskDefinition, DebugTaskTemplate, LaunchRequest};
use theme::ThemeSettings;
@@ -19,6 +21,7 @@ use ui::{
LabelCommon as _, ParentElement, RenderOnce, SharedString, Styled, StyledExt, ToggleButton,
ToggleState, Toggleable, Window, div, h_flex, relative, rems, v_flex,
};
use util::ResultExt;
use workspace::{ModalView, Workspace};
use crate::{attach_modal::AttachModal, debugger_panel::DebugPanel};
@@ -85,11 +88,11 @@ impl NewSessionModal {
}
}
fn debug_config(&self, cx: &App, debugger: &str) -> DebugTaskDefinition {
fn debug_config(&self, cx: &App) -> Option<DebugTaskDefinition> {
let request = self.mode.debug_task(cx);
DebugTaskDefinition {
adapter: debugger.to_owned(),
label: suggested_label(&request, debugger),
Some(DebugTaskDefinition {
adapter: self.debugger.clone()?.to_string(),
label: suggested_label(&request, self.debugger.as_deref()?),
request,
initialize_args: self.initialize_args.clone(),
tcp_connection: None,
@@ -97,26 +100,26 @@ impl NewSessionModal {
ToggleState::Selected => Some(true),
_ => None,
},
}
})
}
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) {
let Some(debugger) = self.debugger.as_ref() else {
// todo: show in UI.
log::error!("No debugger selected");
return;
};
let config = self.debug_config(cx, debugger);
let debug_panel = self.debug_panel.clone();
fn start_new_session(&self, window: &mut Window, cx: &mut Context<Self>) -> Result<()> {
let workspace = self.workspace.clone();
let config = self
.debug_config(cx)
.ok_or_else(|| anyhow!("Failed to create a debug config"))?;
let task_contexts = self
.workspace
let _ = self.debug_panel.update(cx, |panel, _| {
panel.past_debug_definition = Some(config.clone());
});
let task_contexts = workspace
.update(cx, |workspace, cx| {
tasks_ui::task_contexts(workspace, window, cx)
})
.ok();
cx.spawn_in(window, async move |this, cx| {
cx.spawn(async move |this, cx| {
let task_context = if let Some(task) = task_contexts {
task.await
.active_worktree_context
@@ -124,8 +127,9 @@ impl NewSessionModal {
} else {
task::TaskContext::default()
};
let project = workspace.update(cx, |workspace, _| workspace.project().clone())?;
debug_panel.update_in(cx, |debug_panel, window, cx| {
let task = project.update(cx, |this, cx| {
let template = DebugTaskTemplate {
locator: None,
definition: config.clone(),
@@ -135,18 +139,23 @@ impl NewSessionModal {
.resolve_task("debug_task", &task_context)
.and_then(|resolved_task| resolved_task.resolved_debug_adapter_config())
{
debug_panel.start_session(debug_config.definition, window, cx)
this.start_debug_session(debug_config.definition, cx)
} else {
debug_panel.start_session(config, window, cx)
this.start_debug_session(config, cx)
}
})?;
this.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
let spawn_result = task.await;
if spawn_result.is_ok() {
this.update(cx, |_, cx| {
cx.emit(DismissEvent);
})
.ok();
}
spawn_result?;
anyhow::Result::<_, anyhow::Error>::Ok(())
})
.detach_and_log_err(cx);
Ok(())
}
fn update_attach_picker(
@@ -240,12 +249,15 @@ impl NewSessionModal {
);
}
DebugRequest::Attach(_) => {
let Some(workspace) = this.workspace.upgrade() else {
let Ok(project) = this
.workspace
.read_with(cx, |this, _| this.project().clone())
else {
return;
};
this.mode = NewSessionMode::attach(
this.debugger.clone(),
workspace,
project,
window,
cx,
);
@@ -345,7 +357,7 @@ struct AttachMode {
impl AttachMode {
fn new(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Entity<Self> {
@@ -358,7 +370,7 @@ impl AttachMode {
stop_on_entry: Some(false),
};
let attach_picker = cx.new(|cx| {
let modal = AttachModal::new(workspace, debug_definition.clone(), false, window, cx);
let modal = AttachModal::new(project, debug_definition.clone(), false, window, cx);
window.focus(&modal.focus_handle(cx));
modal
@@ -458,11 +470,11 @@ impl RenderOnce for NewSessionMode {
impl NewSessionMode {
fn attach(
debugger: Option<SharedString>,
workspace: Entity<Workspace>,
project: Entity<Project>,
window: &mut Window,
cx: &mut Context<NewSessionModal>,
) -> Self {
Self::Attach(AttachMode::new(debugger, workspace, window, cx))
Self::Attach(AttachMode::new(debugger, project, window, cx))
}
fn launch(
past_launch_config: Option<LaunchRequest>,
@@ -557,12 +569,15 @@ impl Render for NewSessionModal {
.toggle_state(matches!(self.mode, NewSessionMode::Attach(_)))
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
let Some(workspace) = this.workspace.upgrade() else {
let Ok(project) = this
.workspace
.read_with(cx, |this, _| this.project().clone())
else {
return;
};
this.mode = NewSessionMode::attach(
this.debugger.clone(),
workspace,
project,
window,
cx,
);
@@ -616,7 +631,7 @@ impl Render for NewSessionModal {
.child(
Button::new("debugger-spawn", "Start")
.on_click(cx.listener(|this, _, window, cx| {
this.start_new_session(window, cx);
this.start_new_session(window, cx).log_err();
}))
.disabled(self.debugger.is_none()),
),

View File

@@ -88,12 +88,6 @@ impl DebugSession {
}
}
pub fn session(&self, cx: &App) -> Entity<Session> {
match &self.mode {
DebugSessionState::Running(entity) => entity.read(cx).session().clone(),
}
}
pub(crate) fn shutdown(&mut self, cx: &mut Context<Self>) {
match &self.mode {
DebugSessionState::Running(state) => state.update(cx, |state, cx| state.shutdown(cx)),
@@ -121,7 +115,13 @@ impl DebugSession {
};
self.label
.get_or_init(|| session.read(cx).label())
.get_or_init(|| {
session
.read(cx)
.as_local()
.expect("Remote Debug Sessions are not implemented yet")
.label()
})
.to_owned()
}

View File

@@ -411,26 +411,13 @@ impl RunningState {
.log_err();
if let Some(thread_id) = thread_id {
this.select_thread(*thread_id, window, cx);
this.select_thread(*thread_id, cx);
}
}
SessionEvent::Threads => {
let threads = this.session.update(cx, |this, cx| this.threads(cx));
this.select_current_thread(&threads, window, cx);
this.select_current_thread(&threads, cx);
}
SessionEvent::CapabilitiesLoaded => {
let capabilities = this.capabilities(cx);
if !capabilities.supports_modules_request.unwrap_or(false) {
this.remove_pane_item(DebuggerPaneItem::Modules, window, cx);
}
if !capabilities
.supports_loaded_sources_request
.unwrap_or(false)
{
this.remove_pane_item(DebuggerPaneItem::LoadedSources, window, cx);
}
}
_ => {}
}
cx.notify()
@@ -460,14 +447,35 @@ impl RunningState {
workspace::PaneGroup::with_root(root)
} else {
pane_close_subscriptions.clear();
let module_list = if session
.read(cx)
.capabilities()
.supports_modules_request
.unwrap_or(false)
{
Some(&module_list)
} else {
None
};
let loaded_source_list = if session
.read(cx)
.capabilities()
.supports_loaded_sources_request
.unwrap_or(false)
{
Some(&loaded_source_list)
} else {
None
};
let root = Self::default_pane_layout(
project,
&workspace,
&stack_frame_list,
&variable_list,
&module_list,
&loaded_source_list,
module_list,
loaded_source_list,
&console,
&breakpoint_list,
&mut pane_close_subscriptions,
@@ -504,6 +512,11 @@ impl RunningState {
window: &mut Window,
cx: &mut Context<Self>,
) {
debug_assert!(
item_kind.is_supported(self.session.read(cx).capabilities()),
"We should only allow removing supported item kinds"
);
if let Some((pane, item_id)) = self.panes.panes().iter().find_map(|pane| {
Some(pane).zip(
pane.read(cx)
@@ -731,7 +744,6 @@ impl RunningState {
pub fn select_current_thread(
&mut self,
threads: &Vec<(Thread, ThreadStatus)>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let selected_thread = self
@@ -744,7 +756,7 @@ impl RunningState {
};
if Some(ThreadId(selected_thread.id)) != self.thread_id {
self.select_thread(ThreadId(selected_thread.id), window, cx);
self.select_thread(ThreadId(selected_thread.id), cx);
}
}
@@ -757,7 +769,7 @@ impl RunningState {
.map(|id| self.session().read(cx).thread_status(id))
}
fn select_thread(&mut self, thread_id: ThreadId, window: &mut Window, cx: &mut Context<Self>) {
fn select_thread(&mut self, thread_id: ThreadId, cx: &mut Context<Self>) {
if self.thread_id.is_some_and(|id| id == thread_id) {
return;
}
@@ -765,7 +777,8 @@ impl RunningState {
self.thread_id = Some(thread_id);
self.stack_frame_list
.update(cx, |list, cx| list.schedule_refresh(true, window, cx));
.update(cx, |list, cx| list.refresh(cx));
cx.notify();
}
pub fn continue_thread(&mut self, cx: &mut Context<Self>) {
@@ -917,9 +930,9 @@ impl RunningState {
for (thread, _) in threads {
let state = state.clone();
let thread_id = thread.id;
this = this.entry(thread.name, None, move |window, cx| {
this = this.entry(thread.name, None, move |_, cx| {
state.update(cx, |state, cx| {
state.select_thread(ThreadId(thread_id), window, cx);
state.select_thread(ThreadId(thread_id), cx);
});
});
}
@@ -933,8 +946,8 @@ impl RunningState {
workspace: &WeakEntity<Workspace>,
stack_frame_list: &Entity<StackFrameList>,
variable_list: &Entity<VariableList>,
module_list: &Entity<ModuleList>,
loaded_source_list: &Entity<LoadedSourceList>,
module_list: Option<&Entity<ModuleList>>,
loaded_source_list: Option<&Entity<LoadedSourceList>>,
console: &Entity<Console>,
breakpoints: &Entity<BreakpointList>,
subscriptions: &mut HashMap<EntityId, Subscription>,
@@ -990,36 +1003,41 @@ impl RunningState {
window,
cx,
);
this.add_item(
Box::new(SubView::new(
module_list.focus_handle(cx),
module_list.clone().into(),
DebuggerPaneItem::Modules,
if let Some(module_list) = module_list {
this.add_item(
Box::new(SubView::new(
module_list.focus_handle(cx),
module_list.clone().into(),
DebuggerPaneItem::Modules,
None,
cx,
)),
false,
false,
None,
window,
cx,
)),
false,
false,
None,
window,
cx,
);
);
this.activate_item(0, false, false, window, cx);
}
this.add_item(
Box::new(SubView::new(
loaded_source_list.focus_handle(cx),
loaded_source_list.clone().into(),
DebuggerPaneItem::LoadedSources,
if let Some(loaded_source_list) = loaded_source_list {
this.add_item(
Box::new(SubView::new(
loaded_source_list.focus_handle(cx),
loaded_source_list.clone().into(),
DebuggerPaneItem::LoadedSources,
None,
cx,
)),
false,
false,
None,
window,
cx,
)),
false,
false,
None,
window,
cx,
);
this.activate_item(0, false, false, window, cx);
);
this.activate_item(1, false, false, window, cx);
}
});
let rightmost_pane = new_debugger_pane(workspace.clone(), project.clone(), window, cx);

View File

@@ -141,16 +141,14 @@ impl Console {
expression
});
self.session.update(cx, |session, cx| {
session
.evaluate(
expression,
Some(dap::EvaluateArgumentsContext::Variables),
self.stack_frame_list.read(cx).selected_stack_frame_id(),
None,
cx,
)
.detach();
self.session.update(cx, |state, cx| {
state.evaluate(
expression,
Some(dap::EvaluateArgumentsContext::Variables),
self.stack_frame_list.read(cx).selected_stack_frame_id(),
None,
cx,
);
});
}

View File

@@ -1,6 +1,5 @@
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Result, anyhow};
use dap::StackFrameId;
@@ -10,7 +9,6 @@ use gpui::{
};
use language::PointUtf16;
use project::debugger::breakpoint_store::ActiveStackFrame;
use project::debugger::session::{Session, SessionEvent, StackFrame};
use project::{ProjectItem, ProjectPath};
use ui::{Scrollbar, ScrollbarState, Tooltip, prelude::*};
@@ -30,11 +28,11 @@ pub struct StackFrameList {
_subscription: Subscription,
session: Entity<Session>,
state: WeakEntity<RunningState>,
invalidate: bool,
entries: Vec<StackFrameEntry>,
workspace: WeakEntity<Workspace>,
selected_stack_frame_id: Option<StackFrameId>,
scrollbar_state: ScrollbarState,
_refresh_task: Task<()>,
}
#[allow(clippy::large_enum_variant)]
@@ -70,17 +68,14 @@ impl StackFrameList {
);
let _subscription =
cx.subscribe_in(&session, window, |this, _, event, window, cx| match event {
SessionEvent::Threads => {
this.schedule_refresh(false, window, cx);
}
SessionEvent::Stopped(..) | SessionEvent::StackTrace => {
this.schedule_refresh(true, window, cx);
cx.subscribe_in(&session, window, |this, _, event, _, cx| match event {
SessionEvent::Stopped(_) | SessionEvent::StackTrace | SessionEvent::Threads => {
this.refresh(cx);
}
_ => {}
});
let mut this = Self {
Self {
scrollbar_state: ScrollbarState::new(list.clone()),
list,
session,
@@ -88,12 +83,10 @@ impl StackFrameList {
focus_handle,
state,
_subscription,
invalidate: true,
entries: Default::default(),
selected_stack_frame_id: None,
_refresh_task: Task::ready(()),
};
this.schedule_refresh(true, window, cx);
this
}
}
#[cfg(test)]
@@ -143,32 +136,10 @@ impl StackFrameList {
self.selected_stack_frame_id
}
pub(super) fn schedule_refresh(
&mut self,
select_first: bool,
window: &mut Window,
cx: &mut Context<Self>,
) {
const REFRESH_DEBOUNCE: Duration = Duration::from_millis(20);
self._refresh_task = cx.spawn_in(window, async move |this, cx| {
let debounce = this
.update(cx, |this, cx| {
let new_stack_frames = this.stack_frames(cx);
new_stack_frames.is_empty() && !this.entries.is_empty()
})
.ok()
.unwrap_or_default();
if debounce {
cx.background_executor().timer(REFRESH_DEBOUNCE).await;
}
this.update_in(cx, |this, window, cx| {
this.build_entries(select_first, window, cx);
cx.notify();
})
.ok();
})
pub(super) fn refresh(&mut self, cx: &mut Context<Self>) {
self.invalidate = true;
self.entries.clear();
cx.notify();
}
pub fn build_entries(
@@ -266,7 +237,6 @@ impl StackFrameList {
return Task::ready(Err(anyhow!("Project path not found")));
};
let stack_frame_id = stack_frame.id;
cx.spawn_in(window, async move |this, cx| {
let (worktree, relative_path) = this
.update(cx, |this, cx| {
@@ -315,22 +285,12 @@ impl StackFrameList {
.await?;
this.update(cx, |this, cx| {
let Some(thread_id) = this.state.read_with(cx, |state, _| state.thread_id)? else {
return Err(anyhow!("No selected thread ID found"));
};
this.workspace.update(cx, |workspace, cx| {
let breakpoint_store = workspace.project().read(cx).breakpoint_store();
breakpoint_store.update(cx, |store, cx| {
store.set_active_position(
ActiveStackFrame {
session_id: this.session.read(cx).session_id(),
thread_id,
stack_frame_id,
path: abs_path,
position,
},
(this.session.read(cx).session_id(), abs_path, position),
cx,
);
})
@@ -555,7 +515,13 @@ impl StackFrameList {
}
impl Render for StackFrameList {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
if self.invalidate {
self.build_entries(self.entries.is_empty(), window, cx);
self.invalidate = false;
cx.notify();
}
div()
.size_full()
.p_1()

View File

@@ -1,29 +1,16 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
use dap::{DebugRequest, client::DebugAdapterClient};
use gpui::{Entity, TestAppContext, WindowHandle};
use project::{Project, debugger::session::Session};
use project::Project;
use settings::SettingsStore;
use task::DebugTaskDefinition;
use terminal_view::terminal_panel::TerminalPanel;
use workspace::Workspace;
use crate::{debugger_panel::DebugPanel, session::DebugSession};
#[cfg(test)]
mod attach_modal;
#[cfg(test)]
mod console;
#[cfg(test)]
mod dap_logger;
#[cfg(test)]
mod debugger_panel;
#[cfg(test)]
mod module_list;
#[cfg(test)]
mod stack_frame_list;
#[cfg(test)]
mod variable_list;
pub fn init_test(cx: &mut gpui::TestAppContext) {
@@ -55,7 +42,7 @@ pub async fn init_test_workspace(
let debugger_panel = workspace_handle
.update(cx, |_, window, cx| {
cx.spawn_in(window, async move |this, cx| {
DebugPanel::load(this, cx).await
DebugPanel::load(this, cx.clone()).await
})
})
.unwrap()
@@ -95,46 +82,3 @@ pub fn active_debug_session_panel(
})
.unwrap()
}
pub fn start_debug_session_with<T: Fn(&Arc<DebugAdapterClient>) + 'static>(
workspace: &WindowHandle<Workspace>,
cx: &mut gpui::TestAppContext,
config: DebugTaskDefinition,
configure: T,
) -> Result<Entity<Session>> {
let _subscription = project::debugger::test::intercept_debug_sessions(cx, configure);
workspace.update(cx, |workspace, window, cx| {
workspace.start_debug_session(config, window, cx)
})?;
cx.run_until_parked();
let session = workspace.read_with(cx, |workspace, cx| {
workspace
.panel::<DebugPanel>(cx)
.and_then(|panel| panel.read(cx).active_session())
.and_then(|session| session.read(cx).mode().as_running().cloned())
.map(|running| running.read(cx).session().clone())
.ok_or_else(|| anyhow!("Failed to get active session"))
})??;
Ok(session)
}
pub fn start_debug_session<T: Fn(&Arc<DebugAdapterClient>) + 'static>(
workspace: &WindowHandle<Workspace>,
cx: &mut gpui::TestAppContext,
configure: T,
) -> Result<Entity<Session>> {
start_debug_session_with(
workspace,
cx,
DebugTaskDefinition {
adapter: "fake-adapter".to_string(),
request: DebugRequest::Launch(Default::default()),
label: "test".to_string(),
initialize_args: None,
tcp_connection: None,
stop_on_entry: None,
},
configure,
)
}

View File

@@ -1,4 +1,4 @@
use crate::{attach_modal::Candidate, tests::start_debug_session_with, *};
use crate::{attach_modal::Candidate, *};
use attach_modal::AttachModal;
use dap::{FakeAdapter, client::SessionId};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
@@ -26,8 +26,8 @@ async fn test_direct_attach_to_process(executor: BackgroundExecutor, cx: &mut Te
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session_with(
&workspace,
let session = debugger::test::start_debug_session_with(
&project,
cx,
DebugTaskDefinition {
adapter: "fake-adapter".to_string(),
@@ -47,6 +47,7 @@ async fn test_direct_attach_to_process(executor: BackgroundExecutor, cx: &mut Te
});
},
)
.await
.unwrap();
cx.run_until_parked();
@@ -98,10 +99,9 @@ async fn test_show_attach_modal_and_select_process(
});
let attach_modal = workspace
.update(cx, |workspace, window, cx| {
let workspace_handle = cx.entity();
workspace.toggle_modal(window, cx, |window, cx| {
AttachModal::with_processes(
workspace_handle,
project.clone(),
DebugTaskDefinition {
adapter: FakeAdapter::ADAPTER_NAME.into(),
request: dap::DebugRequest::Attach(AttachRequest::default()),

View File

@@ -1,7 +1,4 @@
use crate::{
tests::{active_debug_session_panel, start_debug_session},
*,
};
use crate::{tests::active_debug_session_panel, *};
use dap::requests::StackTrace;
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
@@ -31,7 +28,9 @@ async fn test_handle_output_event(executor: BackgroundExecutor, cx: &mut TestApp
})
.unwrap();
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<StackTrace, _>(move |_, _| {

View File

@@ -1,118 +0,0 @@
use crate::tests::{init_test, init_test_workspace, start_debug_session};
use dap::requests::{StackTrace, Threads};
use debugger_tools::LogStore;
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::Project;
use serde_json::json;
use std::cell::OnceCell;
#[gpui::test]
async fn test_dap_logger_captures_all_session_rpc_messages(
executor: BackgroundExecutor,
cx: &mut TestAppContext,
) {
let log_store_cell = std::rc::Rc::new(OnceCell::new());
cx.update(|cx| {
let log_store_cell = log_store_cell.clone();
cx.observe_new::<LogStore>(move |_, _, cx| {
log_store_cell.set(cx.entity()).unwrap();
})
.detach();
debugger_tools::init(cx);
});
init_test(cx);
let log_store = log_store_cell.get().unwrap().clone();
// Create a filesystem with a simple project
let fs = project::FakeFs::new(executor.clone());
fs.insert_tree(
"/project",
json!({
"main.rs": "fn main() {\n println!(\"Hello, world!\");\n}"
}),
)
.await;
assert!(
log_store.read_with(cx, |log_store, _| log_store
.contained_session_ids()
.is_empty()),
"log_store shouldn't contain any session IDs before any sessions were created"
);
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
// Start a debug session
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session_id = session.read_with(cx, |session, _| session.session_id());
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
assert_eq!(
log_store.read_with(cx, |log_store, _| log_store.contained_session_ids().len()),
1,
);
assert!(
log_store.read_with(cx, |log_store, _| log_store
.contained_session_ids()
.contains(&session_id)),
"log_store should contain the session IDs of the started session"
);
assert!(
!log_store.read_with(cx, |log_store, _| log_store
.rpc_messages_for_session_id(session_id)
.is_empty()),
"We should have the initialization sequence in the log store"
);
// Set up basic responses for common requests
client.on_request::<Threads, _>(move |_, _| {
Ok(dap::ThreadsResponse {
threads: vec![dap::Thread {
id: 1,
name: "Thread 1".into(),
}],
})
});
client.on_request::<StackTrace, _>(move |_, _| {
Ok(dap::StackTraceResponse {
stack_frames: Vec::default(),
total_frames: None,
})
});
// Run until all pending tasks are executed
cx.run_until_parked();
// Simulate a stopped event to generate more DAP messages
client
.fake_event(dap::messages::Events::Stopped(dap::StoppedEvent {
reason: dap::StoppedEventReason::Pause,
description: None,
thread_id: Some(1),
preserve_focus_hint: None,
text: None,
all_threads_stopped: None,
hit_breakpoint_ids: None,
}))
.await;
cx.run_until_parked();
// Shutdown the debug session
let shutdown_session = project.update(cx, |project, cx| {
project.dap_store().update(cx, |dap_store, cx| {
dap_store.shutdown_session(session.read(cx).session_id(), cx)
})
});
shutdown_session.await.unwrap();
cx.run_until_parked();
}

View File

@@ -1,4 +1,4 @@
use crate::{tests::start_debug_session, *};
use crate::*;
use dap::{
ErrorResponse, Message, RunInTerminalRequestArguments, SourceBreakpoint,
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
@@ -48,7 +48,9 @@ async fn test_basic_show_debug_panel(executor: BackgroundExecutor, cx: &mut Test
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<Threads, _>(move |_, _| {
@@ -185,7 +187,9 @@ async fn test_we_can_only_have_one_panel_per_debug_session(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<Threads, _>(move |_, _| {
@@ -350,7 +354,9 @@ async fn test_handle_successful_run_in_terminal_reverse_request(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client
@@ -413,86 +419,6 @@ async fn test_handle_successful_run_in_terminal_reverse_request(
shutdown_session.await.unwrap();
}
#[gpui::test]
async fn test_handle_start_debugging_request(
executor: BackgroundExecutor,
cx: &mut TestAppContext,
) {
init_test(cx);
let fs = FakeFs::new(executor.clone());
fs.insert_tree(
"/project",
json!({
"main.rs": "First line\nSecond line\nThird line\nFourth line",
}),
)
.await;
let project = Project::test(fs, ["/project".as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
let fake_config = json!({"one": "two"});
let launched_with = Arc::new(parking_lot::Mutex::new(None));
let _subscription = project::debugger::test::intercept_debug_sessions(cx, {
let launched_with = launched_with.clone();
move |client| {
let launched_with = launched_with.clone();
client.on_request::<dap::requests::Launch, _>(move |_, args| {
launched_with.lock().replace(args.raw);
Ok(())
});
client.on_request::<dap::requests::Attach, _>(move |_, _| {
assert!(false, "should not get attach request");
Ok(())
});
}
});
client
.fake_reverse_request::<StartDebugging>(StartDebuggingRequestArguments {
request: StartDebuggingRequestArgumentsRequest::Launch,
configuration: fake_config.clone(),
})
.await;
cx.run_until_parked();
workspace
.update(cx, |workspace, _window, cx| {
let debug_panel = workspace.panel::<DebugPanel>(cx).unwrap();
let active_session = debug_panel
.read(cx)
.active_session()
.unwrap()
.read(cx)
.session(cx);
let parent_session = active_session.read(cx).parent_session().unwrap();
assert_eq!(
active_session.read(cx).definition(),
parent_session.read(cx).definition()
);
})
.unwrap();
assert_eq!(&fake_config, launched_with.lock().as_ref().unwrap());
let shutdown_session = project.update(cx, |project, cx| {
project.dap_store().update(cx, |dap_store, cx| {
dap_store.shutdown_session(session.read(cx).session_id(), cx)
})
});
shutdown_session.await.unwrap();
}
// // covers that we always send a response back, if something when wrong,
// // while spawning the terminal
#[gpui::test]
@@ -518,7 +444,9 @@ async fn test_handle_error_run_in_terminal_reverse_request(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client
@@ -594,7 +522,9 @@ async fn test_handle_start_debugging_reverse_request(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {
@@ -699,7 +629,9 @@ async fn test_shutdown_children_when_parent_session_shutdown(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let parent_session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let parent_session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = parent_session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {
@@ -805,7 +737,9 @@ async fn test_shutdown_parent_session_if_all_children_are_shutdown(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let parent_session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let parent_session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = parent_session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_response::<StartDebugging, _>(move |_| {}).await;
@@ -924,7 +858,7 @@ async fn test_debug_panel_item_thread_status_reset_on_failure(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |client| {
let session = debugger::test::start_debug_session(&project, cx, |client| {
client.on_request::<dap::requests::Initialize, _>(move |_, _| {
Ok(dap::Capabilities {
supports_step_back: Some(true),
@@ -932,6 +866,7 @@ async fn test_debug_panel_item_thread_status_reset_on_failure(
})
});
})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
@@ -1138,7 +1073,9 @@ async fn test_send_breakpoints_when_editor_has_been_saved(
.update(cx, |_, _, cx| worktree.read(cx).id())
.unwrap();
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
let buffer = project
@@ -1353,7 +1290,9 @@ async fn test_unsetting_breakpoints_on_clear_breakpoint_action(
editor.toggle_breakpoint(&actions::ToggleBreakpoint, window, cx);
});
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
let called_set_breakpoints = Arc::new(AtomicBool::new(false));
@@ -1419,7 +1358,7 @@ async fn test_debug_session_is_shutdown_when_attach_and_launch_request_fails(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
start_debug_session(&workspace, cx, |client| {
let task = project::debugger::test::start_debug_session(&project, cx, |client| {
client.on_request::<dap::requests::Initialize, _>(|_, _| {
Err(ErrorResponse {
error: Some(Message {
@@ -1433,8 +1372,12 @@ async fn test_debug_session_is_shutdown_when_attach_and_launch_request_fails(
}),
})
});
})
.ok();
});
assert!(
task.await.is_err(),
"Session should failed to start if launch request fails"
);
cx.run_until_parked();

View File

@@ -1,13 +1,16 @@
use crate::{
debugger_panel::DebugPanel,
tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session},
tests::{active_debug_session_panel, init_test, init_test_workspace},
};
use dap::{
StoppedEvent,
requests::{Initialize, Modules},
};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
use project::{
FakeFs, Project,
debugger::{self},
};
use std::sync::{
Arc,
atomic::{AtomicBool, AtomicI32, Ordering},
@@ -28,7 +31,7 @@ async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext)
.unwrap();
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |client| {
let session = debugger::test::start_debug_session(&project, cx, |client| {
client.on_request::<Initialize, _>(move |_, _| {
Ok(dap::Capabilities {
supports_modules_request: Some(true),
@@ -36,6 +39,7 @@ async fn test_module_list(executor: BackgroundExecutor, cx: &mut TestAppContext)
})
});
})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());

View File

@@ -1,7 +1,7 @@
use crate::{
debugger_panel::DebugPanel,
session::running::stack_frame_list::StackFrameEntry,
tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session},
tests::{active_debug_session_panel, init_test, init_test_workspace},
};
use dap::{
StackFrame,
@@ -9,7 +9,7 @@ use dap::{
};
use editor::{Editor, ToPoint as _};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
use project::{FakeFs, Project, debugger};
use serde_json::json;
use std::sync::Arc;
use unindent::Unindent as _;
@@ -50,7 +50,9 @@ async fn test_fetch_initial_stack_frames_and_go_to_stack_frame(
let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<Scopes, _>(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] }));
@@ -152,7 +154,7 @@ async fn test_fetch_initial_stack_frames_and_go_to_stack_frame(
cx.run_until_parked();
// select first thread
active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| {
active_debug_session_panel(workspace, cx).update_in(cx, |session, _, cx| {
session
.mode()
.as_running()
@@ -162,7 +164,6 @@ async fn test_fetch_initial_stack_frames_and_go_to_stack_frame(
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
@@ -228,7 +229,9 @@ async fn test_select_stack_frame(executor: BackgroundExecutor, cx: &mut TestAppC
});
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<Threads, _>(move |_, _| {
@@ -331,7 +334,7 @@ async fn test_select_stack_frame(executor: BackgroundExecutor, cx: &mut TestAppC
cx.run_until_parked();
// select first thread
active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| {
active_debug_session_panel(workspace, cx).update_in(cx, |session, _, cx| {
session
.mode()
.as_running()
@@ -341,7 +344,6 @@ async fn test_select_stack_frame(executor: BackgroundExecutor, cx: &mut TestAppC
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});
@@ -493,7 +495,9 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<Threads, _>(move |_, _| {
@@ -706,7 +710,7 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo
cx.run_until_parked();
// select first thread
active_debug_session_panel(workspace, cx).update_in(cx, |session, window, cx| {
active_debug_session_panel(workspace, cx).update_in(cx, |session, _, cx| {
session
.mode()
.as_running()
@@ -716,7 +720,6 @@ async fn test_collapsed_entries(executor: BackgroundExecutor, cx: &mut TestAppCo
&running_state
.session()
.update(cx, |session, cx| session.threads(cx)),
window,
cx,
);
});

View File

@@ -6,7 +6,7 @@ use std::sync::{
use crate::{
DebugPanel,
session::running::variable_list::{CollapseSelectedEntry, ExpandSelectedEntry},
tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session},
tests::{active_debug_session_panel, init_test, init_test_workspace},
};
use collections::HashMap;
use dap::{
@@ -15,7 +15,7 @@ use dap::{
};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use menu::{SelectFirst, SelectNext, SelectPrevious};
use project::{FakeFs, Project};
use project::{FakeFs, Project, debugger};
use serde_json::json;
use unindent::Unindent as _;
use util::path;
@@ -54,7 +54,9 @@ async fn test_basic_fetch_initial_scope_and_variables(
})
.unwrap();
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {
@@ -264,7 +266,9 @@ async fn test_fetch_variables_for_multiple_scopes(
.unwrap();
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {
@@ -524,7 +528,9 @@ async fn test_keyboard_navigation(executor: BackgroundExecutor, cx: &mut TestApp
})
.unwrap();
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {
@@ -1307,7 +1313,9 @@ async fn test_variable_list_only_sends_requests_when_rendering(
let workspace = init_test_workspace(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {
@@ -1552,7 +1560,9 @@ async fn test_it_fetches_scopes_variables_when_you_select_a_stack_frame(
.unwrap();
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
let session = debugger::test::start_debug_session(&project, cx, |_| {})
.await
.unwrap();
let client = session.update(cx, |session, _| session.adapter_client().unwrap());
client.on_request::<dap::requests::Threads, _>(move |_, _| {

View File

@@ -28,7 +28,6 @@ impl DiagnosticRenderer {
diagnostic_group: Vec<DiagnosticEntry<Point>>,
buffer_id: BufferId,
diagnostics_editor: Option<WeakEntity<ProjectDiagnosticsEditor>>,
merge_same_row: bool,
cx: &mut App,
) -> Vec<DiagnosticBlock> {
let Some(primary_ix) = diagnostic_group
@@ -46,7 +45,7 @@ impl DiagnosticRenderer {
if entry.diagnostic.is_primary {
continue;
}
if entry.range.start.row == primary.range.start.row && merge_same_row {
if entry.range.start.row == primary.range.start.row {
same_row.push(entry)
} else if entry.range.start.row.abs_diff(primary.range.start.row) < 5 {
close.push(entry)
@@ -55,48 +54,28 @@ impl DiagnosticRenderer {
}
}
let mut markdown = String::new();
let diagnostic = &primary.diagnostic;
markdown.push_str(&Markdown::escape(&diagnostic.message));
let mut markdown =
Markdown::escape(&if let Some(source) = primary.diagnostic.source.as_ref() {
format!("{}: {}", source, primary.diagnostic.message)
} else {
primary.diagnostic.message
})
.to_string();
for entry in same_row {
markdown.push_str("\n- hint: ");
markdown.push_str(&Markdown::escape(&entry.diagnostic.message))
}
if diagnostic.source.is_some() || diagnostic.code.is_some() {
markdown.push_str(" (");
}
if let Some(source) = diagnostic.source.as_ref() {
markdown.push_str(&Markdown::escape(&source));
}
if diagnostic.source.is_some() && diagnostic.code.is_some() {
markdown.push(' ');
}
if let Some(code) = diagnostic.code.as_ref() {
if let Some(description) = diagnostic.code_description.as_ref() {
markdown.push('[');
markdown.push_str(&Markdown::escape(&code.to_string()));
markdown.push_str("](");
markdown.push_str(&Markdown::escape(description.as_ref()));
markdown.push(')');
} else {
markdown.push_str(&Markdown::escape(&code.to_string()));
}
}
if diagnostic.source.is_some() || diagnostic.code.is_some() {
markdown.push(')');
}
for (ix, entry) in &distant {
markdown.push_str("\n- hint: [");
markdown.push_str(&Markdown::escape(&entry.diagnostic.message));
markdown.push_str(&format!(
"](file://#diagnostic-{buffer_id}-{group_id}-{ix})\n",
))
markdown.push_str(&format!("](file://#diagnostic-{group_id}-{ix})\n",))
}
let mut results = vec![DiagnosticBlock {
initial_range: primary.range,
severity: primary.diagnostic.severity,
buffer_id,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
}];
@@ -112,6 +91,7 @@ impl DiagnosticRenderer {
results.push(DiagnosticBlock {
initial_range: entry.range,
severity: entry.diagnostic.severity,
buffer_id,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
@@ -125,12 +105,15 @@ impl DiagnosticRenderer {
};
let mut markdown = Markdown::escape(&markdown).to_string();
markdown.push_str(&format!(
" ([back](file://#diagnostic-{buffer_id}-{group_id}-{primary_ix}))"
" ([back](file://#diagnostic-{group_id}-{primary_ix}))"
));
// problem: group-id changes...
// - only an issue in diagnostics because caching
results.push(DiagnosticBlock {
initial_range: entry.range,
severity: entry.diagnostic.severity,
buffer_id,
diagnostics_editor: diagnostics_editor.clone(),
markdown: cx.new(|cx| Markdown::new(markdown.into(), None, None, cx)),
});
@@ -149,7 +132,7 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer {
editor: WeakEntity<Editor>,
cx: &mut App,
) -> Vec<BlockProperties<Anchor>> {
let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, true, cx);
let blocks = Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, cx);
blocks
.into_iter()
.map(|block| {
@@ -168,40 +151,13 @@ impl editor::DiagnosticRenderer for DiagnosticRenderer {
})
.collect()
}
fn render_hover(
&self,
diagnostic_group: Vec<DiagnosticEntry<Point>>,
range: Range<Point>,
buffer_id: BufferId,
cx: &mut App,
) -> Option<Entity<Markdown>> {
let blocks =
Self::diagnostic_blocks_for_group(diagnostic_group, buffer_id, None, false, cx);
blocks.into_iter().find_map(|block| {
if block.initial_range == range {
Some(block.markdown)
} else {
None
}
})
}
fn open_link(
&self,
editor: &mut Editor,
link: SharedString,
window: &mut Window,
cx: &mut Context<Editor>,
) {
DiagnosticBlock::open_link(editor, &None, link, window, cx);
}
}
#[derive(Clone)]
pub(crate) struct DiagnosticBlock {
pub(crate) initial_range: Range<Point>,
pub(crate) severity: DiagnosticSeverity,
pub(crate) buffer_id: BufferId,
pub(crate) markdown: Entity<Markdown>,
pub(crate) diagnostics_editor: Option<WeakEntity<ProjectDiagnosticsEditor>>,
}
@@ -225,6 +181,7 @@ impl DiagnosticBlock {
let settings = ThemeSettings::get_global(cx);
let editor_line_height = (settings.line_height() * settings.buffer_font_size(cx)).round();
let line_height = editor_line_height;
let buffer_id = self.buffer_id;
let diagnostics_editor = self.diagnostics_editor.clone();
div()
@@ -238,11 +195,14 @@ impl DiagnosticBlock {
MarkdownElement::new(self.markdown.clone(), hover_markdown_style(bcx.window, cx))
.on_url_click({
move |link, window, cx| {
editor
.update(cx, |editor, cx| {
Self::open_link(editor, &diagnostics_editor, link, window, cx)
})
.ok();
Self::open_link(
editor.clone(),
&diagnostics_editor,
link,
window,
buffer_id,
cx,
)
}
}),
)
@@ -250,71 +210,79 @@ impl DiagnosticBlock {
}
pub fn open_link(
editor: &mut Editor,
editor: WeakEntity<Editor>,
diagnostics_editor: &Option<WeakEntity<ProjectDiagnosticsEditor>>,
link: SharedString,
window: &mut Window,
cx: &mut Context<Editor>,
buffer_id: BufferId,
cx: &mut App,
) {
let Some(diagnostic_link) = link.strip_prefix("file://#diagnostic-") else {
editor::hover_popover::open_markdown_url(link, window, cx);
return;
};
let Some((buffer_id, group_id, ix)) = maybe!({
let mut parts = diagnostic_link.split('-');
let buffer_id: u64 = parts.next()?.parse().ok()?;
let group_id: usize = parts.next()?.parse().ok()?;
let ix: usize = parts.next()?.parse().ok()?;
Some((BufferId::new(buffer_id).ok()?, group_id, ix))
}) else {
return;
};
if let Some(diagnostics_editor) = diagnostics_editor {
if let Some(diagnostic) = diagnostics_editor
.update(cx, |diagnostics, _| {
diagnostics
.diagnostics
.get(&buffer_id)
.cloned()
.unwrap_or_default()
.into_iter()
.filter(|d| d.diagnostic.group_id == group_id)
.nth(ix)
})
.ok()
.flatten()
{
let multibuffer = editor.buffer().read(cx);
let Some(snapshot) = multibuffer
.buffer(buffer_id)
.map(|entity| entity.read(cx).snapshot())
else {
editor
.update(cx, |editor, cx| {
let Some(diagnostic_link) = link.strip_prefix("file://#diagnostic-") else {
editor::hover_popover::open_markdown_url(link, window, cx);
return;
};
let Some((group_id, ix)) = maybe!({
let (group_id, ix) = diagnostic_link.split_once('-')?;
let group_id: usize = group_id.parse().ok()?;
let ix: usize = ix.parse().ok()?;
Some((group_id, ix))
}) else {
return;
};
for (excerpt_id, range) in multibuffer.excerpts_for_buffer(buffer_id, cx) {
if range.context.overlaps(&diagnostic.range, &snapshot) {
Self::jump_to(
editor,
Anchor::range_in_buffer(excerpt_id, buffer_id, diagnostic.range),
window,
cx,
);
return;
if let Some(diagnostics_editor) = diagnostics_editor {
if let Some(diagnostic) = diagnostics_editor
.update(cx, |diagnostics, _| {
diagnostics
.diagnostics
.get(&buffer_id)
.cloned()
.unwrap_or_default()
.into_iter()
.filter(|d| d.diagnostic.group_id == group_id)
.nth(ix)
})
.ok()
.flatten()
{
let multibuffer = editor.buffer().read(cx);
let Some(snapshot) = multibuffer
.buffer(buffer_id)
.map(|entity| entity.read(cx).snapshot())
else {
return;
};
for (excerpt_id, range) in multibuffer.excerpts_for_buffer(buffer_id, cx) {
if range.context.overlaps(&diagnostic.range, &snapshot) {
Self::jump_to(
editor,
Anchor::range_in_buffer(
excerpt_id,
buffer_id,
diagnostic.range,
),
window,
cx,
);
return;
}
}
}
}
}
} else {
if let Some(diagnostic) = editor
.snapshot(window, cx)
.buffer_snapshot
.diagnostic_group(buffer_id, group_id)
.nth(ix)
{
Self::jump_to(editor, diagnostic.range, window, cx)
}
};
} else {
if let Some(diagnostic) = editor
.snapshot(window, cx)
.buffer_snapshot
.diagnostic_group(buffer_id, group_id)
.nth(ix)
{
Self::jump_to(editor, diagnostic.range, window, cx)
}
};
})
.ok();
}
fn jump_to<T: ToOffset>(

Some files were not shown because too many files have changed in this diff Show More