Compare commits
30 Commits
collect-ta
...
markdown-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5efc26ae7f | ||
|
|
65e751ca33 | ||
|
|
17cf04558b | ||
|
|
36ae564b61 | ||
|
|
110195cdae | ||
|
|
b7d5e6480a | ||
|
|
0fa9f05313 | ||
|
|
051f49ce9a | ||
|
|
e5670ba081 | ||
|
|
e4262f97af | ||
|
|
944a0df436 | ||
|
|
a1be61949d | ||
|
|
a092e2dc03 | ||
|
|
b1c7fa1dac | ||
|
|
df66237428 | ||
|
|
ca513f52bf | ||
|
|
e9c9a8a269 | ||
|
|
315321bf8c | ||
|
|
c747a57b7e | ||
|
|
f73c8e5841 | ||
|
|
f7a0834f54 | ||
|
|
83d513aef4 | ||
|
|
b440e1a467 | ||
|
|
5c4f9e57d8 | ||
|
|
05f8001ee9 | ||
|
|
b93c67438c | ||
|
|
fdec966226 | ||
|
|
9041f734fd | ||
|
|
844c7ad22e | ||
|
|
926f377c6c |
8
Cargo.lock
generated
8
Cargo.lock
generated
@@ -546,6 +546,7 @@ dependencies = [
|
||||
"language_model",
|
||||
"lmstudio",
|
||||
"log",
|
||||
"mistral",
|
||||
"ollama",
|
||||
"open_ai",
|
||||
"paths",
|
||||
@@ -2813,7 +2814,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"async-recursion 0.3.2",
|
||||
"async-tungstenite",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"clock",
|
||||
"cocoa 0.26.0",
|
||||
@@ -2825,7 +2825,6 @@ dependencies = [
|
||||
"gpui_tokio",
|
||||
"http_client",
|
||||
"http_client_tls",
|
||||
"httparse",
|
||||
"log",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
@@ -2846,7 +2845,6 @@ dependencies = [
|
||||
"time",
|
||||
"tiny_http",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-socks",
|
||||
"url",
|
||||
"util",
|
||||
@@ -4982,6 +4980,7 @@ dependencies = [
|
||||
"clap",
|
||||
"client",
|
||||
"collections",
|
||||
"debug_adapter_extension",
|
||||
"dirs 4.0.0",
|
||||
"dotenv",
|
||||
"env_logger 0.11.8",
|
||||
@@ -7947,6 +7946,7 @@ dependencies = [
|
||||
"editor",
|
||||
"file_icons",
|
||||
"gpui",
|
||||
"language",
|
||||
"log",
|
||||
"project",
|
||||
"schemars",
|
||||
@@ -12885,6 +12885,7 @@ dependencies = [
|
||||
"clock",
|
||||
"dap",
|
||||
"dap_adapters",
|
||||
"debug_adapter_extension",
|
||||
"env_logger 0.11.8",
|
||||
"extension",
|
||||
"extension_host",
|
||||
@@ -19633,6 +19634,7 @@ dependencies = [
|
||||
"dap",
|
||||
"dap_adapters",
|
||||
"db",
|
||||
"debug_adapter_extension",
|
||||
"debugger_tools",
|
||||
"debugger_ui",
|
||||
"diagnostics",
|
||||
|
||||
@@ -72,7 +72,9 @@
|
||||
"alt-left": "editor::SelectToPreviousWordStart",
|
||||
"alt-right": "editor::SelectToNextWordEnd",
|
||||
"pagedown": "editor::SelectPageDown",
|
||||
"ctrl-v": "editor::SelectPageDown",
|
||||
"pageup": "editor::SelectPageUp",
|
||||
"alt-v": "editor::SelectPageUp",
|
||||
"ctrl-f": "editor::SelectRight",
|
||||
"ctrl-b": "editor::SelectLeft",
|
||||
"ctrl-n": "editor::SelectDown",
|
||||
|
||||
@@ -72,7 +72,9 @@
|
||||
"alt-left": "editor::SelectToPreviousWordStart",
|
||||
"alt-right": "editor::SelectToNextWordEnd",
|
||||
"pagedown": "editor::SelectPageDown",
|
||||
"ctrl-v": "editor::SelectPageDown",
|
||||
"pageup": "editor::SelectPageUp",
|
||||
"alt-v": "editor::SelectPageUp",
|
||||
"ctrl-f": "editor::SelectRight",
|
||||
"ctrl-b": "editor::SelectLeft",
|
||||
"ctrl-n": "editor::SelectDown",
|
||||
|
||||
@@ -485,7 +485,7 @@ impl ActivityIndicator {
|
||||
this.dismiss_error_message(&DismissErrorMessage, window, cx)
|
||||
})),
|
||||
}),
|
||||
AutoUpdateStatus::Updated { binary_path } => Some(Content {
|
||||
AutoUpdateStatus::Updated { binary_path, .. } => Some(Content {
|
||||
icon: None,
|
||||
message: "Click to restart and update Zed".to_string(),
|
||||
on_click: Some(Arc::new({
|
||||
|
||||
@@ -85,6 +85,7 @@ actions!(
|
||||
KeepAll,
|
||||
Follow,
|
||||
ResetTrialUpsell,
|
||||
ResetTrialEndUpsell,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use db::kvp::KEY_VALUE_STORE;
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use markdown::Markdown;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -66,8 +66,8 @@ use crate::ui::AgentOnboardingModal;
|
||||
use crate::{
|
||||
AddContextServer, AgentDiffPane, ContextStore, DeleteRecentlyOpenThread, ExpandMessageEditor,
|
||||
Follow, InlineAssistant, NewTextThread, NewThread, OpenActiveThreadAsMarkdown, OpenAgentDiff,
|
||||
OpenHistory, ResetTrialUpsell, TextThreadStore, ThreadEvent, ToggleContextPicker,
|
||||
ToggleNavigationMenu, ToggleOptionsMenu,
|
||||
OpenHistory, ResetTrialEndUpsell, ResetTrialUpsell, TextThreadStore, ThreadEvent,
|
||||
ToggleContextPicker, ToggleNavigationMenu, ToggleOptionsMenu,
|
||||
};
|
||||
|
||||
const AGENT_PANEL_KEY: &str = "agent_panel";
|
||||
@@ -157,7 +157,10 @@ pub fn init(cx: &mut App) {
|
||||
window.refresh();
|
||||
})
|
||||
.register_action(|_workspace, _: &ResetTrialUpsell, _window, cx| {
|
||||
set_trial_upsell_dismissed(false, cx);
|
||||
TrialUpsell::set_dismissed(false, cx);
|
||||
})
|
||||
.register_action(|_workspace, _: &ResetTrialEndUpsell, _window, cx| {
|
||||
TrialEndUpsell::set_dismissed(false, cx);
|
||||
});
|
||||
},
|
||||
)
|
||||
@@ -1932,12 +1935,23 @@ impl AgentPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn should_render_trial_end_upsell(&self, cx: &mut Context<Self>) -> bool {
|
||||
if TrialEndUpsell::dismissed() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let plan = self.user_store.read(cx).current_plan();
|
||||
let has_previous_trial = self.user_store.read(cx).trial_started_at().is_some();
|
||||
|
||||
matches!(plan, Some(Plan::Free)) && has_previous_trial
|
||||
}
|
||||
|
||||
fn should_render_upsell(&self, cx: &mut Context<Self>) -> bool {
|
||||
if !matches!(self.active_view, ActiveView::Thread { .. }) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.hide_trial_upsell || dismissed_trial_upsell() {
|
||||
if self.hide_trial_upsell || TrialUpsell::dismissed() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1983,125 +1997,115 @@ impl AgentPanel {
|
||||
move |toggle_state, _window, cx| {
|
||||
let toggle_state_bool = toggle_state.selected();
|
||||
|
||||
set_trial_upsell_dismissed(toggle_state_bool, cx);
|
||||
TrialUpsell::set_dismissed(toggle_state_bool, cx);
|
||||
},
|
||||
);
|
||||
|
||||
Some(
|
||||
div().p_2().child(
|
||||
v_flex()
|
||||
let contents = div()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.child(Headline::new("Build better with Zed Pro").size(HeadlineSize::Small))
|
||||
.child(
|
||||
Label::new("Try Zed Pro for free for 14 days - no credit card required.")
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"Use your own API keys or enable usage-based billing once you hit the cap.",
|
||||
)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.elevation_2(cx)
|
||||
.rounded(px(8.))
|
||||
.bg(cx.theme().colors().background.alpha(0.5))
|
||||
.p(px(3.))
|
||||
|
||||
.px_neg_1()
|
||||
.justify_between()
|
||||
.items_center()
|
||||
.child(h_flex().items_center().gap_1().child(checkbox))
|
||||
.child(
|
||||
div()
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.size_full()
|
||||
.border_1()
|
||||
.rounded(px(5.))
|
||||
.border_color(cx.theme().colors().text.alpha(0.1))
|
||||
.overflow_hidden()
|
||||
.relative()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.px_4()
|
||||
.py_3()
|
||||
.child(
|
||||
div()
|
||||
.absolute()
|
||||
.top_0()
|
||||
.right(px(-1.0))
|
||||
.w(px(441.))
|
||||
.h(px(167.))
|
||||
.child(
|
||||
Vector::new(VectorName::Grid, rems_from_px(441.), rems_from_px(167.)).color(ui::Color::Custom(cx.theme().colors().text.alpha(0.1)))
|
||||
)
|
||||
Button::new("dismiss-button", "Not Now")
|
||||
.style(ButtonStyle::Transparent)
|
||||
.color(Color::Muted)
|
||||
.on_click({
|
||||
let agent_panel = cx.entity();
|
||||
move |_, _, cx| {
|
||||
agent_panel.update(cx, |this, cx| {
|
||||
this.hide_trial_upsell = true;
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.absolute()
|
||||
.top(px(-8.0))
|
||||
.right_0()
|
||||
.w(px(400.))
|
||||
.h(px(92.))
|
||||
.child(
|
||||
Vector::new(VectorName::AiGrid, rems_from_px(400.), rems_from_px(92.)).color(ui::Color::Custom(cx.theme().colors().text.alpha(0.32)))
|
||||
)
|
||||
)
|
||||
// .child(
|
||||
// div()
|
||||
// .absolute()
|
||||
// .top_0()
|
||||
// .right(px(360.))
|
||||
// .size(px(401.))
|
||||
// .overflow_hidden()
|
||||
// .bg(cx.theme().colors().panel_background)
|
||||
// )
|
||||
.child(
|
||||
div()
|
||||
.absolute()
|
||||
.top_0()
|
||||
.right_0()
|
||||
.w(px(660.))
|
||||
.h(px(401.))
|
||||
.overflow_hidden()
|
||||
.bg(linear_gradient(
|
||||
75.,
|
||||
linear_color_stop(cx.theme().colors().panel_background.alpha(0.01), 1.0),
|
||||
linear_color_stop(cx.theme().colors().panel_background, 0.45),
|
||||
))
|
||||
)
|
||||
.child(Headline::new("Build better with Zed Pro").size(HeadlineSize::Small))
|
||||
.child(Label::new("Try Zed Pro for free for 14 days - no credit card required.").size(LabelSize::Small))
|
||||
.child(Label::new("Use your own API keys or enable usage-based billing once you hit the cap.").color(Color::Muted))
|
||||
Button::new("cta-button", "Start Trial")
|
||||
.style(ButtonStyle::Transparent)
|
||||
.on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
Some(self.render_upsell_container(cx, contents))
|
||||
}
|
||||
|
||||
fn render_trial_end_upsell(
|
||||
&self,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<impl IntoElement> {
|
||||
if !self.should_render_trial_end_upsell(cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(
|
||||
self.render_upsell_container(
|
||||
cx,
|
||||
div()
|
||||
.size_full()
|
||||
.gap_2()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.child(
|
||||
Headline::new("Your Zed Pro trial has expired.").size(HeadlineSize::Small),
|
||||
)
|
||||
.child(
|
||||
Label::new("You've been automatically reset to the free plan.")
|
||||
.size(LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.px_neg_1()
|
||||
.justify_between()
|
||||
.items_center()
|
||||
.child(div())
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.px_neg_1()
|
||||
.justify_between()
|
||||
.items_center()
|
||||
.child(h_flex().items_center().gap_1().child(checkbox))
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
Button::new("dismiss-button", "Not Now")
|
||||
.style(ButtonStyle::Transparent)
|
||||
.color(Color::Muted)
|
||||
.on_click({
|
||||
let agent_panel = cx.entity();
|
||||
move |_, _, cx| {
|
||||
agent_panel.update(
|
||||
cx,
|
||||
|this, cx| {
|
||||
let hidden =
|
||||
this.hide_trial_upsell;
|
||||
println!("hidden: {}", hidden);
|
||||
this.hide_trial_upsell = true;
|
||||
let new_hidden =
|
||||
this.hide_trial_upsell;
|
||||
println!(
|
||||
"new_hidden: {}",
|
||||
new_hidden
|
||||
);
|
||||
|
||||
cx.notify();
|
||||
},
|
||||
);
|
||||
}
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Button::new("cta-button", "Start Trial")
|
||||
.style(ButtonStyle::Transparent)
|
||||
.on_click(|_, _, cx| {
|
||||
cx.open_url(&zed_urls::account_url(cx))
|
||||
}),
|
||||
),
|
||||
Button::new("dismiss-button", "Stay on Free")
|
||||
.style(ButtonStyle::Transparent)
|
||||
.color(Color::Muted)
|
||||
.on_click({
|
||||
let agent_panel = cx.entity();
|
||||
move |_, _, cx| {
|
||||
agent_panel.update(cx, |_this, cx| {
|
||||
TrialEndUpsell::set_dismissed(true, cx);
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
Button::new("cta-button", "Upgrade to Zed Pro")
|
||||
.style(ButtonStyle::Transparent)
|
||||
.on_click(|_, _, cx| {
|
||||
cx.open_url(&zed_urls::account_url(cx))
|
||||
}),
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -2109,6 +2113,91 @@ impl AgentPanel {
|
||||
)
|
||||
}
|
||||
|
||||
fn render_upsell_container(&self, cx: &mut Context<Self>, content: Div) -> Div {
|
||||
div().p_2().child(
|
||||
v_flex()
|
||||
.w_full()
|
||||
.elevation_2(cx)
|
||||
.rounded(px(8.))
|
||||
.bg(cx.theme().colors().background.alpha(0.5))
|
||||
.p(px(3.))
|
||||
.child(
|
||||
div()
|
||||
.gap_2()
|
||||
.flex()
|
||||
.flex_col()
|
||||
.size_full()
|
||||
.border_1()
|
||||
.rounded(px(5.))
|
||||
.border_color(cx.theme().colors().text.alpha(0.1))
|
||||
.overflow_hidden()
|
||||
.relative()
|
||||
.bg(cx.theme().colors().panel_background)
|
||||
.px_4()
|
||||
.py_3()
|
||||
.child(
|
||||
div()
|
||||
.absolute()
|
||||
.top_0()
|
||||
.right(px(-1.0))
|
||||
.w(px(441.))
|
||||
.h(px(167.))
|
||||
.child(
|
||||
Vector::new(
|
||||
VectorName::Grid,
|
||||
rems_from_px(441.),
|
||||
rems_from_px(167.),
|
||||
)
|
||||
.color(ui::Color::Custom(cx.theme().colors().text.alpha(0.1))),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.absolute()
|
||||
.top(px(-8.0))
|
||||
.right_0()
|
||||
.w(px(400.))
|
||||
.h(px(92.))
|
||||
.child(
|
||||
Vector::new(
|
||||
VectorName::AiGrid,
|
||||
rems_from_px(400.),
|
||||
rems_from_px(92.),
|
||||
)
|
||||
.color(ui::Color::Custom(cx.theme().colors().text.alpha(0.32))),
|
||||
),
|
||||
)
|
||||
// .child(
|
||||
// div()
|
||||
// .absolute()
|
||||
// .top_0()
|
||||
// .right(px(360.))
|
||||
// .size(px(401.))
|
||||
// .overflow_hidden()
|
||||
// .bg(cx.theme().colors().panel_background)
|
||||
// )
|
||||
.child(
|
||||
div()
|
||||
.absolute()
|
||||
.top_0()
|
||||
.right_0()
|
||||
.w(px(660.))
|
||||
.h(px(401.))
|
||||
.overflow_hidden()
|
||||
.bg(linear_gradient(
|
||||
75.,
|
||||
linear_color_stop(
|
||||
cx.theme().colors().panel_background.alpha(0.01),
|
||||
1.0,
|
||||
),
|
||||
linear_color_stop(cx.theme().colors().panel_background, 0.45),
|
||||
)),
|
||||
)
|
||||
.child(content),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_active_thread_or_empty_state(
|
||||
&self,
|
||||
window: &mut Window,
|
||||
@@ -2827,6 +2916,7 @@ impl Render for AgentPanel {
|
||||
.on_action(cx.listener(Self::toggle_zoom))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
.children(self.render_trial_upsell(window, cx))
|
||||
.children(self.render_trial_end_upsell(window, cx))
|
||||
.map(|parent| match &self.active_view {
|
||||
ActiveView::Thread { .. } => parent
|
||||
.relative()
|
||||
@@ -3014,25 +3104,14 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
const DISMISSED_TRIAL_UPSELL_KEY: &str = "dismissed-trial-upsell";
|
||||
struct TrialUpsell;
|
||||
|
||||
fn dismissed_trial_upsell() -> bool {
|
||||
db::kvp::KEY_VALUE_STORE
|
||||
.read_kvp(DISMISSED_TRIAL_UPSELL_KEY)
|
||||
.log_err()
|
||||
.map_or(false, |s| s.is_some())
|
||||
impl Dismissable for TrialUpsell {
|
||||
const KEY: &'static str = "dismissed-trial-upsell";
|
||||
}
|
||||
|
||||
fn set_trial_upsell_dismissed(is_dismissed: bool, cx: &mut App) {
|
||||
db::write_and_log(cx, move || async move {
|
||||
if is_dismissed {
|
||||
db::kvp::KEY_VALUE_STORE
|
||||
.write_kvp(DISMISSED_TRIAL_UPSELL_KEY.into(), "1".into())
|
||||
.await
|
||||
} else {
|
||||
db::kvp::KEY_VALUE_STORE
|
||||
.delete_kvp(DISMISSED_TRIAL_UPSELL_KEY.into())
|
||||
.await
|
||||
}
|
||||
})
|
||||
struct TrialEndUpsell;
|
||||
|
||||
impl Dismissable for TrialEndUpsell {
|
||||
const KEY: &'static str = "dismissed-trial-end-upsell";
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist};
|
||||
use crate::{RemoveAllContext, ToggleContextPicker};
|
||||
use client::ErrorExt;
|
||||
use collections::VecDeque;
|
||||
use db::kvp::Dismissable;
|
||||
use editor::display_map::EditorMargins;
|
||||
use editor::{
|
||||
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
|
||||
@@ -33,7 +34,6 @@ use ui::utils::WithRemSize;
|
||||
use ui::{
|
||||
CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub struct PromptEditor<T> {
|
||||
@@ -722,7 +722,7 @@ impl<T: 'static> PromptEditor<T> {
|
||||
.child(CheckboxWithLabel::new(
|
||||
"dont-show-again",
|
||||
Label::new("Don't show again"),
|
||||
if dismissed_rate_limit_notice() {
|
||||
if RateLimitNotice::dismissed() {
|
||||
ui::ToggleState::Selected
|
||||
} else {
|
||||
ui::ToggleState::Unselected
|
||||
@@ -734,7 +734,7 @@ impl<T: 'static> PromptEditor<T> {
|
||||
ui::ToggleState::Selected => true,
|
||||
};
|
||||
|
||||
set_rate_limit_notice_dismissed(is_dismissed, cx)
|
||||
RateLimitNotice::set_dismissed(is_dismissed, cx);
|
||||
},
|
||||
))
|
||||
.child(
|
||||
@@ -974,7 +974,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
&& !RateLimitNotice::dismissed()
|
||||
{
|
||||
self.show_rate_limit_notice = true;
|
||||
cx.notify();
|
||||
@@ -1180,27 +1180,10 @@ impl PromptEditor<TerminalCodegen> {
|
||||
}
|
||||
}
|
||||
|
||||
const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
|
||||
struct RateLimitNotice;
|
||||
|
||||
fn dismissed_rate_limit_notice() -> bool {
|
||||
db::kvp::KEY_VALUE_STORE
|
||||
.read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
|
||||
.log_err()
|
||||
.map_or(false, |s| s.is_some())
|
||||
}
|
||||
|
||||
fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut App) {
|
||||
db::write_and_log(cx, move || async move {
|
||||
if is_dismissed {
|
||||
db::kvp::KEY_VALUE_STORE
|
||||
.write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
|
||||
.await
|
||||
} else {
|
||||
db::kvp::KEY_VALUE_STORE
|
||||
.delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
|
||||
.await
|
||||
}
|
||||
})
|
||||
impl Dismissable for RateLimitNotice {
|
||||
const KEY: &'static str = "dismissed-rate-limit-notice";
|
||||
}
|
||||
|
||||
pub enum CodegenStatus {
|
||||
|
||||
@@ -23,6 +23,7 @@ log.workspace = true
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
mistral = { workspace = true, features = ["schemars"] }
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
|
||||
@@ -10,6 +10,7 @@ use deepseek::Model as DeepseekModel;
|
||||
use gpui::{App, Pixels, SharedString};
|
||||
use language_model::{CloudModel, LanguageModel};
|
||||
use lmstudio::Model as LmStudioModel;
|
||||
use mistral::Model as MistralModel;
|
||||
use ollama::Model as OllamaModel;
|
||||
use schemars::{JsonSchema, schema::Schema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -71,6 +72,11 @@ pub enum AssistantProviderContentV1 {
|
||||
default_model: Option<DeepseekModel>,
|
||||
api_url: Option<String>,
|
||||
},
|
||||
#[serde(rename = "mistral")]
|
||||
Mistral {
|
||||
default_model: Option<MistralModel>,
|
||||
api_url: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug)]
|
||||
@@ -249,6 +255,12 @@ impl AssistantSettingsContent {
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
AssistantProviderContentV1::Mistral { default_model, .. } => {
|
||||
default_model.map(|model| LanguageModelSelection {
|
||||
provider: "mistral".into(),
|
||||
model: model.id().to_string(),
|
||||
})
|
||||
}
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
@@ -700,6 +712,7 @@ impl JsonSchema for LanguageModelProviderSetting {
|
||||
"zed.dev".into(),
|
||||
"copilot_chat".into(),
|
||||
"deepseek".into(),
|
||||
"mistral".into(),
|
||||
]),
|
||||
..Default::default()
|
||||
}
|
||||
|
||||
@@ -39,13 +39,22 @@ struct UpdateRequestBody {
|
||||
destination: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub enum VersionCheckType {
|
||||
Sha(String),
|
||||
Semantic(SemanticVersion),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub enum AutoUpdateStatus {
|
||||
Idle,
|
||||
Checking,
|
||||
Downloading,
|
||||
Installing,
|
||||
Updated { binary_path: PathBuf },
|
||||
Updated {
|
||||
binary_path: PathBuf,
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Errored,
|
||||
}
|
||||
|
||||
@@ -62,7 +71,7 @@ pub struct AutoUpdater {
|
||||
pending_poll: Option<Task<Option<()>>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct JsonRelease {
|
||||
pub version: String,
|
||||
pub url: String,
|
||||
@@ -307,7 +316,7 @@ impl AutoUpdater {
|
||||
}
|
||||
|
||||
pub fn poll(&mut self, cx: &mut Context<Self>) {
|
||||
if self.pending_poll.is_some() || self.status.is_updated() {
|
||||
if self.pending_poll.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -483,36 +492,63 @@ impl AutoUpdater {
|
||||
Self::get_release(this, asset, os, arch, None, release_channel, cx).await
|
||||
}
|
||||
|
||||
fn installed_update_version(&self) -> Option<VersionCheckType> {
|
||||
match &self.status {
|
||||
AutoUpdateStatus::Updated { version, .. } => Some(version.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn update(this: Entity<Self>, mut cx: AsyncApp) -> Result<()> {
|
||||
let (client, current_version, release_channel) = this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Checking;
|
||||
cx.notify();
|
||||
(
|
||||
this.http_client.clone(),
|
||||
this.current_version,
|
||||
ReleaseChannel::try_global(cx),
|
||||
)
|
||||
})?;
|
||||
let (client, current_version, installed_update_version, release_channel) =
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Checking;
|
||||
cx.notify();
|
||||
(
|
||||
this.http_client.clone(),
|
||||
this.current_version,
|
||||
this.installed_update_version(),
|
||||
ReleaseChannel::try_global(cx),
|
||||
)
|
||||
})?;
|
||||
|
||||
let release =
|
||||
Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
|
||||
|
||||
let should_download = match *RELEASE_CHANNEL {
|
||||
ReleaseChannel::Nightly => cx
|
||||
.update(|cx| AppCommitSha::try_global(cx).map(|sha| release.version != sha.0))
|
||||
.ok()
|
||||
.flatten()
|
||||
.unwrap_or(true),
|
||||
_ => release.version.parse::<SemanticVersion>()? > current_version,
|
||||
let update_version_to_install = match *RELEASE_CHANNEL {
|
||||
ReleaseChannel::Nightly => {
|
||||
let should_download = cx
|
||||
.update(|cx| AppCommitSha::try_global(cx).map(|sha| release.version != sha.0))
|
||||
.ok()
|
||||
.flatten()
|
||||
.unwrap_or(true);
|
||||
|
||||
should_download.then(|| VersionCheckType::Sha(release.version.clone()))
|
||||
}
|
||||
_ => {
|
||||
let installed_version =
|
||||
installed_update_version.unwrap_or(VersionCheckType::Semantic(current_version));
|
||||
match installed_version {
|
||||
VersionCheckType::Sha(_) => {
|
||||
log::warn!("Unexpected SHA-based version in non-nightly build");
|
||||
Some(installed_version)
|
||||
}
|
||||
VersionCheckType::Semantic(semantic_comparison_version) => {
|
||||
let latest_release_version = release.version.parse::<SemanticVersion>()?;
|
||||
let should_download = latest_release_version > semantic_comparison_version;
|
||||
should_download.then(|| VersionCheckType::Semantic(latest_release_version))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !should_download {
|
||||
let Some(update_version) = update_version_to_install else {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Idle;
|
||||
cx.notify();
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Downloading;
|
||||
@@ -534,7 +570,7 @@ impl AutoUpdater {
|
||||
);
|
||||
|
||||
let downloaded_asset = installer_dir.path().join(filename);
|
||||
download_release(&downloaded_asset, release, client, &cx).await?;
|
||||
download_release(&downloaded_asset, release.clone(), client, &cx).await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Installing;
|
||||
@@ -551,7 +587,10 @@ impl AutoUpdater {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.set_should_show_update_notification(true, cx)
|
||||
.detach_and_log_err(cx);
|
||||
this.status = AutoUpdateStatus::Updated { binary_path };
|
||||
this.status = AutoUpdateStatus::Updated {
|
||||
binary_path,
|
||||
version: update_version,
|
||||
};
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ test-support = ["clock/test-support", "collections/test-support", "gpui/test-sup
|
||||
anyhow.workspace = true
|
||||
async-recursion = "0.3"
|
||||
async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manual-roots"] }
|
||||
base64.workspace = true
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
@@ -30,7 +29,6 @@ gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
http_client.workspace = true
|
||||
http_client_tls.workspace = true
|
||||
httparse = "1.10"
|
||||
log.workspace = true
|
||||
paths.workspace = true
|
||||
parking_lot.workspace = true
|
||||
@@ -49,7 +47,6 @@ text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http = "0.8"
|
||||
tokio-native-tls = "0.3"
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod proxy;
|
||||
mod socks;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
pub mod zed_urls;
|
||||
@@ -24,13 +24,13 @@ use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use parking_lot::RwLock;
|
||||
use postage::watch;
|
||||
use proxy::connect_proxy_stream;
|
||||
use rand::prelude::*;
|
||||
use release_channel::{AppVersion, ReleaseChannel};
|
||||
use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use socks::connect_socks_proxy_stream;
|
||||
use std::pin::Pin;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
@@ -1156,7 +1156,7 @@ impl Client {
|
||||
let handle = cx.update(|cx| gpui_tokio::Tokio::handle(cx)).ok().unwrap();
|
||||
let _guard = handle.enter();
|
||||
match proxy {
|
||||
Some(proxy) => connect_proxy_stream(&proxy, rpc_host).await?,
|
||||
Some(proxy) => connect_socks_proxy_stream(&proxy, rpc_host).await?,
|
||||
None => Box::new(TcpStream::connect(rpc_host).await?),
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
//! client proxy
|
||||
|
||||
mod http_proxy;
|
||||
mod socks_proxy;
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use http_client::Url;
|
||||
use http_proxy::{HttpProxyType, connect_http_proxy_stream, parse_http_proxy};
|
||||
use socks_proxy::{SocksVersion, connect_socks_proxy_stream, parse_socks_proxy};
|
||||
|
||||
pub(crate) async fn connect_proxy_stream(
|
||||
proxy: &Url,
|
||||
rpc_host: (&str, u16),
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
let Some(((proxy_domain, proxy_port), proxy_type)) = parse_proxy_type(proxy) else {
|
||||
// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
|
||||
// SOCKS proxies are often used in contexts where security and privacy are critical,
|
||||
// so any fallback could expose users to significant risks.
|
||||
return Err(anyhow!("Parsing proxy url failed"));
|
||||
};
|
||||
|
||||
// Connect to proxy and wrap protocol later
|
||||
let stream = tokio::net::TcpStream::connect((proxy_domain.as_str(), proxy_port))
|
||||
.await
|
||||
.context("Failed to connect to proxy")?;
|
||||
|
||||
let proxy_stream = match proxy_type {
|
||||
ProxyType::SocksProxy(proxy) => connect_socks_proxy_stream(stream, proxy, rpc_host).await?,
|
||||
ProxyType::HttpProxy(proxy) => {
|
||||
connect_http_proxy_stream(stream, proxy, rpc_host, &proxy_domain).await?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(proxy_stream)
|
||||
}
|
||||
|
||||
enum ProxyType<'t> {
|
||||
SocksProxy(SocksVersion<'t>),
|
||||
HttpProxy(HttpProxyType<'t>),
|
||||
}
|
||||
|
||||
fn parse_proxy_type<'t>(proxy: &'t Url) -> Option<((String, u16), ProxyType<'t>)> {
|
||||
let scheme = proxy.scheme();
|
||||
let host = proxy.host()?.to_string();
|
||||
let port = proxy.port_or_known_default()?;
|
||||
let proxy_type = match scheme {
|
||||
scheme if scheme.starts_with("socks") => {
|
||||
Some(ProxyType::SocksProxy(parse_socks_proxy(scheme, proxy)))
|
||||
}
|
||||
scheme if scheme.starts_with("http") => {
|
||||
Some(ProxyType::HttpProxy(parse_http_proxy(scheme, proxy)))
|
||||
}
|
||||
_ => None,
|
||||
}?;
|
||||
|
||||
Some(((host, port), proxy_type))
|
||||
}
|
||||
|
||||
pub(crate) trait AsyncReadWrite:
|
||||
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
}
|
||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static> AsyncReadWrite
|
||||
for T
|
||||
{
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
use anyhow::{Context, Result};
|
||||
use base64::Engine;
|
||||
use httparse::{EMPTY_HEADER, Response};
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
|
||||
net::TcpStream,
|
||||
};
|
||||
use tokio_native_tls::{TlsConnector, native_tls};
|
||||
use url::Url;
|
||||
|
||||
use super::AsyncReadWrite;
|
||||
|
||||
pub(super) enum HttpProxyType<'t> {
|
||||
HTTP(Option<HttpProxyAuthorization<'t>>),
|
||||
HTTPS(Option<HttpProxyAuthorization<'t>>),
|
||||
}
|
||||
|
||||
pub(super) struct HttpProxyAuthorization<'t> {
|
||||
username: &'t str,
|
||||
password: &'t str,
|
||||
}
|
||||
|
||||
pub(super) fn parse_http_proxy<'t>(scheme: &str, proxy: &'t Url) -> HttpProxyType<'t> {
|
||||
let auth = proxy.password().map(|password| HttpProxyAuthorization {
|
||||
username: proxy.username(),
|
||||
password,
|
||||
});
|
||||
if scheme.starts_with("https") {
|
||||
HttpProxyType::HTTPS(auth)
|
||||
} else {
|
||||
HttpProxyType::HTTP(auth)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn connect_http_proxy_stream(
|
||||
stream: TcpStream,
|
||||
http_proxy: HttpProxyType<'_>,
|
||||
rpc_host: (&str, u16),
|
||||
proxy_domain: &str,
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
match http_proxy {
|
||||
HttpProxyType::HTTP(auth) => http_connect(stream, rpc_host, auth).await,
|
||||
HttpProxyType::HTTPS(auth) => https_connect(stream, rpc_host, auth, proxy_domain).await,
|
||||
}
|
||||
.context("error connecting to http/https proxy")
|
||||
}
|
||||
|
||||
async fn http_connect<T>(
|
||||
stream: T,
|
||||
target: (&str, u16),
|
||||
auth: Option<HttpProxyAuthorization<'_>>,
|
||||
) -> Result<Box<dyn AsyncReadWrite>>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let mut stream = BufStream::new(stream);
|
||||
let request = make_request(target, auth);
|
||||
stream.write_all(request.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
check_response(&mut stream).await?;
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
|
||||
async fn https_connect<T>(
|
||||
stream: T,
|
||||
target: (&str, u16),
|
||||
auth: Option<HttpProxyAuthorization<'_>>,
|
||||
proxy_domain: &str,
|
||||
) -> Result<Box<dyn AsyncReadWrite>>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
|
||||
let stream = tls_connector.connect(proxy_domain, stream).await?;
|
||||
http_connect(stream, target, auth).await
|
||||
}
|
||||
|
||||
fn make_request(target: (&str, u16), auth: Option<HttpProxyAuthorization<'_>>) -> String {
|
||||
let (host, port) = target;
|
||||
let mut request = format!(
|
||||
"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\nProxy-Connection: Keep-Alive\r\n"
|
||||
);
|
||||
if let Some(HttpProxyAuthorization { username, password }) = auth {
|
||||
let auth =
|
||||
base64::prelude::BASE64_STANDARD.encode(format!("{username}:{password}").as_bytes());
|
||||
let auth = format!("Proxy-Authorization: Basic {auth}\r\n");
|
||||
request.push_str(&auth);
|
||||
}
|
||||
request.push_str("\r\n");
|
||||
request
|
||||
}
|
||||
|
||||
async fn check_response<T>(stream: &mut BufStream<T>) -> Result<()>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let response = recv_response(stream).await?;
|
||||
let mut dummy_headers = [EMPTY_HEADER; MAX_RESPONSE_HEADERS];
|
||||
let mut parser = Response::new(&mut dummy_headers);
|
||||
parser.parse(response.as_bytes())?;
|
||||
|
||||
match parser.code {
|
||||
Some(code) => {
|
||||
if code == 200 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Proxy connection failed with HTTP code: {code}"
|
||||
))
|
||||
}
|
||||
}
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Proxy connection failed with no HTTP code: {}",
|
||||
parser.reason.unwrap_or("Unknown reason")
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
const MAX_RESPONSE_HEADER_LENGTH: usize = 4096;
|
||||
const MAX_RESPONSE_HEADERS: usize = 16;
|
||||
|
||||
async fn recv_response<T>(stream: &mut BufStream<T>) -> Result<String>
|
||||
where
|
||||
T: AsyncReadWrite,
|
||||
{
|
||||
let mut response = String::new();
|
||||
loop {
|
||||
if stream.read_line(&mut response).await? == 0 {
|
||||
return Err(anyhow::anyhow!("End of stream"));
|
||||
}
|
||||
|
||||
if MAX_RESPONSE_HEADER_LENGTH < response.len() {
|
||||
return Err(anyhow::anyhow!("Maximum response header length exceeded"));
|
||||
}
|
||||
|
||||
if response.ends_with("\r\n\r\n") {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use url::Url;
|
||||
|
||||
use super::{HttpProxyAuthorization, HttpProxyType, parse_http_proxy};
|
||||
|
||||
#[test]
|
||||
fn test_parse_http_proxy() {
|
||||
let proxy = Url::parse("http://proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_http_proxy(scheme, &proxy);
|
||||
assert!(matches!(version, HttpProxyType::HTTP(None)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_http_proxy_with_auth() {
|
||||
let proxy = Url::parse("http://username:password@proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_http_proxy(scheme, &proxy);
|
||||
assert!(matches!(
|
||||
version,
|
||||
HttpProxyType::HTTP(Some(HttpProxyAuthorization {
|
||||
username: "username",
|
||||
password: "password"
|
||||
}))
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,15 @@
|
||||
//! socks proxy
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::net::TcpStream;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use http_client::Url;
|
||||
use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
|
||||
use url::Url;
|
||||
|
||||
use super::AsyncReadWrite;
|
||||
|
||||
/// Identification to a Socks V4 Proxy
|
||||
pub(super) struct Socks4Identification<'a> {
|
||||
struct Socks4Identification<'a> {
|
||||
user_id: &'a str,
|
||||
}
|
||||
|
||||
/// Authorization to a Socks V5 Proxy
|
||||
pub(super) struct Socks5Authorization<'a> {
|
||||
struct Socks5Authorization<'a> {
|
||||
username: &'a str,
|
||||
password: &'a str,
|
||||
}
|
||||
@@ -22,50 +18,45 @@ pub(super) struct Socks5Authorization<'a> {
|
||||
///
|
||||
/// V4 allows idenfication using a user_id
|
||||
/// V5 allows authorization using a username and password
|
||||
pub(super) enum SocksVersion<'a> {
|
||||
enum SocksVersion<'a> {
|
||||
V4(Option<Socks4Identification<'a>>),
|
||||
V5(Option<Socks5Authorization<'a>>),
|
||||
}
|
||||
|
||||
pub(super) fn parse_socks_proxy<'t>(scheme: &str, proxy: &'t Url) -> SocksVersion<'t> {
|
||||
if scheme.starts_with("socks4") {
|
||||
let identification = match proxy.username() {
|
||||
"" => None,
|
||||
username => Some(Socks4Identification { user_id: username }),
|
||||
};
|
||||
SocksVersion::V4(identification)
|
||||
} else {
|
||||
let authorization = proxy.password().map(|password| Socks5Authorization {
|
||||
username: proxy.username(),
|
||||
password,
|
||||
});
|
||||
SocksVersion::V5(authorization)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn connect_socks_proxy_stream(
|
||||
stream: TcpStream,
|
||||
socks_version: SocksVersion<'_>,
|
||||
pub(crate) async fn connect_socks_proxy_stream(
|
||||
proxy: &Url,
|
||||
rpc_host: (&str, u16),
|
||||
) -> Result<Box<dyn AsyncReadWrite>> {
|
||||
match socks_version {
|
||||
let Some((socks_proxy, version)) = parse_socks_proxy(proxy) else {
|
||||
// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
|
||||
// SOCKS proxies are often used in contexts where security and privacy are critical,
|
||||
// so any fallback could expose users to significant risks.
|
||||
return Err(anyhow!("Parsing proxy url failed"));
|
||||
};
|
||||
|
||||
// Connect to proxy and wrap protocol later
|
||||
let stream = tokio::net::TcpStream::connect(socks_proxy)
|
||||
.await
|
||||
.context("Failed to connect to socks proxy")?;
|
||||
|
||||
let socks: Box<dyn AsyncReadWrite> = match version {
|
||||
SocksVersion::V4(None) => {
|
||||
let socks = Socks4Stream::connect_with_socket(stream, rpc_host)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Ok(Box::new(socks))
|
||||
Box::new(socks)
|
||||
}
|
||||
SocksVersion::V4(Some(Socks4Identification { user_id })) => {
|
||||
let socks = Socks4Stream::connect_with_userid_and_socket(stream, rpc_host, user_id)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Ok(Box::new(socks))
|
||||
Box::new(socks)
|
||||
}
|
||||
SocksVersion::V5(None) => {
|
||||
let socks = Socks5Stream::connect_with_socket(stream, rpc_host)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Ok(Box::new(socks))
|
||||
Box::new(socks)
|
||||
}
|
||||
SocksVersion::V5(Some(Socks5Authorization { username, password })) => {
|
||||
let socks = Socks5Stream::connect_with_password_and_socket(
|
||||
@@ -73,9 +64,44 @@ pub(super) async fn connect_socks_proxy_stream(
|
||||
)
|
||||
.await
|
||||
.context("error connecting to socks")?;
|
||||
Ok(Box::new(socks))
|
||||
Box::new(socks)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(socks)
|
||||
}
|
||||
|
||||
fn parse_socks_proxy(proxy: &Url) -> Option<((String, u16), SocksVersion<'_>)> {
|
||||
let scheme = proxy.scheme();
|
||||
let socks_version = if scheme.starts_with("socks4") {
|
||||
let identification = match proxy.username() {
|
||||
"" => None,
|
||||
username => Some(Socks4Identification { user_id: username }),
|
||||
};
|
||||
SocksVersion::V4(identification)
|
||||
} else if scheme.starts_with("socks") {
|
||||
let authorization = proxy.password().map(|password| Socks5Authorization {
|
||||
username: proxy.username(),
|
||||
password,
|
||||
});
|
||||
SocksVersion::V5(authorization)
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let host = proxy.host()?.to_string();
|
||||
let port = proxy.port_or_known_default()?;
|
||||
|
||||
Some(((host, port), socks_version))
|
||||
}
|
||||
|
||||
pub(crate) trait AsyncReadWrite:
|
||||
tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static
|
||||
{
|
||||
}
|
||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static> AsyncReadWrite
|
||||
for T
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -87,18 +113,20 @@ mod tests {
|
||||
#[test]
|
||||
fn parse_socks4() {
|
||||
let proxy = Url::parse("socks4://proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
assert!(matches!(version, SocksVersion::V4(None)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_socks4_with_identification() {
|
||||
let proxy = Url::parse("socks4://userid@proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
assert!(matches!(
|
||||
version,
|
||||
SocksVersion::V4(Some(Socks4Identification { user_id: "userid" }))
|
||||
@@ -108,18 +136,20 @@ mod tests {
|
||||
#[test]
|
||||
fn parse_socks5() {
|
||||
let proxy = Url::parse("socks5://proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
assert!(matches!(version, SocksVersion::V5(None)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_socks5_with_authorization() {
|
||||
let proxy = Url::parse("socks5://username:password@proxy.example.com:1080").unwrap();
|
||||
let scheme = proxy.scheme();
|
||||
|
||||
let version = parse_socks_proxy(scheme, &proxy);
|
||||
let ((host, port), version) = parse_socks_proxy(&proxy).unwrap();
|
||||
assert_eq!(host, "proxy.example.com");
|
||||
assert_eq!(port, 1080);
|
||||
assert!(matches!(
|
||||
version,
|
||||
SocksVersion::V5(Some(Socks5Authorization {
|
||||
@@ -128,4 +158,19 @@ mod tests {
|
||||
}))
|
||||
))
|
||||
}
|
||||
|
||||
/// If parsing the proxy URL fails, we must avoid falling back to an insecure connection.
|
||||
/// SOCKS proxies are often used in contexts where security and privacy are critical,
|
||||
/// so any fallback could expose users to significant risks.
|
||||
#[tokio::test]
|
||||
async fn fails_on_bad_proxy() {
|
||||
// Should fail connecting because http is not a valid Socks proxy scheme
|
||||
let proxy = Url::parse("http://localhost:2313").unwrap();
|
||||
|
||||
let result = connect_socks_proxy_stream(&proxy, ("test", 1080)).await;
|
||||
match result {
|
||||
Err(e) => assert_eq!(e.to_string(), "Parsing proxy url failed"),
|
||||
Ok(_) => panic!("Connecting on bad proxy should fail"),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -17,9 +17,8 @@ use stripe::{
|
||||
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirm,
|
||||
CreateBillingPortalSessionFlowDataSubscriptionUpdateConfirmItems,
|
||||
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
|
||||
EventType, Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId,
|
||||
SubscriptionStatus,
|
||||
CreateBillingPortalSessionFlowDataType, Customer, CustomerId, EventObject, EventType,
|
||||
Expandable, ListEvents, PaymentMethod, Subscription, SubscriptionId, SubscriptionStatus,
|
||||
};
|
||||
use util::{ResultExt, maybe};
|
||||
|
||||
@@ -280,7 +279,7 @@ async fn list_billing_subscriptions(
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserialize)]
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ProductCode {
|
||||
ZedPro,
|
||||
@@ -291,7 +290,7 @@ enum ProductCode {
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CreateBillingSubscriptionBody {
|
||||
github_user_id: i32,
|
||||
product: Option<ProductCode>,
|
||||
product: ProductCode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -310,13 +309,6 @@ async fn create_billing_subscription(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let Some(stripe_client) = app.stripe_client.clone() else {
|
||||
log::error!("failed to retrieve Stripe client");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::error!("failed to retrieve Stripe billing object");
|
||||
Err(Error::http(
|
||||
@@ -325,11 +317,16 @@ async fn create_billing_subscription(
|
||||
))?
|
||||
};
|
||||
|
||||
if app.db.has_active_billing_subscription(user.id).await? {
|
||||
return Err(Error::http(
|
||||
StatusCode::CONFLICT,
|
||||
"user already has an active subscription".into(),
|
||||
));
|
||||
if let Some(existing_subscription) = app.db.get_active_billing_subscription(user.id).await? {
|
||||
let is_checkout_allowed = body.product == ProductCode::ZedProTrial
|
||||
&& existing_subscription.kind == Some(SubscriptionKind::ZedFree);
|
||||
|
||||
if !is_checkout_allowed {
|
||||
return Err(Error::http(
|
||||
StatusCode::CONFLICT,
|
||||
"user already has an active subscription".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
|
||||
@@ -346,35 +343,9 @@ async fn create_billing_subscription(
|
||||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?
|
||||
} else {
|
||||
let existing_customer = if let Some(email) = user.email_address.as_deref() {
|
||||
let customers = Customer::list(
|
||||
&stripe_client,
|
||||
&stripe::ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customers.data.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&stripe_client,
|
||||
CreateCustomer {
|
||||
email: user.email_address.as_deref(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
}
|
||||
stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?
|
||||
};
|
||||
|
||||
let success_url = format!(
|
||||
@@ -383,12 +354,12 @@ async fn create_billing_subscription(
|
||||
);
|
||||
|
||||
let checkout_session_url = match body.product {
|
||||
Some(ProductCode::ZedPro) => {
|
||||
ProductCode::ZedPro => {
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro(customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedProTrial) => {
|
||||
ProductCode::ZedProTrial => {
|
||||
if let Some(existing_billing_customer) = &existing_billing_customer {
|
||||
if existing_billing_customer.trial_started_at.is_some() {
|
||||
return Err(Error::http(
|
||||
@@ -409,17 +380,11 @@ async fn create_billing_subscription(
|
||||
)
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedFree) => {
|
||||
ProductCode::ZedFree => {
|
||||
stripe_billing
|
||||
.checkout_with_zed_free(customer_id, &user.github_login, &success_url)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No product selected".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(CreateBillingSubscriptionResponse {
|
||||
@@ -972,7 +937,7 @@ async fn poll_stripe_events(
|
||||
.create_processed_stripe_event(&processed_event_params)
|
||||
.await?;
|
||||
|
||||
return Ok(());
|
||||
continue;
|
||||
}
|
||||
|
||||
let process_result = match event.type_ {
|
||||
@@ -1142,31 +1107,51 @@ async fn sync_subscription(
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
// If the user already has an active billing subscription, ignore the
|
||||
// event and return an `Ok` to signal that it was processed
|
||||
// successfully.
|
||||
//
|
||||
// There is the possibility that this could cause us to not create a
|
||||
// subscription in the following scenario:
|
||||
//
|
||||
// 1. User has an active subscription A
|
||||
// 2. User cancels subscription A
|
||||
// 3. User creates a new subscription B
|
||||
// 4. We process the new subscription B before the cancellation of subscription A
|
||||
// 5. User ends up with no subscriptions
|
||||
//
|
||||
// In theory this situation shouldn't arise as we try to process the events in the order they occur.
|
||||
if app
|
||||
if let Some(existing_subscription) = app
|
||||
.db
|
||||
.has_active_billing_subscription(billing_customer.user_id)
|
||||
.get_active_billing_subscription(billing_customer.user_id)
|
||||
.await?
|
||||
{
|
||||
log::info!(
|
||||
"user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
|
||||
user_id = billing_customer.user_id,
|
||||
subscription_id = subscription.id
|
||||
);
|
||||
return Ok(billing_customer);
|
||||
if existing_subscription.kind == Some(SubscriptionKind::ZedFree)
|
||||
&& subscription_kind == Some(SubscriptionKind::ZedProTrial)
|
||||
{
|
||||
let stripe_subscription_id = existing_subscription
|
||||
.stripe_subscription_id
|
||||
.parse::<stripe::SubscriptionId>()
|
||||
.context("failed to parse Stripe subscription ID from database")?;
|
||||
|
||||
Subscription::cancel(
|
||||
&stripe_client,
|
||||
&stripe_subscription_id,
|
||||
stripe::CancelSubscription {
|
||||
invoice_now: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
// If the user already has an active billing subscription, ignore the
|
||||
// event and return an `Ok` to signal that it was processed
|
||||
// successfully.
|
||||
//
|
||||
// There is the possibility that this could cause us to not create a
|
||||
// subscription in the following scenario:
|
||||
//
|
||||
// 1. User has an active subscription A
|
||||
// 2. User cancels subscription A
|
||||
// 3. User creates a new subscription B
|
||||
// 4. We process the new subscription B before the cancellation of subscription A
|
||||
// 5. User ends up with no subscriptions
|
||||
//
|
||||
// In theory this situation shouldn't arise as we try to process the events in the order they occur.
|
||||
|
||||
log::info!(
|
||||
"user {user_id} already has an active subscription, skipping creation of subscription {subscription_id}",
|
||||
user_id = billing_customer.user_id,
|
||||
subscription_id = subscription.id
|
||||
);
|
||||
return Ok(billing_customer);
|
||||
}
|
||||
}
|
||||
|
||||
app.db
|
||||
@@ -1185,6 +1170,27 @@ async fn sync_subscription(
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(stripe_billing) = app.stripe_billing.as_ref() {
|
||||
if subscription.status == SubscriptionStatus::Canceled
|
||||
|| subscription.status == SubscriptionStatus::Paused
|
||||
{
|
||||
let already_has_active_billing_subscription = app
|
||||
.db
|
||||
.has_active_billing_subscription(billing_customer.user_id)
|
||||
.await?;
|
||||
if !already_has_active_billing_subscription {
|
||||
let stripe_customer_id = billing_customer
|
||||
.stripe_customer_id
|
||||
.parse::<stripe::CustomerId>()
|
||||
.context("failed to parse Stripe customer ID from database")?;
|
||||
|
||||
stripe_billing
|
||||
.subscribe_to_zed_free(stripe_customer_id)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(billing_customer)
|
||||
}
|
||||
|
||||
@@ -1447,7 +1453,7 @@ impl From<CancellationDetailsReason> for StripeCancellationReason {
|
||||
}
|
||||
|
||||
/// Finds or creates a billing customer using the provided customer.
|
||||
async fn find_or_create_billing_customer(
|
||||
pub async fn find_or_create_billing_customer(
|
||||
app: &Arc<AppState>,
|
||||
stripe_client: &stripe::Client,
|
||||
customer_or_id: Expandable<Customer>,
|
||||
|
||||
@@ -32,9 +32,9 @@ impl Database {
|
||||
pub async fn create_billing_subscription(
|
||||
&self,
|
||||
params: &CreateBillingSubscriptionParams,
|
||||
) -> Result<()> {
|
||||
) -> Result<billing_subscription::Model> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
let id = billing_subscription::Entity::insert(billing_subscription::ActiveModel {
|
||||
billing_customer_id: ActiveValue::set(params.billing_customer_id),
|
||||
kind: ActiveValue::set(params.kind),
|
||||
stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()),
|
||||
@@ -44,10 +44,14 @@ impl Database {
|
||||
stripe_current_period_end: ActiveValue::set(params.stripe_current_period_end),
|
||||
..Default::default()
|
||||
})
|
||||
.exec_without_returning(&*tx)
|
||||
.await?;
|
||||
.exec(&*tx)
|
||||
.await?
|
||||
.last_insert_id;
|
||||
|
||||
Ok(())
|
||||
Ok(billing_subscription::Entity::find_by_id(id)
|
||||
.one(&*tx)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("failed to retrieve inserted billing subscription"))?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
@@ -236,7 +240,9 @@ impl Database {
|
||||
.filter(
|
||||
billing_customer::Column::UserId.eq(user_id).and(
|
||||
billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Active),
|
||||
.eq(StripeSubscriptionStatus::Active)
|
||||
.or(billing_subscription::Column::StripeSubscriptionStatus
|
||||
.eq(StripeSubscriptionStatus::Trialing)),
|
||||
),
|
||||
)
|
||||
.count(&*tx)
|
||||
|
||||
@@ -42,7 +42,7 @@ impl LlmTokenClaims {
|
||||
is_staff: bool,
|
||||
billing_preferences: Option<billing_preference::Model>,
|
||||
feature_flags: &Vec<String>,
|
||||
subscription: Option<billing_subscription::Model>,
|
||||
subscription: billing_subscription::Model,
|
||||
system_id: Option<String>,
|
||||
config: &Config,
|
||||
) -> Result<String> {
|
||||
@@ -54,17 +54,14 @@ impl LlmTokenClaims {
|
||||
let plan = if is_staff {
|
||||
Plan::ZedPro
|
||||
} else {
|
||||
subscription
|
||||
.as_ref()
|
||||
.and_then(|subscription| subscription.kind)
|
||||
.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
subscription.kind.map_or(Plan::ZedFree, |kind| match kind {
|
||||
SubscriptionKind::ZedFree => Plan::ZedFree,
|
||||
SubscriptionKind::ZedPro => Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => Plan::ZedProTrial,
|
||||
})
|
||||
};
|
||||
let subscription_period =
|
||||
billing_subscription::Model::current_period(subscription, is_staff)
|
||||
billing_subscription::Model::current_period(Some(subscription), is_staff)
|
||||
.map(|(start, end)| (start.naive_utc(), end.naive_utc()))
|
||||
.ok_or_else(|| anyhow!("A plan is required to use Zed's hosted models or edit predictions. Visit https://zed.dev/account to get started."))?;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod connection_pool;
|
||||
|
||||
use crate::api::billing::find_or_create_billing_customer;
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::db::LlmDatabase;
|
||||
@@ -4024,7 +4025,56 @@ async fn get_llm_api_token(
|
||||
Err(anyhow!("terms of service not accepted"))?
|
||||
}
|
||||
|
||||
let billing_subscription = db.get_active_billing_subscription(user.id).await?;
|
||||
let Some(stripe_client) = session.app_state.stripe_client.as_ref() else {
|
||||
Err(anyhow!("failed to retrieve Stripe client"))?
|
||||
};
|
||||
|
||||
let Some(stripe_billing) = session.app_state.stripe_billing.as_ref() else {
|
||||
Err(anyhow!("failed to retrieve Stripe billing object"))?
|
||||
};
|
||||
|
||||
let billing_customer =
|
||||
if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
|
||||
billing_customer
|
||||
} else {
|
||||
let customer_id = stripe_billing
|
||||
.find_or_create_customer_by_email(user.email_address.as_deref())
|
||||
.await?;
|
||||
|
||||
find_or_create_billing_customer(
|
||||
&session.app_state,
|
||||
&stripe_client,
|
||||
stripe::Expandable::Id(customer_id),
|
||||
)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("billing customer not found"))?
|
||||
};
|
||||
|
||||
let billing_subscription =
|
||||
if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
|
||||
billing_subscription
|
||||
} else {
|
||||
let stripe_customer_id = billing_customer
|
||||
.stripe_customer_id
|
||||
.parse::<stripe::CustomerId>()
|
||||
.context("failed to parse Stripe customer ID from database")?;
|
||||
|
||||
let stripe_subscription = stripe_billing
|
||||
.subscribe_to_zed_free(stripe_customer_id)
|
||||
.await?;
|
||||
|
||||
db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
|
||||
billing_customer_id: billing_customer.id,
|
||||
kind: Some(SubscriptionKind::ZedFree),
|
||||
stripe_subscription_id: stripe_subscription.id.to_string(),
|
||||
stripe_subscription_status: stripe_subscription.status.into(),
|
||||
stripe_cancellation_reason: None,
|
||||
stripe_current_period_start: Some(stripe_subscription.current_period_start),
|
||||
stripe_current_period_end: Some(stripe_subscription.current_period_end),
|
||||
})
|
||||
.await?
|
||||
};
|
||||
|
||||
let billing_preferences = db.get_billing_preferences(user.id).await?;
|
||||
|
||||
let token = LlmTokenClaims::create(
|
||||
|
||||
@@ -7,7 +7,7 @@ use anyhow::{Context as _, anyhow};
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::{PriceId, SubscriptionStatus};
|
||||
use stripe::{CreateCustomer, Customer, CustomerId, PriceId, SubscriptionStatus};
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -122,6 +122,47 @@ impl StripeBilling {
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the Stripe customer associated with the provided email address, or creates a new customer, if one does
|
||||
/// not already exist.
|
||||
///
|
||||
/// Always returns a new Stripe customer if the email address is `None`.
|
||||
pub async fn find_or_create_customer_by_email(
|
||||
&self,
|
||||
email_address: Option<&str>,
|
||||
) -> Result<CustomerId> {
|
||||
let existing_customer = if let Some(email) = email_address {
|
||||
let customers = Customer::list(
|
||||
&self.client,
|
||||
&stripe::ListCustomers {
|
||||
email: Some(email),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customers.data.first().cloned()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let customer_id = if let Some(existing_customer) = existing_customer {
|
||||
existing_customer.id
|
||||
} else {
|
||||
let customer = Customer::create(
|
||||
&self.client,
|
||||
CreateCustomer {
|
||||
email: email_address,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
customer.id
|
||||
};
|
||||
|
||||
Ok(customer_id)
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
@@ -203,9 +244,6 @@ impl StripeBilling {
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
// Should be based on location: https://docs.stripe.com/tax/checkout/tax-ids
|
||||
params.tax_id_collection =
|
||||
Some(stripe::CreateCheckoutSessionTaxIdCollection { enabled: true });
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
@@ -240,7 +278,7 @@ impl StripeBilling {
|
||||
trial_period_days: Some(trial_period_days),
|
||||
trial_settings: Some(stripe::CreateCheckoutSessionSubscriptionDataTrialSettings {
|
||||
end_behavior: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehavior {
|
||||
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Pause,
|
||||
missing_payment_method: stripe::CreateCheckoutSessionSubscriptionDataTrialSettingsEndBehaviorMissingPaymentMethod::Cancel,
|
||||
}
|
||||
}),
|
||||
metadata: if !subscription_metadata.is_empty() {
|
||||
@@ -260,15 +298,52 @@ impl StripeBilling {
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
// Should be based on location: https://docs.stripe.com/tax/checkout/tax-ids
|
||||
params.tax_id_collection =
|
||||
Some(stripe::CreateCheckoutSessionTaxIdCollection { enabled: true });
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_zed_free(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
) -> Result<stripe::Subscription> {
|
||||
let zed_free_price_id = self.zed_free_price_id().await?;
|
||||
|
||||
let existing_subscriptions = stripe::Subscription::list(
|
||||
&self.client,
|
||||
&stripe::ListSubscriptions {
|
||||
customer: Some(customer_id.clone()),
|
||||
status: None,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let existing_active_subscription =
|
||||
existing_subscriptions
|
||||
.data
|
||||
.into_iter()
|
||||
.find(|subscription| {
|
||||
subscription.status == SubscriptionStatus::Active
|
||||
|| subscription.status == SubscriptionStatus::Trialing
|
||||
});
|
||||
if let Some(subscription) = existing_active_subscription {
|
||||
return Ok(subscription);
|
||||
}
|
||||
|
||||
let mut params = stripe::CreateSubscription::new(customer_id);
|
||||
params.items = Some(vec![stripe::CreateSubscriptionItems {
|
||||
price: Some(zed_free_price_id.to_string()),
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
|
||||
let subscription = stripe::Subscription::create(&self.client, params).await?;
|
||||
|
||||
Ok(subscription)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_free(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
@@ -288,9 +363,6 @@ impl StripeBilling {
|
||||
quantity: Some(1),
|
||||
..Default::default()
|
||||
}]);
|
||||
// Should be based on location: https://docs.stripe.com/tax/checkout/tax-ids
|
||||
params.tax_id_collection =
|
||||
Some(stripe::CreateCheckoutSessionTaxIdCollection { enabled: true });
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
|
||||
@@ -2517,7 +2517,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
let breakpoints_b = editor_b.update(cx_b, |editor, cx| {
|
||||
@@ -2526,7 +2526,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -2550,7 +2550,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
let breakpoints_b = editor_b.update(cx_b, |editor, cx| {
|
||||
@@ -2559,7 +2559,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -2583,7 +2583,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
let breakpoints_b = editor_b.update(cx_b, |editor, cx| {
|
||||
@@ -2592,7 +2592,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -2616,7 +2616,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
let breakpoints_b = editor_b.update(cx_b, |editor, cx| {
|
||||
@@ -2625,7 +2625,7 @@ async fn test_add_breakpoints(cx_a: &mut TestAppContext, cx_b: &mut TestAppConte
|
||||
.clone()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
|
||||
@@ -32,15 +32,17 @@ pub enum DapStatus {
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
pub trait DapDelegate {
|
||||
#[async_trait]
|
||||
pub trait DapDelegate: Send + Sync + 'static {
|
||||
fn worktree_id(&self) -> WorktreeId;
|
||||
fn worktree_root_path(&self) -> &Path;
|
||||
fn http_client(&self) -> Arc<dyn HttpClient>;
|
||||
fn node_runtime(&self) -> NodeRuntime;
|
||||
fn toolchain_store(&self) -> Arc<dyn LanguageToolchainStore>;
|
||||
fn fs(&self) -> Arc<dyn Fs>;
|
||||
fn output_to_console(&self, msg: String);
|
||||
fn which(&self, command: &OsStr) -> Option<PathBuf>;
|
||||
async fn which(&self, command: &OsStr) -> Option<PathBuf>;
|
||||
async fn read_text_file(&self, path: PathBuf) -> Result<String>;
|
||||
async fn shell_env(&self) -> collections::HashMap<String, String>;
|
||||
}
|
||||
|
||||
@@ -413,7 +415,7 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -472,7 +474,7 @@ impl DebugAdapter for FakeAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
_: &dyn DapDelegate,
|
||||
_: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
_: Option<PathBuf>,
|
||||
_: &mut AsyncApp,
|
||||
|
||||
@@ -6,6 +6,8 @@ pub mod proto_conversions;
|
||||
mod registry;
|
||||
pub mod transport;
|
||||
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
pub use dap_types::*;
|
||||
pub use registry::{DapLocator, DapRegistry};
|
||||
pub use task::DebugRequest;
|
||||
@@ -16,3 +18,19 @@ pub type StackFrameId = u64;
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use adapters::FakeAdapter;
|
||||
use task::TcpArgumentsTemplate;
|
||||
|
||||
pub async fn configure_tcp_connection(
|
||||
tcp_connection: TcpArgumentsTemplate,
|
||||
) -> anyhow::Result<(Ipv4Addr, u16, Option<u64>)> {
|
||||
let host = tcp_connection.host();
|
||||
let timeout = tcp_connection.timeout;
|
||||
|
||||
let port = if let Some(port) = tcp_connection.port {
|
||||
port
|
||||
} else {
|
||||
transport::TcpTransport::port(&tcp_connection).await?
|
||||
};
|
||||
|
||||
Ok((host, port, timeout))
|
||||
}
|
||||
|
||||
@@ -54,10 +54,6 @@ impl DapRegistry {
|
||||
pub fn add_adapter(&self, adapter: Arc<dyn DebugAdapter>) {
|
||||
let name = adapter.name();
|
||||
let _previous_value = self.0.write().adapters.insert(name, adapter);
|
||||
debug_assert!(
|
||||
_previous_value.is_none(),
|
||||
"Attempted to insert a new debug adapter when one is already registered"
|
||||
);
|
||||
}
|
||||
|
||||
pub fn adapter_language(&self, adapter_name: &str) -> Option<LanguageName> {
|
||||
|
||||
@@ -61,7 +61,7 @@ impl CodeLldbDebugAdapter {
|
||||
|
||||
async fn fetch_latest_adapter_version(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
) -> Result<AdapterVersion> {
|
||||
let release =
|
||||
latest_github_release("vadimcn/codelldb", true, false, delegate.http_client()).await?;
|
||||
@@ -111,7 +111,7 @@ impl DebugAdapter for CodeLldbDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
_: &mut AsyncApp,
|
||||
@@ -129,7 +129,7 @@ impl DebugAdapter for CodeLldbDebugAdapter {
|
||||
self.name(),
|
||||
version.clone(),
|
||||
adapters::DownloadedFileType::Vsix,
|
||||
delegate,
|
||||
delegate.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
let version_path =
|
||||
|
||||
@@ -6,7 +6,7 @@ mod php;
|
||||
mod python;
|
||||
mod ruby;
|
||||
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
@@ -17,6 +17,7 @@ use dap::{
|
||||
self, AdapterVersion, DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName,
|
||||
GithubRepo,
|
||||
},
|
||||
configure_tcp_connection,
|
||||
inline_value::{PythonInlineValueProvider, RustInlineValueProvider},
|
||||
};
|
||||
use gdb::GdbDebugAdapter;
|
||||
@@ -27,7 +28,6 @@ use php::PhpDebugAdapter;
|
||||
use python::PythonDebugAdapter;
|
||||
use ruby::RubyDebugAdapter;
|
||||
use serde_json::{Value, json};
|
||||
use task::TcpArgumentsTemplate;
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
cx.update_default_global(|registry: &mut DapRegistry, _cx| {
|
||||
@@ -45,21 +45,6 @@ pub fn init(cx: &mut App) {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn configure_tcp_connection(
|
||||
tcp_connection: TcpArgumentsTemplate,
|
||||
) -> Result<(Ipv4Addr, u16, Option<u64>)> {
|
||||
let host = tcp_connection.host();
|
||||
let timeout = tcp_connection.timeout;
|
||||
|
||||
let port = if let Some(port) = tcp_connection.port {
|
||||
port
|
||||
} else {
|
||||
dap::transport::TcpTransport::port(&tcp_connection).await?
|
||||
};
|
||||
|
||||
Ok((host, port, timeout))
|
||||
}
|
||||
|
||||
trait ToDap {
|
||||
fn to_dap(&self) -> dap::StartDebuggingRequestArgumentsRequest;
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ impl DebugAdapter for GdbDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<std::path::PathBuf>,
|
||||
_: &mut AsyncApp,
|
||||
@@ -76,6 +76,7 @@ impl DebugAdapter for GdbDebugAdapter {
|
||||
|
||||
let gdb_path = delegate
|
||||
.which(OsStr::new("gdb"))
|
||||
.await
|
||||
.and_then(|p| p.to_str().map(|s| s.to_string()))
|
||||
.ok_or(anyhow!("Could not find gdb in path"));
|
||||
|
||||
|
||||
@@ -50,13 +50,14 @@ impl DebugAdapter for GoDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
_user_installed_path: Option<PathBuf>,
|
||||
_cx: &mut AsyncApp,
|
||||
) -> Result<DebugAdapterBinary> {
|
||||
let delve_path = delegate
|
||||
.which(OsStr::new("dlv"))
|
||||
.await
|
||||
.and_then(|p| p.to_str().map(|p| p.to_string()))
|
||||
.ok_or(anyhow!("Dlv not found in path"))?;
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ impl JsDebugAdapter {
|
||||
|
||||
async fn fetch_latest_adapter_version(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
) -> Result<AdapterVersion> {
|
||||
let release = latest_github_release(
|
||||
&format!("{}/{}", "microsoft", Self::ADAPTER_NPM_NAME),
|
||||
@@ -82,7 +82,7 @@ impl JsDebugAdapter {
|
||||
|
||||
async fn get_installed_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
_: &mut AsyncApp,
|
||||
@@ -139,7 +139,7 @@ impl DebugAdapter for JsDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -151,7 +151,7 @@ impl DebugAdapter for JsDebugAdapter {
|
||||
self.name(),
|
||||
version,
|
||||
adapters::DownloadedFileType::GzipTar,
|
||||
delegate,
|
||||
delegate.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ impl PhpDebugAdapter {
|
||||
|
||||
async fn fetch_latest_adapter_version(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
) -> Result<AdapterVersion> {
|
||||
let release = latest_github_release(
|
||||
&format!("{}/{}", "xdebug", Self::ADAPTER_PACKAGE_NAME),
|
||||
@@ -66,7 +66,7 @@ impl PhpDebugAdapter {
|
||||
|
||||
async fn get_installed_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
_: &mut AsyncApp,
|
||||
@@ -126,7 +126,7 @@ impl DebugAdapter for PhpDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -138,7 +138,7 @@ impl DebugAdapter for PhpDebugAdapter {
|
||||
self.name(),
|
||||
version,
|
||||
adapters::DownloadedFileType::Vsix,
|
||||
delegate,
|
||||
delegate.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
@@ -52,26 +52,26 @@ impl PythonDebugAdapter {
|
||||
}
|
||||
async fn fetch_latest_adapter_version(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
) -> Result<AdapterVersion> {
|
||||
let github_repo = GithubRepo {
|
||||
repo_name: Self::ADAPTER_PACKAGE_NAME.into(),
|
||||
repo_owner: "microsoft".into(),
|
||||
};
|
||||
|
||||
adapters::fetch_latest_adapter_version_from_github(github_repo, delegate).await
|
||||
adapters::fetch_latest_adapter_version_from_github(github_repo, delegate.as_ref()).await
|
||||
}
|
||||
|
||||
async fn install_binary(
|
||||
&self,
|
||||
version: AdapterVersion,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
) -> Result<()> {
|
||||
let version_path = adapters::download_adapter_from_github(
|
||||
self.name(),
|
||||
version,
|
||||
adapters::DownloadedFileType::Zip,
|
||||
delegate,
|
||||
delegate.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -93,7 +93,7 @@ impl PythonDebugAdapter {
|
||||
|
||||
async fn get_installed_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
@@ -128,14 +128,18 @@ impl PythonDebugAdapter {
|
||||
let python_path = if let Some(toolchain) = toolchain {
|
||||
Some(toolchain.path.to_string())
|
||||
} else {
|
||||
BINARY_NAMES
|
||||
.iter()
|
||||
.filter_map(|cmd| {
|
||||
delegate
|
||||
.which(OsStr::new(cmd))
|
||||
.map(|path| path.to_string_lossy().to_string())
|
||||
})
|
||||
.find(|_| true)
|
||||
let mut name = None;
|
||||
|
||||
for cmd in BINARY_NAMES {
|
||||
name = delegate
|
||||
.which(OsStr::new(cmd))
|
||||
.await
|
||||
.map(|path| path.to_string_lossy().to_string());
|
||||
if name.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
name
|
||||
};
|
||||
|
||||
Ok(DebugAdapterBinary {
|
||||
@@ -172,7 +176,7 @@ impl DebugAdapter for PythonDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
|
||||
@@ -8,7 +8,7 @@ use dap::{
|
||||
};
|
||||
use gpui::{AsyncApp, SharedString};
|
||||
use language::LanguageName;
|
||||
use std::path::PathBuf;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::command::new_smol_command;
|
||||
|
||||
use crate::ToDap;
|
||||
@@ -32,7 +32,7 @@ impl DebugAdapter for RubyDebugAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
definition: &DebugTaskDefinition,
|
||||
_user_installed_path: Option<PathBuf>,
|
||||
_cx: &mut AsyncApp,
|
||||
@@ -40,7 +40,7 @@ impl DebugAdapter for RubyDebugAdapter {
|
||||
let adapter_path = paths::debug_adapters_dir().join(self.name().as_ref());
|
||||
let mut rdbg_path = adapter_path.join("rdbg");
|
||||
if !delegate.fs().is_file(&rdbg_path).await {
|
||||
match delegate.which("rdbg".as_ref()) {
|
||||
match delegate.which("rdbg".as_ref()).await {
|
||||
Some(path) => rdbg_path = path,
|
||||
None => {
|
||||
delegate.output_to_console(
|
||||
@@ -76,7 +76,7 @@ impl DebugAdapter for RubyDebugAdapter {
|
||||
format!("--port={}", port),
|
||||
format!("--host={}", host),
|
||||
];
|
||||
if delegate.which(launch.program.as_ref()).is_some() {
|
||||
if delegate.which(launch.program.as_ref()).await.is_some() {
|
||||
arguments.push("--command".to_string())
|
||||
}
|
||||
arguments.push(launch.program);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use gpui::App;
|
||||
use sqlez_macros::sql;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{define_connection, query};
|
||||
use crate::{define_connection, query, write_and_log};
|
||||
|
||||
define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
|
||||
&[sql!(
|
||||
@@ -11,6 +13,29 @@ define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> =
|
||||
)];
|
||||
);
|
||||
|
||||
pub trait Dismissable {
|
||||
const KEY: &'static str;
|
||||
|
||||
fn dismissed() -> bool {
|
||||
KEY_VALUE_STORE
|
||||
.read_kvp(Self::KEY)
|
||||
.log_err()
|
||||
.map_or(false, |s| s.is_some())
|
||||
}
|
||||
|
||||
fn set_dismissed(is_dismissed: bool, cx: &mut App) {
|
||||
write_and_log(cx, move || async move {
|
||||
if is_dismissed {
|
||||
KEY_VALUE_STORE
|
||||
.write_kvp(Self::KEY.into(), "1".into())
|
||||
.await
|
||||
} else {
|
||||
KEY_VALUE_STORE.delete_kvp(Self::KEY.into()).await
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyValueStore {
|
||||
query! {
|
||||
pub fn read_kvp(key: &str) -> Result<Option<String>> {
|
||||
|
||||
@@ -5,7 +5,7 @@ use async_trait::async_trait;
|
||||
use dap::adapters::{
|
||||
DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName, DebugTaskDefinition,
|
||||
};
|
||||
use extension::Extension;
|
||||
use extension::{Extension, WorktreeDelegate};
|
||||
use gpui::AsyncApp;
|
||||
|
||||
pub(crate) struct ExtensionDapAdapter {
|
||||
@@ -25,6 +25,35 @@ impl ExtensionDapAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
/// An adapter that allows an [`dap::adapters::DapDelegate`] to be used as a [`WorktreeDelegate`].
|
||||
struct WorktreeDelegateAdapter(pub Arc<dyn DapDelegate>);
|
||||
|
||||
#[async_trait]
|
||||
impl WorktreeDelegate for WorktreeDelegateAdapter {
|
||||
fn id(&self) -> u64 {
|
||||
self.0.worktree_id().to_proto()
|
||||
}
|
||||
|
||||
fn root_path(&self) -> String {
|
||||
self.0.worktree_root_path().to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
async fn read_text_file(&self, path: PathBuf) -> Result<String> {
|
||||
self.0.read_text_file(path).await
|
||||
}
|
||||
|
||||
async fn which(&self, binary_name: String) -> Option<String> {
|
||||
self.0
|
||||
.which(binary_name.as_ref())
|
||||
.await
|
||||
.map(|path| path.to_string_lossy().to_string())
|
||||
}
|
||||
|
||||
async fn shell_env(&self) -> Vec<(String, String)> {
|
||||
self.0.shell_env().await.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl DebugAdapter for ExtensionDapAdapter {
|
||||
fn name(&self) -> DebugAdapterName {
|
||||
@@ -33,7 +62,7 @@ impl DebugAdapter for ExtensionDapAdapter {
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
_: &dyn DapDelegate,
|
||||
delegate: &Arc<dyn DapDelegate>,
|
||||
config: &DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
_cx: &mut AsyncApp,
|
||||
@@ -43,6 +72,7 @@ impl DebugAdapter for ExtensionDapAdapter {
|
||||
self.debug_adapter_name.clone(),
|
||||
config.clone(),
|
||||
user_installed_path,
|
||||
Arc::new(WorktreeDelegateAdapter(delegate.clone())),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ impl Render for BreakpointList {
|
||||
cx: &mut ui::Context<Self>,
|
||||
) -> impl ui::IntoElement {
|
||||
let old_len = self.breakpoints.len();
|
||||
let breakpoints = self.breakpoint_store.read(cx).all_breakpoints(cx);
|
||||
let breakpoints = self.breakpoint_store.read(cx).all_source_breakpoints(cx);
|
||||
self.breakpoints.clear();
|
||||
let weak = cx.weak_entity();
|
||||
let breakpoints = breakpoints.into_iter().flat_map(|(path, mut breakpoints)| {
|
||||
|
||||
@@ -122,10 +122,11 @@ use markdown::Markdown;
|
||||
use mouse_context_menu::MouseContextMenu;
|
||||
use persistence::DB;
|
||||
use project::{
|
||||
ProjectPath,
|
||||
BreakpointWithPosition, ProjectPath,
|
||||
debugger::{
|
||||
breakpoint_store::{
|
||||
BreakpointEditAction, BreakpointState, BreakpointStore, BreakpointStoreEvent,
|
||||
BreakpointEditAction, BreakpointSessionState, BreakpointState, BreakpointStore,
|
||||
BreakpointStoreEvent,
|
||||
},
|
||||
session::{Session, SessionEvent},
|
||||
},
|
||||
@@ -198,7 +199,7 @@ use theme::{
|
||||
};
|
||||
use ui::{
|
||||
ButtonSize, ButtonStyle, ContextMenu, Disclosure, IconButton, IconButtonShape, IconName,
|
||||
IconSize, Key, Tooltip, h_flex, prelude::*,
|
||||
IconSize, Indicator, Key, Tooltip, h_flex, prelude::*,
|
||||
};
|
||||
use util::{RangeExt, ResultExt, TryFutureExt, maybe, post_inc};
|
||||
use workspace::{
|
||||
@@ -6997,7 +6998,7 @@ impl Editor {
|
||||
range: Range<DisplayRow>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> HashMap<DisplayRow, (Anchor, Breakpoint)> {
|
||||
) -> HashMap<DisplayRow, (Anchor, Breakpoint, Option<BreakpointSessionState>)> {
|
||||
let mut breakpoint_display_points = HashMap::default();
|
||||
|
||||
let Some(breakpoint_store) = self.breakpoint_store.clone() else {
|
||||
@@ -7031,15 +7032,17 @@ impl Editor {
|
||||
buffer_snapshot,
|
||||
cx,
|
||||
);
|
||||
for (anchor, breakpoint) in breakpoints {
|
||||
for (breakpoint, state) in breakpoints {
|
||||
let multi_buffer_anchor =
|
||||
Anchor::in_buffer(excerpt_id, buffer_snapshot.remote_id(), *anchor);
|
||||
Anchor::in_buffer(excerpt_id, buffer_snapshot.remote_id(), breakpoint.position);
|
||||
let position = multi_buffer_anchor
|
||||
.to_point(&multi_buffer_snapshot)
|
||||
.to_display_point(&snapshot);
|
||||
|
||||
breakpoint_display_points
|
||||
.insert(position.row(), (multi_buffer_anchor, breakpoint.clone()));
|
||||
breakpoint_display_points.insert(
|
||||
position.row(),
|
||||
(multi_buffer_anchor, breakpoint.bp.clone(), state),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7214,8 +7217,10 @@ impl Editor {
|
||||
position: Anchor,
|
||||
row: DisplayRow,
|
||||
breakpoint: &Breakpoint,
|
||||
state: Option<BreakpointSessionState>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> IconButton {
|
||||
let is_rejected = state.is_some_and(|s| !s.verified);
|
||||
// Is it a breakpoint that shows up when hovering over gutter?
|
||||
let (is_phantom, collides_with_existing) = self.gutter_breakpoint_indicator.0.map_or(
|
||||
(false, false),
|
||||
@@ -7241,6 +7246,8 @@ impl Editor {
|
||||
|
||||
let color = if is_phantom {
|
||||
Color::Hint
|
||||
} else if is_rejected {
|
||||
Color::Disabled
|
||||
} else {
|
||||
Color::Debugger
|
||||
};
|
||||
@@ -7268,9 +7275,18 @@ impl Editor {
|
||||
}
|
||||
let primary_text = SharedString::from(primary_text);
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let meta = if is_rejected {
|
||||
"No executable code is associated with this line."
|
||||
} else {
|
||||
"Right-click for more options."
|
||||
};
|
||||
IconButton::new(("breakpoint_indicator", row.0 as usize), icon)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.size(ui::ButtonSize::None)
|
||||
.when(is_rejected, |this| {
|
||||
this.indicator(Indicator::icon(Icon::new(IconName::Warning)).color(Color::Warning))
|
||||
})
|
||||
.icon_color(color)
|
||||
.style(ButtonStyle::Transparent)
|
||||
.on_click(cx.listener({
|
||||
@@ -7302,14 +7318,7 @@ impl Editor {
|
||||
);
|
||||
}))
|
||||
.tooltip(move |window, cx| {
|
||||
Tooltip::with_meta_in(
|
||||
primary_text.clone(),
|
||||
None,
|
||||
"Right-click for more options",
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
Tooltip::with_meta_in(primary_text.clone(), None, meta, &focus_handle, window, cx)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7449,11 +7458,11 @@ impl Editor {
|
||||
_style: &EditorStyle,
|
||||
is_active: bool,
|
||||
row: DisplayRow,
|
||||
breakpoint: Option<(Anchor, Breakpoint)>,
|
||||
breakpoint: Option<(Anchor, Breakpoint, Option<BreakpointSessionState>)>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> IconButton {
|
||||
let color = Color::Muted;
|
||||
let position = breakpoint.as_ref().map(|(anchor, _)| *anchor);
|
||||
let position = breakpoint.as_ref().map(|(anchor, _, _)| *anchor);
|
||||
|
||||
IconButton::new(("run_indicator", row.0 as usize), ui::IconName::Play)
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
@@ -9633,16 +9642,16 @@ impl Editor {
|
||||
cx,
|
||||
)
|
||||
.next()
|
||||
.and_then(|(anchor, bp)| {
|
||||
.and_then(|(bp, _)| {
|
||||
let breakpoint_row = buffer_snapshot
|
||||
.summary_for_anchor::<text::PointUtf16>(anchor)
|
||||
.summary_for_anchor::<text::PointUtf16>(&bp.position)
|
||||
.row;
|
||||
|
||||
if breakpoint_row == row {
|
||||
snapshot
|
||||
.buffer_snapshot
|
||||
.anchor_in_excerpt(enclosing_excerpt, *anchor)
|
||||
.map(|anchor| (anchor, bp.clone()))
|
||||
.anchor_in_excerpt(enclosing_excerpt, bp.position)
|
||||
.map(|position| (position, bp.bp.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -9805,7 +9814,10 @@ impl Editor {
|
||||
breakpoint_store.update(cx, |breakpoint_store, cx| {
|
||||
breakpoint_store.toggle_breakpoint(
|
||||
buffer,
|
||||
(breakpoint_position.text_anchor, breakpoint),
|
||||
BreakpointWithPosition {
|
||||
position: breakpoint_position.text_anchor,
|
||||
bp: breakpoint,
|
||||
},
|
||||
edit_action,
|
||||
cx,
|
||||
);
|
||||
|
||||
@@ -6,6 +6,8 @@ use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources, VsCodeSettings};
|
||||
use util::serde::default_true;
|
||||
|
||||
/// Imports from the VSCode settings at
|
||||
/// https://code.visualstudio.com/docs/reference/default-settings
|
||||
#[derive(Deserialize, Clone)]
|
||||
pub struct EditorSettings {
|
||||
pub cursor_blink: bool,
|
||||
@@ -539,7 +541,7 @@ pub struct ScrollbarContent {
|
||||
}
|
||||
|
||||
/// Minimap related settings
|
||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
pub struct MinimapContent {
|
||||
/// When to show the minimap in the editor.
|
||||
///
|
||||
@@ -770,5 +772,32 @@ impl Settings for EditorSettings {
|
||||
let search = current.search.get_or_insert_default();
|
||||
search.include_ignored = use_ignored;
|
||||
}
|
||||
|
||||
let mut minimap = MinimapContent::default();
|
||||
let minimap_enabled = vscode.read_bool("editor.minimap.enabled").unwrap_or(true);
|
||||
let autohide = vscode.read_bool("editor.minimap.autohide");
|
||||
if minimap_enabled {
|
||||
if let Some(false) = autohide {
|
||||
minimap.show = Some(ShowMinimap::Always);
|
||||
} else {
|
||||
minimap.show = Some(ShowMinimap::Auto);
|
||||
}
|
||||
} else {
|
||||
minimap.show = Some(ShowMinimap::Never);
|
||||
}
|
||||
|
||||
vscode.enum_setting(
|
||||
"editor.minimap.showSlider",
|
||||
&mut minimap.thumb,
|
||||
|s| match s {
|
||||
"always" => Some(MinimapThumb::Always),
|
||||
"mouseover" => Some(MinimapThumb::Hover),
|
||||
_ => None,
|
||||
},
|
||||
);
|
||||
|
||||
if minimap != MinimapContent::default() {
|
||||
current.minimap = Some(minimap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18716,7 +18716,7 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18741,7 +18741,7 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18763,7 +18763,7 @@ async fn test_breakpoint_toggling(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18830,7 +18830,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18851,7 +18851,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18871,7 +18871,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18894,7 +18894,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -18917,7 +18917,7 @@ async fn test_log_breakpoint_editing(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -19010,7 +19010,7 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -19042,7 +19042,7 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
@@ -19078,7 +19078,7 @@ async fn test_breakpoint_enabling_and_disabling(cx: &mut TestAppContext) {
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.read(cx)
|
||||
.all_breakpoints(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
.clone()
|
||||
});
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ use multi_buffer::{
|
||||
|
||||
use project::{
|
||||
ProjectPath,
|
||||
debugger::breakpoint_store::Breakpoint,
|
||||
debugger::breakpoint_store::{Breakpoint, BreakpointSessionState},
|
||||
project_settings::{GitGutterSetting, GitHunkStyleSetting, ProjectSettings},
|
||||
};
|
||||
use settings::Settings;
|
||||
@@ -2317,7 +2317,7 @@ impl EditorElement {
|
||||
gutter_hitbox: &Hitbox,
|
||||
display_hunks: &[(DisplayDiffHunk, Option<Hitbox>)],
|
||||
snapshot: &EditorSnapshot,
|
||||
breakpoints: HashMap<DisplayRow, (Anchor, Breakpoint)>,
|
||||
breakpoints: HashMap<DisplayRow, (Anchor, Breakpoint, Option<BreakpointSessionState>)>,
|
||||
row_infos: &[RowInfo],
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -2325,7 +2325,7 @@ impl EditorElement {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
breakpoints
|
||||
.into_iter()
|
||||
.filter_map(|(display_row, (text_anchor, bp))| {
|
||||
.filter_map(|(display_row, (text_anchor, bp, state))| {
|
||||
if row_infos
|
||||
.get((display_row.0.saturating_sub(range.start.0)) as usize)
|
||||
.is_some_and(|row_info| {
|
||||
@@ -2348,7 +2348,7 @@ impl EditorElement {
|
||||
return None;
|
||||
}
|
||||
|
||||
let button = editor.render_breakpoint(text_anchor, display_row, &bp, cx);
|
||||
let button = editor.render_breakpoint(text_anchor, display_row, &bp, state, cx);
|
||||
|
||||
let button = prepaint_gutter_button(
|
||||
button,
|
||||
@@ -2378,7 +2378,7 @@ impl EditorElement {
|
||||
gutter_hitbox: &Hitbox,
|
||||
display_hunks: &[(DisplayDiffHunk, Option<Hitbox>)],
|
||||
snapshot: &EditorSnapshot,
|
||||
breakpoints: &mut HashMap<DisplayRow, (Anchor, Breakpoint)>,
|
||||
breakpoints: &mut HashMap<DisplayRow, (Anchor, Breakpoint, Option<BreakpointSessionState>)>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Vec<AnyElement> {
|
||||
@@ -7437,8 +7437,10 @@ impl Element for EditorElement {
|
||||
editor.active_breakpoints(start_row..end_row, window, cx)
|
||||
});
|
||||
if cx.has_flag::<DebuggerFeatureFlag>() {
|
||||
for display_row in breakpoint_rows.keys() {
|
||||
active_rows.entry(*display_row).or_default().breakpoint = true;
|
||||
for (display_row, (_, bp, state)) in &breakpoint_rows {
|
||||
if bp.is_enabled() && state.is_none_or(|s| s.verified) {
|
||||
active_rows.entry(*display_row).or_default().breakpoint = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7478,7 +7480,7 @@ impl Element for EditorElement {
|
||||
let breakpoint = Breakpoint::new_standard();
|
||||
phantom_breakpoint.collides_with_existing_breakpoint =
|
||||
false;
|
||||
(position, breakpoint)
|
||||
(position, breakpoint, None)
|
||||
});
|
||||
}
|
||||
})
|
||||
|
||||
@@ -30,6 +30,7 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
dirs.workspace = true
|
||||
dotenv.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -422,6 +422,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
|
||||
let extension_host_proxy = ExtensionHostProxy::global(cx);
|
||||
|
||||
language::init(cx);
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language_extension::init(extension_host_proxy.clone(), languages.clone());
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
|
||||
@@ -141,6 +141,7 @@ pub trait Extension: Send + Sync + 'static {
|
||||
dap_name: Arc<str>,
|
||||
config: DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
worktree: Arc<dyn WorktreeDelegate>,
|
||||
) -> Result<DebugAdapterBinary>;
|
||||
}
|
||||
|
||||
|
||||
@@ -87,6 +87,8 @@ pub struct ExtensionManifest {
|
||||
pub snippets: Option<PathBuf>,
|
||||
#[serde(default)]
|
||||
pub capabilities: Vec<ExtensionCapability>,
|
||||
#[serde(default)]
|
||||
pub debug_adapters: Vec<Arc<str>>,
|
||||
}
|
||||
|
||||
impl ExtensionManifest {
|
||||
@@ -274,6 +276,7 @@ fn manifest_from_old_manifest(
|
||||
indexed_docs_providers: BTreeMap::default(),
|
||||
snippets: None,
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,6 +304,7 @@ mod tests {
|
||||
indexed_docs_providers: BTreeMap::default(),
|
||||
snippets: None,
|
||||
capabilities: vec![],
|
||||
debug_adapters: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,11 @@ pub use wit::{
|
||||
KeyValueStore, LanguageServerInstallationStatus, Project, Range, Worktree, download_file,
|
||||
make_file_executable,
|
||||
zed::extension::context_server::ContextServerConfiguration,
|
||||
zed::extension::dap::{
|
||||
DebugAdapterBinary, DebugRequest, DebugTaskDefinition, StartDebuggingRequestArguments,
|
||||
StartDebuggingRequestArgumentsRequest, TcpArguments, TcpArgumentsTemplate,
|
||||
resolve_tcp_template,
|
||||
},
|
||||
zed::extension::github::{
|
||||
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
|
||||
latest_github_release,
|
||||
@@ -194,6 +199,7 @@ pub trait Extension: Send + Sync {
|
||||
_adapter_name: String,
|
||||
_config: DebugTaskDefinition,
|
||||
_user_provided_path: Option<String>,
|
||||
_worktree: &Worktree,
|
||||
) -> Result<DebugAdapterBinary, String> {
|
||||
Err("`get_dap_binary` not implemented".to_string())
|
||||
}
|
||||
@@ -386,8 +392,9 @@ impl wit::Guest for Component {
|
||||
adapter_name: String,
|
||||
config: DebugTaskDefinition,
|
||||
user_installed_path: Option<String>,
|
||||
) -> Result<DebugAdapterBinary, String> {
|
||||
extension().get_dap_binary(adapter_name, config, user_installed_path)
|
||||
worktree: &Worktree,
|
||||
) -> Result<wit::DebugAdapterBinary, String> {
|
||||
extension().get_dap_binary(adapter_name, config, user_installed_path, worktree)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
interface dap {
|
||||
use common.{env-vars};
|
||||
|
||||
/// Resolves a specified TcpArgumentsTemplate into TcpArguments
|
||||
resolve-tcp-template: func(template: tcp-arguments-template) -> result<tcp-arguments, string>;
|
||||
|
||||
record launch-request {
|
||||
program: string,
|
||||
cwd: option<string>,
|
||||
|
||||
@@ -11,7 +11,7 @@ world extension {
|
||||
|
||||
use common.{env-vars, range};
|
||||
use context-server.{context-server-configuration};
|
||||
use dap.{debug-adapter-binary, debug-task-definition};
|
||||
use dap.{debug-adapter-binary, debug-task-definition, debug-request};
|
||||
use lsp.{completion, symbol};
|
||||
use process.{command};
|
||||
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
|
||||
@@ -157,5 +157,5 @@ world extension {
|
||||
export index-docs: func(provider-name: string, package-name: string, database: borrow<key-value-store>) -> result<_, string>;
|
||||
|
||||
/// Returns a configured debug adapter binary for a given debug task.
|
||||
export get-dap-binary: func(adapter-name: string, config: debug-task-definition, user-installed-path: option<string>) -> result<debug-adapter-binary, string>;
|
||||
export get-dap-binary: func(adapter-name: string, config: debug-task-definition, user-installed-path: option<string>, worktree: borrow<worktree>) -> result<debug-adapter-binary, string>;
|
||||
}
|
||||
|
||||
@@ -138,6 +138,7 @@ fn manifest() -> ExtensionManifest {
|
||||
command: "echo".into(),
|
||||
args: vec!["hello!".into()],
|
||||
}],
|
||||
debug_adapters: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@ use collections::{BTreeMap, BTreeSet, HashMap, HashSet, btree_map};
|
||||
pub use extension::ExtensionManifest;
|
||||
use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder};
|
||||
use extension::{
|
||||
ExtensionContextServerProxy, ExtensionEvents, ExtensionGrammarProxy, ExtensionHostProxy,
|
||||
ExtensionIndexedDocsProviderProxy, ExtensionLanguageProxy, ExtensionLanguageServerProxy,
|
||||
ExtensionSlashCommandProxy, ExtensionSnippetProxy, ExtensionThemeProxy,
|
||||
ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionIndexedDocsProviderProxy,
|
||||
ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy,
|
||||
ExtensionSnippetProxy, ExtensionThemeProxy,
|
||||
};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::{
|
||||
@@ -1328,6 +1329,11 @@ impl ExtensionStore {
|
||||
this.proxy
|
||||
.register_indexed_docs_provider(extension.clone(), provider_id.clone());
|
||||
}
|
||||
|
||||
for debug_adapter in &manifest.debug_adapters {
|
||||
this.proxy
|
||||
.register_debug_adapter(extension.clone(), debug_adapter.clone());
|
||||
}
|
||||
}
|
||||
|
||||
this.wasm_extensions.extend(wasm_extensions);
|
||||
|
||||
@@ -164,6 +164,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
indexed_docs_providers: BTreeMap::default(),
|
||||
snippets: None,
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
@@ -193,6 +194,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
indexed_docs_providers: BTreeMap::default(),
|
||||
snippets: None,
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
@@ -367,6 +369,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
indexed_docs_providers: BTreeMap::default(),
|
||||
snippets: None,
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
|
||||
@@ -379,11 +379,13 @@ impl extension::Extension for WasmExtension {
|
||||
dap_name: Arc<str>,
|
||||
config: DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
worktree: Arc<dyn WorktreeDelegate>,
|
||||
) -> Result<DebugAdapterBinary> {
|
||||
self.call(|extension, store| {
|
||||
async move {
|
||||
let resource = store.data_mut().table().push(worktree)?;
|
||||
let dap_binary = extension
|
||||
.call_get_dap_binary(store, dap_name, config, user_installed_path)
|
||||
.call_get_dap_binary(store, dap_name, config, user_installed_path, resource)
|
||||
.await?
|
||||
.map_err(|err| anyhow!("{err:?}"))?;
|
||||
let dap_binary = dap_binary.try_into()?;
|
||||
|
||||
@@ -903,6 +903,7 @@ impl Extension {
|
||||
adapter_name: Arc<str>,
|
||||
task: DebugTaskDefinition,
|
||||
user_installed_path: Option<PathBuf>,
|
||||
resource: Resource<Arc<dyn WorktreeDelegate>>,
|
||||
) -> Result<Result<DebugAdapterBinary, String>> {
|
||||
match self {
|
||||
Extension::V0_6_0(ext) => {
|
||||
@@ -912,6 +913,7 @@ impl Extension {
|
||||
&adapter_name,
|
||||
&task.try_into()?,
|
||||
user_installed_path.as_ref().and_then(|p| p.to_str()),
|
||||
resource,
|
||||
)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
@@ -48,7 +48,7 @@ wasmtime::component::bindgen!({
|
||||
pub use self::zed::extension::*;
|
||||
|
||||
mod settings {
|
||||
include!(concat!(env!("OUT_DIR"), "/since_v0.5.0/settings.rs"));
|
||||
include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs"));
|
||||
}
|
||||
|
||||
pub type ExtensionWorktree = Arc<dyn WorktreeDelegate>;
|
||||
@@ -729,8 +729,29 @@ impl slash_command::Host for WasmState {}
|
||||
#[async_trait]
|
||||
impl context_server::Host for WasmState {}
|
||||
|
||||
#[async_trait]
|
||||
impl dap::Host for WasmState {}
|
||||
impl dap::Host for WasmState {
|
||||
async fn resolve_tcp_template(
|
||||
&mut self,
|
||||
template: TcpArgumentsTemplate,
|
||||
) -> wasmtime::Result<Result<TcpArguments, String>> {
|
||||
maybe!(async {
|
||||
let (host, port, timeout) =
|
||||
::dap::configure_tcp_connection(task::TcpArgumentsTemplate {
|
||||
port: template.port,
|
||||
host: template.host.map(Ipv4Addr::from_bits),
|
||||
timeout: template.timeout,
|
||||
})
|
||||
.await?;
|
||||
Ok(TcpArguments {
|
||||
port,
|
||||
host: host.to_bits(),
|
||||
timeout,
|
||||
})
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionImports for WasmState {
|
||||
async fn get_settings(
|
||||
|
||||
@@ -5,16 +5,17 @@ use editor::{
|
||||
display_map::{BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId},
|
||||
};
|
||||
use gpui::{
|
||||
App, Context, Entity, InteractiveElement as _, ParentElement as _, Subscription, WeakEntity,
|
||||
App, Context, Entity, InteractiveElement as _, ParentElement as _, Subscription, Task,
|
||||
WeakEntity,
|
||||
};
|
||||
use language::{Anchor, Buffer, BufferId};
|
||||
use project::{ConflictRegion, ConflictSet, ConflictSetUpdate};
|
||||
use project::{ConflictRegion, ConflictSet, ConflictSetUpdate, ProjectItem as _};
|
||||
use std::{ops::Range, sync::Arc};
|
||||
use ui::{
|
||||
ActiveTheme, AnyElement, Element as _, StatefulInteractiveElement, Styled,
|
||||
StyledTypography as _, div, h_flex, rems,
|
||||
StyledTypography as _, Window, div, h_flex, rems,
|
||||
};
|
||||
use util::{debug_panic, maybe};
|
||||
use util::{ResultExt as _, debug_panic, maybe};
|
||||
|
||||
pub(crate) struct ConflictAddon {
|
||||
buffers: HashMap<BufferId, BufferConflicts>,
|
||||
@@ -404,8 +405,16 @@ fn render_conflict_buttons(
|
||||
let editor = editor.clone();
|
||||
let conflict = conflict.clone();
|
||||
let ours = conflict.ours.clone();
|
||||
move |_, _, cx| {
|
||||
resolve_conflict(editor.clone(), excerpt_id, &conflict, &[ours.clone()], cx)
|
||||
move |_, window, cx| {
|
||||
resolve_conflict(
|
||||
editor.clone(),
|
||||
excerpt_id,
|
||||
conflict.clone(),
|
||||
vec![ours.clone()],
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach()
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -422,14 +431,16 @@ fn render_conflict_buttons(
|
||||
let editor = editor.clone();
|
||||
let conflict = conflict.clone();
|
||||
let theirs = conflict.theirs.clone();
|
||||
move |_, _, cx| {
|
||||
move |_, window, cx| {
|
||||
resolve_conflict(
|
||||
editor.clone(),
|
||||
excerpt_id,
|
||||
&conflict,
|
||||
&[theirs.clone()],
|
||||
conflict.clone(),
|
||||
vec![theirs.clone()],
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach()
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -447,69 +458,101 @@ fn render_conflict_buttons(
|
||||
let conflict = conflict.clone();
|
||||
let ours = conflict.ours.clone();
|
||||
let theirs = conflict.theirs.clone();
|
||||
move |_, _, cx| {
|
||||
move |_, window, cx| {
|
||||
resolve_conflict(
|
||||
editor.clone(),
|
||||
excerpt_id,
|
||||
&conflict,
|
||||
&[ours.clone(), theirs.clone()],
|
||||
conflict.clone(),
|
||||
vec![ours.clone(), theirs.clone()],
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.detach()
|
||||
}
|
||||
}),
|
||||
)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn resolve_conflict(
|
||||
pub(crate) fn resolve_conflict(
|
||||
editor: WeakEntity<Editor>,
|
||||
excerpt_id: ExcerptId,
|
||||
resolved_conflict: &ConflictRegion,
|
||||
ranges: &[Range<Anchor>],
|
||||
resolved_conflict: ConflictRegion,
|
||||
ranges: Vec<Range<Anchor>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(editor) = editor.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let multibuffer = editor.read(cx).buffer().read(cx);
|
||||
let snapshot = multibuffer.snapshot(cx);
|
||||
let Some(buffer) = resolved_conflict
|
||||
.ours
|
||||
.end
|
||||
.buffer_id
|
||||
.and_then(|buffer_id| multibuffer.buffer(buffer_id))
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
resolved_conflict.resolve(buffer, ranges, cx);
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
let conflict_addon = editor.addon_mut::<ConflictAddon>().unwrap();
|
||||
let Some(state) = conflict_addon.buffers.get_mut(&buffer_snapshot.remote_id()) else {
|
||||
) -> Task<()> {
|
||||
window.spawn(cx, async move |cx| {
|
||||
let Some((workspace, project, multibuffer, buffer)) = editor
|
||||
.update(cx, |editor, cx| {
|
||||
let workspace = editor.workspace()?;
|
||||
let project = editor.project.clone()?;
|
||||
let multibuffer = editor.buffer().clone();
|
||||
let buffer_id = resolved_conflict.ours.end.buffer_id?;
|
||||
let buffer = multibuffer.read(cx).buffer(buffer_id)?;
|
||||
resolved_conflict.resolve(buffer.clone(), &ranges, cx);
|
||||
let conflict_addon = editor.addon_mut::<ConflictAddon>().unwrap();
|
||||
let snapshot = multibuffer.read(cx).snapshot(cx);
|
||||
let buffer_snapshot = buffer.read(cx).snapshot();
|
||||
let state = conflict_addon
|
||||
.buffers
|
||||
.get_mut(&buffer_snapshot.remote_id())?;
|
||||
let ix = state
|
||||
.block_ids
|
||||
.binary_search_by(|(range, _)| {
|
||||
range
|
||||
.start
|
||||
.cmp(&resolved_conflict.range.start, &buffer_snapshot)
|
||||
})
|
||||
.ok()?;
|
||||
let &(_, block_id) = &state.block_ids[ix];
|
||||
let start = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, resolved_conflict.range.start)
|
||||
.unwrap();
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, resolved_conflict.range.end)
|
||||
.unwrap();
|
||||
editor.remove_highlighted_rows::<ConflictsOuter>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsOurs>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsTheirs>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsOursMarker>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsTheirsMarker>(vec![start..end], cx);
|
||||
editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
Some((workspace, project, multibuffer, buffer))
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let Ok(ix) = state.block_ids.binary_search_by(|(range, _)| {
|
||||
range
|
||||
.start
|
||||
.cmp(&resolved_conflict.range.start, &buffer_snapshot)
|
||||
}) else {
|
||||
let Some(save) = project
|
||||
.update(cx, |project, cx| {
|
||||
if multibuffer.read(cx).all_diff_hunks_expanded() {
|
||||
project.save_buffer(buffer.clone(), cx)
|
||||
} else {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let &(_, block_id) = &state.block_ids[ix];
|
||||
let start = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, resolved_conflict.range.start)
|
||||
.unwrap();
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, resolved_conflict.range.end)
|
||||
.unwrap();
|
||||
editor.remove_highlighted_rows::<ConflictsOuter>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsOurs>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsTheirs>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsOursMarker>(vec![start..end], cx);
|
||||
editor.remove_highlighted_rows::<ConflictsTheirsMarker>(vec![start..end], cx);
|
||||
editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
if save.await.log_err().is_none() {
|
||||
let open_path = maybe!({
|
||||
let path = buffer
|
||||
.read_with(cx, |buffer, cx| buffer.project_path(cx))
|
||||
.ok()
|
||||
.flatten()?;
|
||||
workspace
|
||||
.update_in(cx, |workspace, window, cx| {
|
||||
workspace.open_path_preview(path, None, false, false, false, window, cx)
|
||||
})
|
||||
.ok()
|
||||
});
|
||||
|
||||
if let Some(open_path) = open_path {
|
||||
open_path.await.log_err();
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -148,6 +148,17 @@ impl ProjectDiff {
|
||||
});
|
||||
diff_display_editor
|
||||
});
|
||||
window.defer(cx, {
|
||||
let workspace = workspace.clone();
|
||||
let editor = editor.clone();
|
||||
move |window, cx| {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.added_to_workspace(workspace, window, cx);
|
||||
})
|
||||
});
|
||||
}
|
||||
});
|
||||
cx.subscribe_in(&editor, window, Self::handle_editor_event)
|
||||
.detach();
|
||||
|
||||
@@ -1323,6 +1334,7 @@ fn merge_anchor_ranges<'a>(
|
||||
mod tests {
|
||||
use db::indoc;
|
||||
use editor::test::editor_test_context::{EditorTestContext, assert_state_with_diff};
|
||||
use git::status::{UnmergedStatus, UnmergedStatusCode};
|
||||
use gpui::TestAppContext;
|
||||
use project::FakeFs;
|
||||
use serde_json::json;
|
||||
@@ -1583,7 +1595,10 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
use crate::project_diff::{self, ProjectDiff};
|
||||
use crate::{
|
||||
conflict_view::resolve_conflict,
|
||||
project_diff::{self, ProjectDiff},
|
||||
};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_go_to_prev_hunk_multibuffer(cx: &mut TestAppContext) {
|
||||
@@ -1754,4 +1769,80 @@ mod tests {
|
||||
|
||||
cx.assert_excerpts_with_selections(&format!("[EXCERPT]\nˇ{git_contents}"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_saving_resolved_conflicts(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/project"),
|
||||
json!({
|
||||
".git": {},
|
||||
"foo": "<<<<<<< x\nours\n=======\ntheirs\n>>>>>>> y\n",
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
fs.set_status_for_repo(
|
||||
Path::new(path!("/project/.git")),
|
||||
&[(
|
||||
Path::new("foo"),
|
||||
UnmergedStatus {
|
||||
first_head: UnmergedStatusCode::Updated,
|
||||
second_head: UnmergedStatusCode::Updated,
|
||||
}
|
||||
.into(),
|
||||
)],
|
||||
);
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
let diff = cx.new_window_entity(|window, cx| {
|
||||
ProjectDiff::new(project.clone(), workspace, window, cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
cx.update(|window, cx| {
|
||||
let editor = diff.read(cx).editor.clone();
|
||||
let excerpt_ids = editor.read(cx).buffer().read(cx).excerpt_ids();
|
||||
assert_eq!(excerpt_ids.len(), 1);
|
||||
let excerpt_id = excerpt_ids[0];
|
||||
let buffer = editor
|
||||
.read(cx)
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.all_buffers()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap();
|
||||
let buffer_id = buffer.read(cx).remote_id();
|
||||
let conflict_set = diff
|
||||
.read(cx)
|
||||
.editor
|
||||
.read(cx)
|
||||
.addon::<ConflictAddon>()
|
||||
.unwrap()
|
||||
.conflict_set(buffer_id)
|
||||
.unwrap();
|
||||
assert!(conflict_set.read(cx).has_conflict);
|
||||
let snapshot = conflict_set.read(cx).snapshot();
|
||||
assert_eq!(snapshot.conflicts.len(), 1);
|
||||
|
||||
let ours_range = snapshot.conflicts[0].ours.clone();
|
||||
|
||||
resolve_conflict(
|
||||
editor.downgrade(),
|
||||
excerpt_id,
|
||||
snapshot.conflicts[0].clone(),
|
||||
vec![ours_range],
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
let contents = fs.read_file_sync(path!("/project/foo")).unwrap();
|
||||
let contents = String::from_utf8(contents).unwrap();
|
||||
assert_eq!(contents, "ours\n");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +177,7 @@ impl Render for ImageLoadingExample {
|
||||
)
|
||||
.to_path_buf();
|
||||
img(image_source.clone())
|
||||
.id("image-1")
|
||||
.id("image-4")
|
||||
.border_1()
|
||||
.size_12()
|
||||
.with_fallback(|| Self::fallback_element().into_any_element())
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{
|
||||
any::{TypeId, type_name},
|
||||
cell::{Ref, RefCell, RefMut},
|
||||
cell::{BorrowMutError, Ref, RefCell, RefMut},
|
||||
marker::PhantomData,
|
||||
mem,
|
||||
ops::{Deref, DerefMut},
|
||||
@@ -79,6 +79,16 @@ impl AppCell {
|
||||
}
|
||||
AppRefMut(self.app.borrow_mut())
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[track_caller]
|
||||
pub fn try_borrow_mut(&self) -> Result<AppRefMut, BorrowMutError> {
|
||||
if option_env!("TRACK_THREAD_BORROWS").is_some() {
|
||||
let thread_id = std::thread::current().id();
|
||||
eprintln!("borrowed {thread_id:?}");
|
||||
}
|
||||
Ok(AppRefMut(self.app.try_borrow_mut()?))
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
||||
@@ -88,7 +88,7 @@ impl AppContext for AsyncApp {
|
||||
F: FnOnce(AnyView, &mut Window, &mut App) -> T,
|
||||
{
|
||||
let app = self.app.upgrade().context("app was released")?;
|
||||
let mut lock = app.borrow_mut();
|
||||
let mut lock = app.try_borrow_mut()?;
|
||||
lock.update_window(window, f)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,19 @@ pub trait FluentBuilder {
|
||||
self.map(|this| if condition { then(this) } else { this })
|
||||
}
|
||||
|
||||
/// Conditionally modify self with the given closure.
|
||||
fn when_else(
|
||||
self,
|
||||
condition: bool,
|
||||
then: impl FnOnce(Self) -> Self,
|
||||
else_fn: impl FnOnce(Self) -> Self,
|
||||
) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.map(|this| if condition { then(this) } else { else_fn(this) })
|
||||
}
|
||||
|
||||
/// Conditionally unwrap and modify self with the given closure, if the given option is Some.
|
||||
fn when_some<T>(self, option: Option<T>, then: impl FnOnce(Self, T) -> Self) -> Self
|
||||
where
|
||||
|
||||
@@ -21,6 +21,7 @@ db.workspace = true
|
||||
editor.workspace = true
|
||||
file_icons.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
@@ -11,6 +11,7 @@ use gpui::{
|
||||
InteractiveElement, IntoElement, ObjectFit, ParentElement, Render, Styled, Task, WeakEntity,
|
||||
Window, canvas, div, fill, img, opaque_grey, point, size,
|
||||
};
|
||||
use language::File as _;
|
||||
use persistence::IMAGE_VIEWER;
|
||||
use project::{ImageItem, Project, ProjectPath, image_store::ImageItemEvent};
|
||||
use settings::Settings;
|
||||
@@ -104,7 +105,7 @@ impl Item for ImageView {
|
||||
}
|
||||
|
||||
fn tab_tooltip_text(&self, cx: &App) -> Option<SharedString> {
|
||||
let abs_path = self.image_item.read(cx).file.as_local()?.abs_path(cx);
|
||||
let abs_path = self.image_item.read(cx).abs_path(cx)?;
|
||||
let file_path = abs_path.compact().to_string_lossy().to_string();
|
||||
Some(file_path.into())
|
||||
}
|
||||
@@ -149,10 +150,10 @@ impl Item for ImageView {
|
||||
}
|
||||
|
||||
fn tab_icon(&self, _: &Window, cx: &App) -> Option<Icon> {
|
||||
let path = self.image_item.read(cx).path();
|
||||
let path = self.image_item.read(cx).abs_path(cx)?;
|
||||
ItemSettings::get_global(cx)
|
||||
.file_icons
|
||||
.then(|| FileIcons::get_icon(path, cx))
|
||||
.then(|| FileIcons::get_icon(&path, cx))
|
||||
.flatten()
|
||||
.map(Icon::from_path)
|
||||
}
|
||||
@@ -274,7 +275,7 @@ impl SerializableItem for ImageView {
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<gpui::Result<()>>> {
|
||||
let workspace_id = workspace.database_id()?;
|
||||
let image_path = self.image_item.read(cx).file.as_local()?.abs_path(cx);
|
||||
let image_path = self.image_item.read(cx).abs_path(cx)?;
|
||||
|
||||
Some(cx.background_spawn({
|
||||
async move {
|
||||
|
||||
@@ -83,6 +83,13 @@ impl EditPredictionUsage {
|
||||
|
||||
Ok(Self { limit, amount })
|
||||
}
|
||||
|
||||
pub fn over_limit(&self) -> bool {
|
||||
match self.limit {
|
||||
UsageLimit::Limited(limit) => self.amount >= limit,
|
||||
UsageLimit::Unlimited => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait EditPredictionProvider: 'static + Sized {
|
||||
|
||||
@@ -33,7 +33,7 @@ use workspace::{
|
||||
StatusItemView, Toast, Workspace, create_and_open_local_file, item::ItemHandle,
|
||||
notifications::NotificationId,
|
||||
};
|
||||
use zed_actions::OpenBrowser;
|
||||
use zed_actions::{OpenBrowser, OpenZedUrl};
|
||||
use zed_llm_client::UsageLimit;
|
||||
use zeta::RateCompletions;
|
||||
|
||||
@@ -277,14 +277,31 @@ impl Render for InlineCompletionButton {
|
||||
);
|
||||
}
|
||||
|
||||
let mut over_limit = false;
|
||||
|
||||
if let Some(usage) = self
|
||||
.edit_prediction_provider
|
||||
.as_ref()
|
||||
.and_then(|provider| provider.usage(cx))
|
||||
{
|
||||
over_limit = usage.over_limit()
|
||||
}
|
||||
|
||||
let show_editor_predictions = self.editor_show_predictions;
|
||||
|
||||
let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
|
||||
.shape(IconButtonShape::Square)
|
||||
.when(enabled && !show_editor_predictions, |this| {
|
||||
this.indicator(Indicator::dot().color(Color::Muted))
|
||||
.when(
|
||||
enabled && (!show_editor_predictions || over_limit),
|
||||
|this| {
|
||||
this.indicator(Indicator::dot().when_else(
|
||||
over_limit,
|
||||
|dot| dot.color(Color::Error),
|
||||
|dot| dot.color(Color::Muted),
|
||||
))
|
||||
.indicator_border_color(Some(cx.theme().colors().status_bar_background))
|
||||
})
|
||||
},
|
||||
)
|
||||
.when(!self.popover_menu_handle.is_deployed(), |element| {
|
||||
element.tooltip(move |window, cx| {
|
||||
if enabled {
|
||||
@@ -440,6 +457,16 @@ impl InlineCompletionButton {
|
||||
},
|
||||
move |_, cx| cx.open_url(&zed_urls::account_url(cx)),
|
||||
)
|
||||
.when(usage.over_limit(), |menu| -> ContextMenu {
|
||||
menu.entry("Subscribe to increase your limit", None, |window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(OpenZedUrl {
|
||||
url: zed_urls::account_url(cx),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})
|
||||
})
|
||||
.separator();
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ pub trait ToolchainLister: Send + Sync {
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
pub trait LanguageToolchainStore {
|
||||
pub trait LanguageToolchainStore: Send + Sync + 'static {
|
||||
async fn active_toolchain(
|
||||
self: Arc<Self>,
|
||||
worktree_id: WorktreeId,
|
||||
|
||||
@@ -2,6 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
@@ -11,13 +12,13 @@ use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter, Role,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
|
||||
use futures::stream::BoxStream;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
@@ -26,6 +27,9 @@ use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
const PROVIDER_ID: &str = "mistral";
|
||||
const PROVIDER_NAME: &str = "Mistral";
|
||||
|
||||
@@ -43,6 +47,7 @@ pub struct AvailableModel {
|
||||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub supports_tools: Option<bool>,
|
||||
}
|
||||
|
||||
pub struct MistralLanguageModelProvider {
|
||||
@@ -209,6 +214,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -300,14 +306,14 @@ impl LanguageModel for MistralLanguageModel {
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -368,26 +374,8 @@ impl LanguageModel for MistralLanguageModel {
|
||||
|
||||
async move {
|
||||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
let mapper = MistralEventMapper::new();
|
||||
Ok(mapper.map_stream(stream).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
@@ -398,33 +386,87 @@ pub fn into_mistral(
|
||||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> mistral::Request {
|
||||
let len = request.messages.len();
|
||||
let merged_messages =
|
||||
request
|
||||
.messages
|
||||
.into_iter()
|
||||
.fold(Vec::with_capacity(len), |mut acc, msg| {
|
||||
let role = msg.role;
|
||||
let content = msg.string_contents();
|
||||
let stream = true;
|
||||
|
||||
acc.push(match role {
|
||||
Role::User => mistral::RequestMessage::User { content },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content },
|
||||
});
|
||||
acc
|
||||
});
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match message.role {
|
||||
Role::User => mistral::RequestMessage::User { content: text },
|
||||
Role::Assistant => mistral::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => mistral::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = mistral::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: mistral::ToolCallContent::Function {
|
||||
function: mistral::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(mistral::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => text.to_string(),
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// TODO: Mistral image support
|
||||
"[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(mistral::RequestMessage::Tool {
|
||||
content,
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mistral::Request {
|
||||
model,
|
||||
messages: merged_messages,
|
||||
stream: true,
|
||||
messages,
|
||||
stream,
|
||||
max_tokens: max_output_tokens,
|
||||
temperature: request.temperature,
|
||||
response_format: None,
|
||||
tool_choice: match request.tool_choice {
|
||||
Some(LanguageModelToolChoice::Auto) if !request.tools.is_empty() => {
|
||||
Some(mistral::ToolChoice::Auto)
|
||||
}
|
||||
Some(LanguageModelToolChoice::Any) if !request.tools.is_empty() => {
|
||||
Some(mistral::ToolChoice::Any)
|
||||
}
|
||||
Some(LanguageModelToolChoice::None) => Some(mistral::ToolChoice::None),
|
||||
_ if !request.tools.is_empty() => Some(mistral::ToolChoice::Auto),
|
||||
_ => None,
|
||||
},
|
||||
parallel_tool_calls: if !request.tools.is_empty() {
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
@@ -439,6 +481,127 @@ pub fn into_mistral(
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MistralEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
impl MistralEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
|
||||
) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: mistral::StreamResponse,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = choice.finish_reason.as_deref() {
|
||||
match finish_reason {
|
||||
"stop" => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
"tool_calls" => {
|
||||
events.extend(self.process_tool_calls());
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
unexpected => {
|
||||
log::error!("Unexpected Mistral stop_reason: {unexpected:?}");
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
fn process_tool_calls(
|
||||
&mut self,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (_, tool_call) in self.tool_calls_by_index.drain() {
|
||||
if tool_call.id.is_empty() || tool_call.name.is_empty() {
|
||||
results.push(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Received incomplete tool call: missing id or name"
|
||||
))));
|
||||
continue;
|
||||
}
|
||||
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => results.push(Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments,
|
||||
},
|
||||
))),
|
||||
Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
@@ -623,3 +786,65 @@ impl Render for ConfigurationView {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use language_model;
|
||||
|
||||
#[test]
|
||||
fn test_into_mistral_conversion() {
|
||||
let request = language_model::LanguageModelRequest {
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::System,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![language_model::MessageContent::Text(
|
||||
"Hello, how are you?".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
temperature: Some(0.7),
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
stop: Vec::new(),
|
||||
};
|
||||
|
||||
let model_name = "mistral-medium-latest".to_string();
|
||||
let max_output_tokens = Some(1000);
|
||||
let mistral_request = into_mistral(request, model_name, max_output_tokens);
|
||||
|
||||
assert_eq!(mistral_request.model, "mistral-medium-latest");
|
||||
assert_eq!(mistral_request.temperature, Some(0.7));
|
||||
assert_eq!(mistral_request.max_tokens, Some(1000));
|
||||
assert!(mistral_request.stream);
|
||||
assert!(mistral_request.tools.is_empty());
|
||||
assert!(mistral_request.tool_choice.is_none());
|
||||
|
||||
assert_eq!(mistral_request.messages.len(), 2);
|
||||
|
||||
match &mistral_request.messages[0] {
|
||||
mistral::RequestMessage::System { content } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("Expected System message"),
|
||||
}
|
||||
|
||||
match &mistral_request.messages[1] {
|
||||
mistral::RequestMessage::User { content } => {
|
||||
assert_eq!(content, "Hello, how are you?");
|
||||
}
|
||||
_ => panic!("Expected User message"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -583,6 +583,7 @@ impl MarkdownElement {
|
||||
if phase.bubble() {
|
||||
if let Some(link) = rendered_text.link_for_position(event.position) {
|
||||
markdown.pressed_link = Some(link.clone());
|
||||
window.prevent_default();
|
||||
} else {
|
||||
let source_index =
|
||||
match rendered_text.source_index_for_position(event.position) {
|
||||
@@ -601,10 +602,10 @@ impl MarkdownElement {
|
||||
reversed: false,
|
||||
pending: true,
|
||||
};
|
||||
window.prevent_default();
|
||||
window.focus(&markdown.focus_handle);
|
||||
}
|
||||
|
||||
window.prevent_default();
|
||||
cx.notify();
|
||||
}
|
||||
} else if phase.capture() {
|
||||
|
||||
@@ -67,6 +67,7 @@ pub enum Model {
|
||||
max_tokens: usize,
|
||||
max_output_tokens: Option<u32>,
|
||||
max_completion_tokens: Option<u32>,
|
||||
supports_tools: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -133,6 +134,18 @@ impl Model {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_tools(&self) -> bool {
|
||||
match self {
|
||||
Self::CodestralLatest
|
||||
| Self::MistralLargeLatest
|
||||
| Self::MistralMediumLatest
|
||||
| Self::MistralSmallLatest
|
||||
| Self::OpenMistralNemo
|
||||
| Self::OpenCodestralMamba => true,
|
||||
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -146,6 +159,10 @@ pub struct Request {
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
}
|
||||
@@ -190,12 +207,13 @@ pub enum Prediction {
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
Other(ToolDefinition),
|
||||
Any,
|
||||
Function(ToolDefinition),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
//!
|
||||
//! Breakpoints are separate from a session because they're not associated with any particular debug session. They can also be set up without a session running.
|
||||
use anyhow::{Result, anyhow};
|
||||
use breakpoints_in_file::BreakpointsInFile;
|
||||
use collections::BTreeMap;
|
||||
pub use breakpoints_in_file::{BreakpointSessionState, BreakpointWithPosition};
|
||||
use breakpoints_in_file::{BreakpointsInFile, StatefulBreakpoint};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use dap::{StackFrameId, client::SessionId};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use itertools::Itertools;
|
||||
@@ -14,21 +15,54 @@ use rpc::{
|
||||
};
|
||||
use std::{hash::Hash, ops::Range, path::Path, sync::Arc, u32};
|
||||
use text::{Point, PointUtf16};
|
||||
use util::maybe;
|
||||
|
||||
use crate::{Project, ProjectPath, buffer_store::BufferStore, worktree_store::WorktreeStore};
|
||||
|
||||
use super::session::ThreadId;
|
||||
|
||||
mod breakpoints_in_file {
|
||||
use collections::HashMap;
|
||||
use language::{BufferEvent, DiskState};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct BreakpointWithPosition {
|
||||
pub position: text::Anchor,
|
||||
pub bp: Breakpoint,
|
||||
}
|
||||
|
||||
/// A breakpoint with per-session data about it's state (as seen by the Debug Adapter).
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StatefulBreakpoint {
|
||||
pub bp: BreakpointWithPosition,
|
||||
pub session_state: HashMap<SessionId, BreakpointSessionState>,
|
||||
}
|
||||
|
||||
impl StatefulBreakpoint {
|
||||
pub(super) fn new(bp: BreakpointWithPosition) -> Self {
|
||||
Self {
|
||||
bp,
|
||||
session_state: Default::default(),
|
||||
}
|
||||
}
|
||||
pub(super) fn position(&self) -> &text::Anchor {
|
||||
&self.bp.position
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct BreakpointSessionState {
|
||||
/// Session-specific identifier for the breakpoint, as assigned by Debug Adapter.
|
||||
pub id: u64,
|
||||
pub verified: bool,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
pub(super) struct BreakpointsInFile {
|
||||
pub(super) buffer: Entity<Buffer>,
|
||||
// TODO: This is.. less than ideal, as it's O(n) and does not return entries in order. We'll have to change TreeMap to support passing in the context for comparisons
|
||||
pub(super) breakpoints: Vec<(text::Anchor, Breakpoint)>,
|
||||
pub(super) breakpoints: Vec<StatefulBreakpoint>,
|
||||
_subscription: Arc<Subscription>,
|
||||
}
|
||||
|
||||
@@ -199,9 +233,26 @@ impl BreakpointStore {
|
||||
.breakpoints
|
||||
.into_iter()
|
||||
.filter_map(|breakpoint| {
|
||||
let anchor = language::proto::deserialize_anchor(breakpoint.position.clone()?)?;
|
||||
let position =
|
||||
language::proto::deserialize_anchor(breakpoint.position.clone()?)?;
|
||||
let session_state = breakpoint
|
||||
.session_state
|
||||
.iter()
|
||||
.map(|(session_id, state)| {
|
||||
let state = BreakpointSessionState {
|
||||
id: state.id,
|
||||
verified: state.verified,
|
||||
};
|
||||
(SessionId::from_proto(*session_id), state)
|
||||
})
|
||||
.collect();
|
||||
let breakpoint = Breakpoint::from_proto(breakpoint)?;
|
||||
Some((anchor, breakpoint))
|
||||
let bp = BreakpointWithPosition {
|
||||
position,
|
||||
bp: breakpoint,
|
||||
};
|
||||
|
||||
Some(StatefulBreakpoint { bp, session_state })
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -231,7 +282,7 @@ impl BreakpointStore {
|
||||
.payload
|
||||
.breakpoint
|
||||
.ok_or_else(|| anyhow!("Breakpoint not present in RPC payload"))?;
|
||||
let anchor = language::proto::deserialize_anchor(
|
||||
let position = language::proto::deserialize_anchor(
|
||||
breakpoint
|
||||
.position
|
||||
.clone()
|
||||
@@ -244,7 +295,10 @@ impl BreakpointStore {
|
||||
breakpoints.update(&mut cx, |this, cx| {
|
||||
this.toggle_breakpoint(
|
||||
buffer,
|
||||
(anchor, breakpoint),
|
||||
BreakpointWithPosition {
|
||||
position,
|
||||
bp: breakpoint,
|
||||
},
|
||||
BreakpointEditAction::Toggle,
|
||||
cx,
|
||||
);
|
||||
@@ -261,13 +315,76 @@ impl BreakpointStore {
|
||||
breakpoints: breakpoint_set
|
||||
.breakpoints
|
||||
.iter()
|
||||
.filter_map(|(anchor, bp)| bp.to_proto(&path, anchor))
|
||||
.filter_map(|breakpoint| {
|
||||
breakpoint.bp.bp.to_proto(
|
||||
&path,
|
||||
&breakpoint.position(),
|
||||
&breakpoint.session_state,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn update_session_breakpoint(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
_: dap::BreakpointEventReason,
|
||||
breakpoint: dap::Breakpoint,
|
||||
) {
|
||||
maybe!({
|
||||
let event_id = breakpoint.id?;
|
||||
|
||||
let state = self
|
||||
.breakpoints
|
||||
.values_mut()
|
||||
.find_map(|breakpoints_in_file| {
|
||||
breakpoints_in_file
|
||||
.breakpoints
|
||||
.iter_mut()
|
||||
.find_map(|state| {
|
||||
let state = state.session_state.get_mut(&session_id)?;
|
||||
|
||||
if state.id == event_id {
|
||||
Some(state)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})?;
|
||||
|
||||
state.verified = breakpoint.verified;
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
|
||||
pub(super) fn mark_breakpoints_verified(
|
||||
&mut self,
|
||||
session_id: SessionId,
|
||||
abs_path: &Path,
|
||||
|
||||
it: impl Iterator<Item = (BreakpointWithPosition, BreakpointSessionState)>,
|
||||
) {
|
||||
maybe!({
|
||||
let breakpoints = self.breakpoints.get_mut(abs_path)?;
|
||||
for (breakpoint, state) in it {
|
||||
if let Some(to_update) = breakpoints
|
||||
.breakpoints
|
||||
.iter_mut()
|
||||
.find(|bp| *bp.position() == breakpoint.position)
|
||||
{
|
||||
to_update
|
||||
.session_state
|
||||
.entry(session_id)
|
||||
.insert_entry(state);
|
||||
}
|
||||
}
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
|
||||
pub fn abs_path_from_buffer(buffer: &Entity<Buffer>, cx: &App) -> Option<Arc<Path>> {
|
||||
worktree::File::from_dyn(buffer.read(cx).file())
|
||||
.and_then(|file| file.worktree.read(cx).absolutize(&file.path).ok())
|
||||
@@ -277,7 +394,7 @@ impl BreakpointStore {
|
||||
pub fn toggle_breakpoint(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
mut breakpoint: (text::Anchor, Breakpoint),
|
||||
mut breakpoint: BreakpointWithPosition,
|
||||
edit_action: BreakpointEditAction,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
@@ -295,54 +412,57 @@ impl BreakpointStore {
|
||||
let len_before = breakpoint_set.breakpoints.len();
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.retain(|value| &breakpoint != value);
|
||||
.retain(|value| breakpoint != value.bp);
|
||||
if len_before == breakpoint_set.breakpoints.len() {
|
||||
// We did not remove any breakpoint, hence let's toggle one.
|
||||
breakpoint_set.breakpoints.push(breakpoint.clone());
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.push(StatefulBreakpoint::new(breakpoint.clone()));
|
||||
}
|
||||
}
|
||||
BreakpointEditAction::InvertState => {
|
||||
if let Some((_, bp)) = breakpoint_set
|
||||
if let Some(bp) = breakpoint_set
|
||||
.breakpoints
|
||||
.iter_mut()
|
||||
.find(|value| breakpoint == **value)
|
||||
.find(|value| breakpoint == value.bp)
|
||||
{
|
||||
let bp = &mut bp.bp.bp;
|
||||
if bp.is_enabled() {
|
||||
bp.state = BreakpointState::Disabled;
|
||||
} else {
|
||||
bp.state = BreakpointState::Enabled;
|
||||
}
|
||||
} else {
|
||||
breakpoint.1.state = BreakpointState::Disabled;
|
||||
breakpoint_set.breakpoints.push(breakpoint.clone());
|
||||
breakpoint.bp.state = BreakpointState::Disabled;
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.push(StatefulBreakpoint::new(breakpoint.clone()));
|
||||
}
|
||||
}
|
||||
BreakpointEditAction::EditLogMessage(log_message) => {
|
||||
if !log_message.is_empty() {
|
||||
let found_bp =
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.iter_mut()
|
||||
.find_map(|(other_pos, other_bp)| {
|
||||
if breakpoint.0 == *other_pos {
|
||||
Some(other_bp)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let found_bp = breakpoint_set.breakpoints.iter_mut().find_map(|bp| {
|
||||
if breakpoint.position == *bp.position() {
|
||||
Some(&mut bp.bp.bp)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(found_bp) = found_bp {
|
||||
found_bp.message = Some(log_message.clone());
|
||||
} else {
|
||||
breakpoint.1.message = Some(log_message.clone());
|
||||
breakpoint.bp.message = Some(log_message.clone());
|
||||
// We did not remove any breakpoint, hence let's toggle one.
|
||||
breakpoint_set.breakpoints.push(breakpoint.clone());
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.push(StatefulBreakpoint::new(breakpoint.clone()));
|
||||
}
|
||||
} else if breakpoint.1.message.is_some() {
|
||||
} else if breakpoint.bp.message.is_some() {
|
||||
if let Some(position) = breakpoint_set
|
||||
.breakpoints
|
||||
.iter()
|
||||
.find_position(|(pos, bp)| &breakpoint.0 == pos && bp == &breakpoint.1)
|
||||
.find_position(|other| breakpoint == other.bp)
|
||||
.map(|res| res.0)
|
||||
{
|
||||
breakpoint_set.breakpoints.remove(position);
|
||||
@@ -353,30 +473,28 @@ impl BreakpointStore {
|
||||
}
|
||||
BreakpointEditAction::EditHitCondition(hit_condition) => {
|
||||
if !hit_condition.is_empty() {
|
||||
let found_bp =
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.iter_mut()
|
||||
.find_map(|(other_pos, other_bp)| {
|
||||
if breakpoint.0 == *other_pos {
|
||||
Some(other_bp)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let found_bp = breakpoint_set.breakpoints.iter_mut().find_map(|other| {
|
||||
if breakpoint.position == *other.position() {
|
||||
Some(&mut other.bp.bp)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(found_bp) = found_bp {
|
||||
found_bp.hit_condition = Some(hit_condition.clone());
|
||||
} else {
|
||||
breakpoint.1.hit_condition = Some(hit_condition.clone());
|
||||
breakpoint.bp.hit_condition = Some(hit_condition.clone());
|
||||
// We did not remove any breakpoint, hence let's toggle one.
|
||||
breakpoint_set.breakpoints.push(breakpoint.clone());
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.push(StatefulBreakpoint::new(breakpoint.clone()))
|
||||
}
|
||||
} else if breakpoint.1.hit_condition.is_some() {
|
||||
} else if breakpoint.bp.hit_condition.is_some() {
|
||||
if let Some(position) = breakpoint_set
|
||||
.breakpoints
|
||||
.iter()
|
||||
.find_position(|(pos, bp)| &breakpoint.0 == pos && bp == &breakpoint.1)
|
||||
.find_position(|bp| breakpoint == bp.bp)
|
||||
.map(|res| res.0)
|
||||
{
|
||||
breakpoint_set.breakpoints.remove(position);
|
||||
@@ -387,30 +505,28 @@ impl BreakpointStore {
|
||||
}
|
||||
BreakpointEditAction::EditCondition(condition) => {
|
||||
if !condition.is_empty() {
|
||||
let found_bp =
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.iter_mut()
|
||||
.find_map(|(other_pos, other_bp)| {
|
||||
if breakpoint.0 == *other_pos {
|
||||
Some(other_bp)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let found_bp = breakpoint_set.breakpoints.iter_mut().find_map(|other| {
|
||||
if breakpoint.position == *other.position() {
|
||||
Some(&mut other.bp.bp)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(found_bp) = found_bp {
|
||||
found_bp.condition = Some(condition.clone());
|
||||
} else {
|
||||
breakpoint.1.condition = Some(condition.clone());
|
||||
breakpoint.bp.condition = Some(condition.clone());
|
||||
// We did not remove any breakpoint, hence let's toggle one.
|
||||
breakpoint_set.breakpoints.push(breakpoint.clone());
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.push(StatefulBreakpoint::new(breakpoint.clone()));
|
||||
}
|
||||
} else if breakpoint.1.condition.is_some() {
|
||||
} else if breakpoint.bp.condition.is_some() {
|
||||
if let Some(position) = breakpoint_set
|
||||
.breakpoints
|
||||
.iter()
|
||||
.find_position(|(pos, bp)| &breakpoint.0 == pos && bp == &breakpoint.1)
|
||||
.find_position(|bp| breakpoint == bp.bp)
|
||||
.map(|res| res.0)
|
||||
{
|
||||
breakpoint_set.breakpoints.remove(position);
|
||||
@@ -425,7 +541,11 @@ impl BreakpointStore {
|
||||
self.breakpoints.remove(&abs_path);
|
||||
}
|
||||
if let BreakpointStoreMode::Remote(remote) = &self.mode {
|
||||
if let Some(breakpoint) = breakpoint.1.to_proto(&abs_path, &breakpoint.0) {
|
||||
if let Some(breakpoint) =
|
||||
breakpoint
|
||||
.bp
|
||||
.to_proto(&abs_path, &breakpoint.position, &HashMap::default())
|
||||
{
|
||||
cx.background_spawn(remote.upstream_client.request(proto::ToggleBreakpoint {
|
||||
project_id: remote._upstream_project_id,
|
||||
path: abs_path.to_str().map(ToOwned::to_owned).unwrap(),
|
||||
@@ -441,7 +561,11 @@ impl BreakpointStore {
|
||||
breakpoint_set
|
||||
.breakpoints
|
||||
.iter()
|
||||
.filter_map(|(anchor, bp)| bp.to_proto(&abs_path, anchor))
|
||||
.filter_map(|bp| {
|
||||
bp.bp
|
||||
.bp
|
||||
.to_proto(&abs_path, bp.position(), &bp.session_state)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
@@ -485,21 +609,31 @@ impl BreakpointStore {
|
||||
range: Option<Range<text::Anchor>>,
|
||||
buffer_snapshot: &'a BufferSnapshot,
|
||||
cx: &App,
|
||||
) -> impl Iterator<Item = &'a (text::Anchor, Breakpoint)> + 'a {
|
||||
) -> impl Iterator<Item = (&'a BreakpointWithPosition, Option<BreakpointSessionState>)> + 'a
|
||||
{
|
||||
let abs_path = Self::abs_path_from_buffer(buffer, cx);
|
||||
let active_session_id = self
|
||||
.active_stack_frame
|
||||
.as_ref()
|
||||
.map(|frame| frame.session_id);
|
||||
abs_path
|
||||
.and_then(|path| self.breakpoints.get(&path))
|
||||
.into_iter()
|
||||
.flat_map(move |file_breakpoints| {
|
||||
file_breakpoints.breakpoints.iter().filter({
|
||||
file_breakpoints.breakpoints.iter().filter_map({
|
||||
let range = range.clone();
|
||||
move |(position, _)| {
|
||||
move |bp| {
|
||||
if let Some(range) = &range {
|
||||
position.cmp(&range.start, buffer_snapshot).is_ge()
|
||||
&& position.cmp(&range.end, buffer_snapshot).is_le()
|
||||
} else {
|
||||
true
|
||||
if bp.position().cmp(&range.start, buffer_snapshot).is_lt()
|
||||
|| bp.position().cmp(&range.end, buffer_snapshot).is_gt()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
}
|
||||
let session_state = active_session_id
|
||||
.and_then(|id| bp.session_state.get(&id))
|
||||
.copied();
|
||||
Some((&bp.bp, session_state))
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -549,34 +683,46 @@ impl BreakpointStore {
|
||||
path: &Path,
|
||||
row: u32,
|
||||
cx: &App,
|
||||
) -> Option<(Entity<Buffer>, (text::Anchor, Breakpoint))> {
|
||||
) -> Option<(Entity<Buffer>, BreakpointWithPosition)> {
|
||||
self.breakpoints.get(path).and_then(|breakpoints| {
|
||||
let snapshot = breakpoints.buffer.read(cx).text_snapshot();
|
||||
|
||||
breakpoints
|
||||
.breakpoints
|
||||
.iter()
|
||||
.find(|(anchor, _)| anchor.summary::<Point>(&snapshot).row == row)
|
||||
.map(|breakpoint| (breakpoints.buffer.clone(), breakpoint.clone()))
|
||||
.find(|bp| bp.position().summary::<Point>(&snapshot).row == row)
|
||||
.map(|breakpoint| (breakpoints.buffer.clone(), breakpoint.bp.clone()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn breakpoints_from_path(&self, path: &Arc<Path>, cx: &App) -> Vec<SourceBreakpoint> {
|
||||
pub fn breakpoints_from_path(&self, path: &Arc<Path>) -> Vec<BreakpointWithPosition> {
|
||||
self.breakpoints
|
||||
.get(path)
|
||||
.map(|bp| bp.breakpoints.iter().map(|bp| bp.bp.clone()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn source_breakpoints_from_path(
|
||||
&self,
|
||||
path: &Arc<Path>,
|
||||
cx: &App,
|
||||
) -> Vec<SourceBreakpoint> {
|
||||
self.breakpoints
|
||||
.get(path)
|
||||
.map(|bp| {
|
||||
let snapshot = bp.buffer.read(cx).snapshot();
|
||||
bp.breakpoints
|
||||
.iter()
|
||||
.map(|(position, breakpoint)| {
|
||||
let position = snapshot.summary_for_anchor::<PointUtf16>(position).row;
|
||||
.map(|bp| {
|
||||
let position = snapshot.summary_for_anchor::<PointUtf16>(bp.position()).row;
|
||||
let bp = &bp.bp;
|
||||
SourceBreakpoint {
|
||||
row: position,
|
||||
path: path.clone(),
|
||||
state: breakpoint.state,
|
||||
message: breakpoint.message.clone(),
|
||||
condition: breakpoint.condition.clone(),
|
||||
hit_condition: breakpoint.hit_condition.clone(),
|
||||
state: bp.bp.state,
|
||||
message: bp.bp.message.clone(),
|
||||
condition: bp.bp.condition.clone(),
|
||||
hit_condition: bp.bp.hit_condition.clone(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
@@ -584,7 +730,18 @@ impl BreakpointStore {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn all_breakpoints(&self, cx: &App) -> BTreeMap<Arc<Path>, Vec<SourceBreakpoint>> {
|
||||
pub fn all_breakpoints(&self) -> BTreeMap<Arc<Path>, Vec<BreakpointWithPosition>> {
|
||||
self.breakpoints
|
||||
.iter()
|
||||
.map(|(path, bp)| {
|
||||
(
|
||||
path.clone(),
|
||||
bp.breakpoints.iter().map(|bp| bp.bp.clone()).collect(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
pub fn all_source_breakpoints(&self, cx: &App) -> BTreeMap<Arc<Path>, Vec<SourceBreakpoint>> {
|
||||
self.breakpoints
|
||||
.iter()
|
||||
.map(|(path, bp)| {
|
||||
@@ -593,15 +750,18 @@ impl BreakpointStore {
|
||||
path.clone(),
|
||||
bp.breakpoints
|
||||
.iter()
|
||||
.map(|(position, breakpoint)| {
|
||||
let position = snapshot.summary_for_anchor::<PointUtf16>(position).row;
|
||||
.map(|breakpoint| {
|
||||
let position = snapshot
|
||||
.summary_for_anchor::<PointUtf16>(&breakpoint.position())
|
||||
.row;
|
||||
let breakpoint = &breakpoint.bp;
|
||||
SourceBreakpoint {
|
||||
row: position,
|
||||
path: path.clone(),
|
||||
message: breakpoint.message.clone(),
|
||||
state: breakpoint.state,
|
||||
hit_condition: breakpoint.hit_condition.clone(),
|
||||
condition: breakpoint.condition.clone(),
|
||||
message: breakpoint.bp.message.clone(),
|
||||
state: breakpoint.bp.state,
|
||||
hit_condition: breakpoint.bp.hit_condition.clone(),
|
||||
condition: breakpoint.bp.condition.clone(),
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
@@ -656,15 +816,17 @@ impl BreakpointStore {
|
||||
continue;
|
||||
}
|
||||
let position = snapshot.anchor_after(point);
|
||||
breakpoints_for_file.breakpoints.push((
|
||||
position,
|
||||
Breakpoint {
|
||||
message: bp.message,
|
||||
state: bp.state,
|
||||
condition: bp.condition,
|
||||
hit_condition: bp.hit_condition,
|
||||
},
|
||||
))
|
||||
breakpoints_for_file
|
||||
.breakpoints
|
||||
.push(StatefulBreakpoint::new(BreakpointWithPosition {
|
||||
position,
|
||||
bp: Breakpoint {
|
||||
message: bp.message,
|
||||
state: bp.state,
|
||||
condition: bp.condition,
|
||||
hit_condition: bp.hit_condition,
|
||||
},
|
||||
}))
|
||||
}
|
||||
new_breakpoints.insert(path, breakpoints_for_file);
|
||||
}
|
||||
@@ -755,7 +917,7 @@ impl BreakpointState {
|
||||
pub struct Breakpoint {
|
||||
pub message: Option<BreakpointMessage>,
|
||||
/// How many times do we hit the breakpoint until we actually stop at it e.g. (2 = 2 times of the breakpoint action)
|
||||
pub hit_condition: Option<BreakpointMessage>,
|
||||
pub hit_condition: Option<Arc<str>>,
|
||||
pub condition: Option<BreakpointMessage>,
|
||||
pub state: BreakpointState,
|
||||
}
|
||||
@@ -788,7 +950,12 @@ impl Breakpoint {
|
||||
}
|
||||
}
|
||||
|
||||
fn to_proto(&self, _path: &Path, position: &text::Anchor) -> Option<client::proto::Breakpoint> {
|
||||
fn to_proto(
|
||||
&self,
|
||||
_path: &Path,
|
||||
position: &text::Anchor,
|
||||
session_states: &HashMap<SessionId, BreakpointSessionState>,
|
||||
) -> Option<client::proto::Breakpoint> {
|
||||
Some(client::proto::Breakpoint {
|
||||
position: Some(serialize_text_anchor(position)),
|
||||
state: match self.state {
|
||||
@@ -801,6 +968,18 @@ impl Breakpoint {
|
||||
.hit_condition
|
||||
.as_ref()
|
||||
.map(|s| String::from(s.as_ref())),
|
||||
session_state: session_states
|
||||
.iter()
|
||||
.map(|(session_id, state)| {
|
||||
(
|
||||
session_id.to_proto(),
|
||||
proto::BreakpointSessionState {
|
||||
id: state.id,
|
||||
verified: state.verified,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -10,13 +10,15 @@ use crate::{
|
||||
terminals::{SshCommand, wrap_for_ssh},
|
||||
worktree_store::WorktreeStore,
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use collections::HashMap;
|
||||
use dap::{
|
||||
Capabilities, CompletionItem, CompletionsArguments, DapRegistry, DebugRequest,
|
||||
EvaluateArguments, EvaluateArgumentsContext, EvaluateResponse, Source, StackFrameId,
|
||||
adapters::{DebugAdapterBinary, DebugAdapterName, DebugTaskDefinition, TcpArguments},
|
||||
adapters::{
|
||||
DapDelegate, DebugAdapterBinary, DebugAdapterName, DebugTaskDefinition, TcpArguments,
|
||||
},
|
||||
client::SessionId,
|
||||
inline_value::VariableLookupKind,
|
||||
messages::Message,
|
||||
@@ -488,14 +490,14 @@ impl DapStore {
|
||||
worktree: &Entity<Worktree>,
|
||||
console: UnboundedSender<String>,
|
||||
cx: &mut App,
|
||||
) -> DapAdapterDelegate {
|
||||
) -> Arc<dyn DapDelegate> {
|
||||
let Some(local_store) = self.as_local() else {
|
||||
unimplemented!("Starting session on remote side");
|
||||
};
|
||||
|
||||
DapAdapterDelegate::new(
|
||||
Arc::new(DapAdapterDelegate::new(
|
||||
local_store.fs.clone(),
|
||||
worktree.read(cx).id(),
|
||||
worktree.read(cx).snapshot(),
|
||||
console,
|
||||
local_store.node_runtime.clone(),
|
||||
local_store.http_client.clone(),
|
||||
@@ -503,7 +505,7 @@ impl DapStore {
|
||||
local_store.environment.update(cx, |env, cx| {
|
||||
env.get_worktree_environment(worktree.clone(), cx)
|
||||
}),
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
pub fn evaluate(
|
||||
@@ -811,7 +813,7 @@ impl DapStore {
|
||||
pub struct DapAdapterDelegate {
|
||||
fs: Arc<dyn Fs>,
|
||||
console: mpsc::UnboundedSender<String>,
|
||||
worktree_id: WorktreeId,
|
||||
worktree: worktree::Snapshot,
|
||||
node_runtime: NodeRuntime,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
toolchain_store: Arc<dyn LanguageToolchainStore>,
|
||||
@@ -821,7 +823,7 @@ pub struct DapAdapterDelegate {
|
||||
impl DapAdapterDelegate {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
worktree_id: WorktreeId,
|
||||
worktree: worktree::Snapshot,
|
||||
status: mpsc::UnboundedSender<String>,
|
||||
node_runtime: NodeRuntime,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
@@ -831,7 +833,7 @@ impl DapAdapterDelegate {
|
||||
Self {
|
||||
fs,
|
||||
console: status,
|
||||
worktree_id,
|
||||
worktree,
|
||||
http_client,
|
||||
node_runtime,
|
||||
toolchain_store,
|
||||
@@ -840,12 +842,15 @@ impl DapAdapterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
#[async_trait]
|
||||
impl dap::adapters::DapDelegate for DapAdapterDelegate {
|
||||
fn worktree_id(&self) -> WorktreeId {
|
||||
self.worktree_id
|
||||
self.worktree.id()
|
||||
}
|
||||
|
||||
fn worktree_root_path(&self) -> &Path {
|
||||
&self.worktree.abs_path()
|
||||
}
|
||||
fn http_client(&self) -> Arc<dyn HttpClient> {
|
||||
self.http_client.clone()
|
||||
}
|
||||
@@ -862,7 +867,7 @@ impl dap::adapters::DapDelegate for DapAdapterDelegate {
|
||||
self.console.unbounded_send(msg).ok();
|
||||
}
|
||||
|
||||
fn which(&self, command: &OsStr) -> Option<PathBuf> {
|
||||
async fn which(&self, command: &OsStr) -> Option<PathBuf> {
|
||||
which::which(command).ok()
|
||||
}
|
||||
|
||||
@@ -874,4 +879,16 @@ impl dap::adapters::DapDelegate for DapAdapterDelegate {
|
||||
fn toolchain_store(&self) -> Arc<dyn LanguageToolchainStore> {
|
||||
self.toolchain_store.clone()
|
||||
}
|
||||
async fn read_text_file(&self, path: PathBuf) -> Result<String> {
|
||||
let entry = self
|
||||
.worktree
|
||||
.entry_for_path(&path)
|
||||
.with_context(|| format!("no worktree entry for path {path:?}"))?;
|
||||
let abs_path = self
|
||||
.worktree
|
||||
.absolutize(&entry.path)
|
||||
.with_context(|| format!("cannot absolutize path {path:?}"))?;
|
||||
|
||||
self.fs.load(&abs_path).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::debugger::breakpoint_store::BreakpointSessionState;
|
||||
|
||||
use super::breakpoint_store::{
|
||||
BreakpointStore, BreakpointStoreEvent, BreakpointUpdatedReason, SourceBreakpoint,
|
||||
};
|
||||
@@ -218,25 +220,55 @@ impl LocalMode {
|
||||
breakpoint_store: &Entity<BreakpointStore>,
|
||||
cx: &mut App,
|
||||
) -> Task<()> {
|
||||
let breakpoints = breakpoint_store
|
||||
.read_with(cx, |store, cx| store.breakpoints_from_path(&abs_path, cx))
|
||||
let breakpoints =
|
||||
breakpoint_store
|
||||
.read_with(cx, |store, cx| {
|
||||
store.source_breakpoints_from_path(&abs_path, cx)
|
||||
})
|
||||
.into_iter()
|
||||
.filter(|bp| bp.state.is_enabled())
|
||||
.chain(self.tmp_breakpoint.iter().filter_map(|breakpoint| {
|
||||
breakpoint.path.eq(&abs_path).then(|| breakpoint.clone())
|
||||
}))
|
||||
.map(Into::into)
|
||||
.collect();
|
||||
|
||||
let raw_breakpoints = breakpoint_store
|
||||
.read(cx)
|
||||
.breakpoints_from_path(&abs_path)
|
||||
.into_iter()
|
||||
.filter(|bp| bp.state.is_enabled())
|
||||
.chain(self.tmp_breakpoint.clone())
|
||||
.map(Into::into)
|
||||
.collect();
|
||||
.filter(|bp| bp.bp.state.is_enabled())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let task = self.request(dap_command::SetBreakpoints {
|
||||
source: client_source(&abs_path),
|
||||
source_modified: Some(matches!(reason, BreakpointUpdatedReason::FileSaved)),
|
||||
breakpoints,
|
||||
});
|
||||
|
||||
cx.background_spawn(async move {
|
||||
match task.await {
|
||||
Ok(_) => {}
|
||||
Err(err) => log::warn!("Set breakpoints request failed for path: {}", err),
|
||||
let session_id = self.client.id();
|
||||
let breakpoint_store = breakpoint_store.downgrade();
|
||||
cx.spawn(async move |cx| match cx.background_spawn(task).await {
|
||||
Ok(breakpoints) => {
|
||||
let breakpoints =
|
||||
breakpoints
|
||||
.into_iter()
|
||||
.zip(raw_breakpoints)
|
||||
.filter_map(|(dap_bp, zed_bp)| {
|
||||
Some((
|
||||
zed_bp,
|
||||
BreakpointSessionState {
|
||||
id: dap_bp.id?,
|
||||
verified: dap_bp.verified,
|
||||
},
|
||||
))
|
||||
});
|
||||
breakpoint_store
|
||||
.update(cx, |this, _| {
|
||||
this.mark_breakpoints_verified(session_id, &abs_path, breakpoints);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
Err(err) => log::warn!("Set breakpoints request failed for path: {}", err),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -271,8 +303,11 @@ impl LocalMode {
|
||||
cx: &App,
|
||||
) -> Task<HashMap<Arc<Path>, anyhow::Error>> {
|
||||
let mut breakpoint_tasks = Vec::new();
|
||||
let breakpoints = breakpoint_store.read_with(cx, |store, cx| store.all_breakpoints(cx));
|
||||
|
||||
let breakpoints =
|
||||
breakpoint_store.read_with(cx, |store, cx| store.all_source_breakpoints(cx));
|
||||
let mut raw_breakpoints = breakpoint_store.read_with(cx, |this, _| this.all_breakpoints());
|
||||
debug_assert_eq!(raw_breakpoints.len(), breakpoints.len());
|
||||
let session_id = self.client.id();
|
||||
for (path, breakpoints) in breakpoints {
|
||||
let breakpoints = if ignore_breakpoints {
|
||||
vec![]
|
||||
@@ -284,14 +319,46 @@ impl LocalMode {
|
||||
.collect()
|
||||
};
|
||||
|
||||
breakpoint_tasks.push(
|
||||
self.request(dap_command::SetBreakpoints {
|
||||
let raw_breakpoints = raw_breakpoints
|
||||
.remove(&path)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter(|bp| bp.bp.state.is_enabled());
|
||||
let error_path = path.clone();
|
||||
let send_request = self
|
||||
.request(dap_command::SetBreakpoints {
|
||||
source: client_source(&path),
|
||||
source_modified: Some(false),
|
||||
breakpoints,
|
||||
})
|
||||
.map(|result| result.map_err(|e| (path, e))),
|
||||
);
|
||||
.map(|result| result.map_err(move |e| (error_path, e)));
|
||||
|
||||
let task = cx.spawn({
|
||||
let breakpoint_store = breakpoint_store.downgrade();
|
||||
async move |cx| {
|
||||
let breakpoints = cx.background_spawn(send_request).await?;
|
||||
|
||||
let breakpoints = breakpoints.into_iter().zip(raw_breakpoints).filter_map(
|
||||
|(dap_bp, zed_bp)| {
|
||||
Some((
|
||||
zed_bp,
|
||||
BreakpointSessionState {
|
||||
id: dap_bp.id?,
|
||||
verified: dap_bp.verified,
|
||||
},
|
||||
))
|
||||
},
|
||||
);
|
||||
breakpoint_store
|
||||
.update(cx, |this, _| {
|
||||
this.mark_breakpoints_verified(session_id, &path, breakpoints);
|
||||
})
|
||||
.ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
breakpoint_tasks.push(task);
|
||||
}
|
||||
|
||||
cx.background_spawn(async move {
|
||||
@@ -1204,7 +1271,9 @@ impl Session {
|
||||
self.output_token.0 += 1;
|
||||
cx.notify();
|
||||
}
|
||||
Events::Breakpoint(_) => {}
|
||||
Events::Breakpoint(event) => self.breakpoint_store.update(cx, |store, _| {
|
||||
store.update_session_breakpoint(self.session_id(), event.reason, event.breakpoint);
|
||||
}),
|
||||
Events::Module(event) => {
|
||||
match event.reason {
|
||||
dap::ModuleEventReason::New => {
|
||||
|
||||
@@ -12,10 +12,10 @@ pub use image::ImageFormat;
|
||||
use image::{ExtendedColorType, GenericImageView, ImageReader};
|
||||
use language::{DiskState, File};
|
||||
use rpc::{AnyProtoClient, ErrorExt as _};
|
||||
use std::ffi::OsStr;
|
||||
use std::num::NonZeroU64;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::{ffi::OsStr, path::PathBuf};
|
||||
use util::ResultExt;
|
||||
use worktree::{LoadedBinaryFile, PathChange, Worktree};
|
||||
|
||||
@@ -96,7 +96,7 @@ impl ImageColorInfo {
|
||||
|
||||
pub struct ImageItem {
|
||||
pub id: ImageId,
|
||||
pub file: Arc<dyn File>,
|
||||
pub file: Arc<worktree::File>,
|
||||
pub image: Arc<gpui::Image>,
|
||||
reload_task: Option<Task<()>>,
|
||||
pub image_metadata: Option<ImageMetadata>,
|
||||
@@ -109,22 +109,11 @@ impl ImageItem {
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ImageMetadata> {
|
||||
let (fs, image_path) = cx.update(|cx| {
|
||||
let project_path = image.read(cx).project_path(cx);
|
||||
|
||||
let worktree = project
|
||||
.read(cx)
|
||||
.worktree_for_id(project_path.worktree_id, cx)
|
||||
.ok_or_else(|| anyhow!("worktree not found"))?;
|
||||
let worktree_root = worktree.read(cx).abs_path();
|
||||
let image_path = image.read(cx).path();
|
||||
let image_path = if image_path.is_absolute() {
|
||||
image_path.to_path_buf()
|
||||
} else {
|
||||
worktree_root.join(image_path)
|
||||
};
|
||||
|
||||
let fs = project.read(cx).fs().clone();
|
||||
|
||||
let image_path = image
|
||||
.read(cx)
|
||||
.abs_path(cx)
|
||||
.context("absolutizing image file path")?;
|
||||
anyhow::Ok((fs, image_path))
|
||||
})??;
|
||||
|
||||
@@ -157,14 +146,14 @@ impl ImageItem {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Arc<Path> {
|
||||
self.file.path()
|
||||
pub fn abs_path(&self, cx: &App) -> Option<PathBuf> {
|
||||
Some(self.file.as_local()?.abs_path(cx))
|
||||
}
|
||||
|
||||
fn file_updated(&mut self, new_file: Arc<dyn File>, cx: &mut Context<Self>) {
|
||||
fn file_updated(&mut self, new_file: Arc<worktree::File>, cx: &mut Context<Self>) {
|
||||
let mut file_changed = false;
|
||||
|
||||
let old_file = self.file.as_ref();
|
||||
let old_file = &self.file;
|
||||
if new_file.path() != old_file.path() {
|
||||
file_changed = true;
|
||||
}
|
||||
@@ -251,7 +240,7 @@ impl ProjectItem for ImageItem {
|
||||
}
|
||||
|
||||
fn entry_id(&self, _: &App) -> Option<ProjectEntryId> {
|
||||
worktree::File::from_dyn(Some(&self.file))?.entry_id
|
||||
self.file.entry_id
|
||||
}
|
||||
|
||||
fn project_path(&self, cx: &App) -> Option<ProjectPath> {
|
||||
@@ -387,6 +376,12 @@ impl ImageStore {
|
||||
entry.insert(rx.clone());
|
||||
|
||||
let project_path = project_path.clone();
|
||||
// TODO kb this is causing another error, and we also pass a worktree nearby — seems ok to pass "" here?
|
||||
// let image_path = worktree
|
||||
// .read(cx)
|
||||
// .absolutize(&project_path.path)
|
||||
// .map(Arc::from)
|
||||
// .unwrap_or_else(|_| project_path.path.clone());
|
||||
let load_image = self
|
||||
.state
|
||||
.open_image(project_path.path.clone(), worktree, cx);
|
||||
@@ -604,9 +599,7 @@ impl LocalImageStore {
|
||||
};
|
||||
|
||||
image.update(cx, |image, cx| {
|
||||
let Some(old_file) = worktree::File::from_dyn(Some(&image.file)) else {
|
||||
return;
|
||||
};
|
||||
let old_file = &image.file;
|
||||
if old_file.worktree != *worktree {
|
||||
return;
|
||||
}
|
||||
@@ -639,7 +632,7 @@ impl LocalImageStore {
|
||||
}
|
||||
};
|
||||
|
||||
if new_file == *old_file {
|
||||
if new_file == **old_file {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -672,9 +665,10 @@ impl LocalImageStore {
|
||||
}
|
||||
|
||||
fn image_changed_file(&mut self, image: Entity<ImageItem>, cx: &mut App) -> Option<()> {
|
||||
let file = worktree::File::from_dyn(Some(&image.read(cx).file))?;
|
||||
let image = image.read(cx);
|
||||
let file = &image.file;
|
||||
|
||||
let image_id = image.read(cx).id;
|
||||
let image_id = image.id;
|
||||
if let Some(entry_id) = file.entry_id {
|
||||
match self.local_image_ids_by_entry_id.get(&entry_id) {
|
||||
Some(_) => {
|
||||
|
||||
@@ -47,6 +47,7 @@ use dap::{DapRegistry, client::DebugAdapterClient};
|
||||
|
||||
use collections::{BTreeSet, HashMap, HashSet};
|
||||
use debounced_delay::DebouncedDelay;
|
||||
pub use debugger::breakpoint_store::BreakpointWithPosition;
|
||||
use debugger::{
|
||||
breakpoint_store::{ActiveStackFrame, BreakpointStore},
|
||||
dap_store::{DapStore, DapStoreEvent},
|
||||
|
||||
@@ -3,7 +3,6 @@ fn main() {
|
||||
build
|
||||
.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
|
||||
.type_attribute("ProjectPath", "#[derive(Hash, Eq)]")
|
||||
.type_attribute("Breakpoint", "#[derive(Hash, Eq)]")
|
||||
.type_attribute("Anchor", "#[derive(Hash, Eq)]")
|
||||
.compile_protos(&["proto/zed.proto"], &["proto"])
|
||||
.unwrap();
|
||||
|
||||
@@ -16,6 +16,12 @@ message Breakpoint {
|
||||
optional string message = 4;
|
||||
optional string condition = 5;
|
||||
optional string hit_condition = 6;
|
||||
map<uint64, BreakpointSessionState> session_state = 7;
|
||||
}
|
||||
|
||||
message BreakpointSessionState {
|
||||
uint64 id = 1;
|
||||
bool verified = 2;
|
||||
}
|
||||
|
||||
message BreakpointsForFile {
|
||||
@@ -30,63 +36,6 @@ message ToggleBreakpoint {
|
||||
Breakpoint breakpoint = 3;
|
||||
}
|
||||
|
||||
enum DebuggerThreadItem {
|
||||
Console = 0;
|
||||
LoadedSource = 1;
|
||||
Modules = 2;
|
||||
Variables = 3;
|
||||
}
|
||||
|
||||
message DebuggerSetVariableState {
|
||||
string name = 1;
|
||||
DapScope scope = 2;
|
||||
string value = 3;
|
||||
uint64 stack_frame_id = 4;
|
||||
optional string evaluate_name = 5;
|
||||
uint64 parent_variables_reference = 6;
|
||||
}
|
||||
|
||||
message VariableListOpenEntry {
|
||||
oneof entry {
|
||||
DebuggerOpenEntryScope scope = 1;
|
||||
DebuggerOpenEntryVariable variable = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message DebuggerOpenEntryScope {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message DebuggerOpenEntryVariable {
|
||||
string scope_name = 1;
|
||||
string name = 2;
|
||||
uint64 depth = 3;
|
||||
}
|
||||
|
||||
message VariableListEntrySetState {
|
||||
uint64 depth = 1;
|
||||
DebuggerSetVariableState state = 2;
|
||||
}
|
||||
|
||||
message VariableListEntryVariable {
|
||||
uint64 depth = 1;
|
||||
DapScope scope = 2;
|
||||
DapVariable variable = 3;
|
||||
bool has_children = 4;
|
||||
uint64 container_reference = 5;
|
||||
}
|
||||
|
||||
message DebuggerScopeVariableIndex {
|
||||
repeated uint64 fetched_ids = 1;
|
||||
repeated DebuggerVariableContainer variables = 2;
|
||||
}
|
||||
|
||||
message DebuggerVariableContainer {
|
||||
uint64 container_reference = 1;
|
||||
DapVariable variable = 2;
|
||||
uint64 depth = 3;
|
||||
}
|
||||
|
||||
enum DapThreadStatus {
|
||||
Running = 0;
|
||||
Stopped = 1;
|
||||
@@ -94,18 +43,6 @@ enum DapThreadStatus {
|
||||
Ended = 3;
|
||||
}
|
||||
|
||||
message VariableListScopes {
|
||||
uint64 stack_frame_id = 1;
|
||||
repeated DapScope scopes = 2;
|
||||
}
|
||||
|
||||
message VariableListVariables {
|
||||
uint64 stack_frame_id = 1;
|
||||
uint64 scope_id = 2;
|
||||
DebuggerScopeVariableIndex variables = 3;
|
||||
}
|
||||
|
||||
|
||||
enum VariablesArgumentsFilter {
|
||||
Indexed = 0;
|
||||
Named = 1;
|
||||
|
||||
@@ -30,6 +30,7 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
dap_adapters.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
env_logger.workspace = true
|
||||
extension.workspace = true
|
||||
extension_host.workspace = true
|
||||
|
||||
@@ -76,6 +76,7 @@ impl HeadlessProject {
|
||||
}: HeadlessAppState,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
debug_adapter_extension::init(proxy.clone(), cx);
|
||||
language_extension::init(proxy.clone(), languages.clone());
|
||||
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||
|
||||
|
||||
@@ -1055,10 +1055,17 @@ impl ProjectSearchView {
|
||||
|
||||
let is_dirty = self.is_dirty(cx);
|
||||
|
||||
let should_confirm_save = !will_autosave && is_dirty;
|
||||
let skip_save_on_close = self
|
||||
.workspace
|
||||
.read_with(cx, |workspace, cx| {
|
||||
workspace::Pane::skip_save_on_close(&self.results_editor, workspace, cx)
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let should_prompt_to_save = !skip_save_on_close && !will_autosave && is_dirty;
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let should_search = if should_confirm_save {
|
||||
let should_search = if should_prompt_to_save {
|
||||
let options = &["Save", "Don't Save", "Cancel"];
|
||||
let result_channel = this.update_in(cx, |_, window, cx| {
|
||||
window.prompt(
|
||||
|
||||
@@ -3,52 +3,48 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
#[derive(Copy, Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
#[serde(default)]
|
||||
#[derive(Copy, Clone, Deserialize, Debug)]
|
||||
pub struct TitleBarSettings {
|
||||
/// Whether to show the branch icon beside branch switcher in the title bar.
|
||||
///
|
||||
/// Default: false
|
||||
pub show_branch_icon: bool,
|
||||
/// Whether to show onboarding banners in the title bar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_onboarding_banner: bool,
|
||||
/// Whether to show user avatar in the title bar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_user_picture: bool,
|
||||
/// Whether to show the branch name button in the titlebar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_branch_name: bool,
|
||||
/// Whether to show the project host and name in the titlebar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_project_items: bool,
|
||||
/// Whether to show the sign in button in the title bar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_sign_in: bool,
|
||||
}
|
||||
|
||||
impl Default for TitleBarSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
show_branch_icon: false,
|
||||
show_onboarding_banner: true,
|
||||
show_user_picture: true,
|
||||
show_branch_name: true,
|
||||
show_project_items: true,
|
||||
show_sign_in: true,
|
||||
}
|
||||
}
|
||||
#[derive(Copy, Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct TitleBarSettingsContent {
|
||||
/// Whether to show the branch icon beside branch switcher in the title bar.
|
||||
///
|
||||
/// Default: false
|
||||
pub show_branch_icon: Option<bool>,
|
||||
/// Whether to show onboarding banners in the title bar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_onboarding_banner: Option<bool>,
|
||||
/// Whether to show user avatar in the title bar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_user_picture: Option<bool>,
|
||||
/// Whether to show the branch name button in the titlebar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_branch_name: Option<bool>,
|
||||
/// Whether to show the project host and name in the titlebar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_project_items: Option<bool>,
|
||||
/// Whether to show the sign in button in the title bar.
|
||||
///
|
||||
/// Default: true
|
||||
pub show_sign_in: Option<bool>,
|
||||
}
|
||||
|
||||
impl Settings for TitleBarSettings {
|
||||
const KEY: Option<&'static str> = Some("title_bar");
|
||||
|
||||
type FileContent = Self;
|
||||
type FileContent = TitleBarSettingsContent;
|
||||
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut gpui::App) -> anyhow::Result<Self>
|
||||
where
|
||||
|
||||
@@ -13,6 +13,7 @@ pub struct ProgressBar {
|
||||
value: f32,
|
||||
max_value: f32,
|
||||
bg_color: Hsla,
|
||||
over_color: Hsla,
|
||||
fg_color: Hsla,
|
||||
}
|
||||
|
||||
@@ -23,6 +24,7 @@ impl ProgressBar {
|
||||
value,
|
||||
max_value,
|
||||
bg_color: cx.theme().colors().background,
|
||||
over_color: cx.theme().status().error,
|
||||
fg_color: cx.theme().status().info,
|
||||
}
|
||||
}
|
||||
@@ -50,6 +52,12 @@ impl ProgressBar {
|
||||
self.fg_color = color;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the over limit color of the progress bar.
|
||||
pub fn over_color(mut self, color: Hsla) -> Self {
|
||||
self.over_color = color;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for ProgressBar {
|
||||
@@ -74,7 +82,8 @@ impl RenderOnce for ProgressBar {
|
||||
div()
|
||||
.h_full()
|
||||
.rounded_full()
|
||||
.bg(self.fg_color)
|
||||
.when(self.value > self.max_value, |div| div.bg(self.over_color))
|
||||
.when(self.value <= self.max_value, |div| div.bg(self.fg_color))
|
||||
.w(relative(fill_width)),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1449,10 +1449,7 @@ impl Pane {
|
||||
}
|
||||
});
|
||||
if dirty_project_item_ids.is_empty() {
|
||||
if item.is_singleton(cx) && item.is_dirty(cx) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return !(item.is_singleton(cx) && item.is_dirty(cx));
|
||||
}
|
||||
|
||||
for open_item in workspace.items(cx) {
|
||||
@@ -1465,11 +1462,7 @@ impl Pane {
|
||||
let other_project_item_ids = open_item.project_item_model_ids(cx);
|
||||
dirty_project_item_ids.retain(|id| !other_project_item_ids.contains(id));
|
||||
}
|
||||
if dirty_project_item_ids.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
return dirty_project_item_ids.is_empty();
|
||||
}
|
||||
|
||||
pub(super) fn file_names_for_prompt(
|
||||
|
||||
@@ -4999,7 +4999,10 @@ impl Workspace {
|
||||
|
||||
if let Some(location) = self.serialize_workspace_location(cx) {
|
||||
let breakpoints = self.project.update(cx, |project, cx| {
|
||||
project.breakpoint_store().read(cx).all_breakpoints(cx)
|
||||
project
|
||||
.breakpoint_store()
|
||||
.read(cx)
|
||||
.all_source_breakpoints(cx)
|
||||
});
|
||||
|
||||
let center_group = build_serialized_pane_group(&self.center.root, window, cx);
|
||||
|
||||
@@ -107,6 +107,15 @@ pub struct LoadedBinaryFile {
|
||||
pub content: Vec<u8>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for LoadedBinaryFile {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("LoadedBinaryFile")
|
||||
.field("file", &self.file)
|
||||
.field("content_bytes", &self.content.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalWorktree {
|
||||
snapshot: LocalSnapshot,
|
||||
scan_requests_tx: channel::Sender<ScanRequest>,
|
||||
@@ -3293,7 +3302,7 @@ impl fmt::Debug for Snapshot {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct File {
|
||||
pub worktree: Entity<Worktree>,
|
||||
pub path: Arc<Path>,
|
||||
|
||||
@@ -45,6 +45,7 @@ dap_adapters.workspace = true
|
||||
debugger_ui.workspace = true
|
||||
debugger_tools.workspace = true
|
||||
db.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
diagnostics.workspace = true
|
||||
editor.workspace = true
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -419,6 +419,7 @@ fn main() {
|
||||
.detach();
|
||||
let node_runtime = NodeRuntime::new(client.http_client(), Some(shell_env_loaded_rx), rx);
|
||||
|
||||
debug_adapter_extension::init(extension_host_proxy.clone(), cx);
|
||||
language::init(cx);
|
||||
language_extension::init(extension_host_proxy.clone(), languages.clone());
|
||||
languages::init(languages.clone(), node_runtime.clone(), cx);
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use crate::{ZED_PREDICT_DATA_COLLECTION_CHOICE, onboarding_event};
|
||||
use anyhow::Context as _;
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use client::{Client, UserStore};
|
||||
use db::kvp::KEY_VALUE_STORE;
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
@@ -384,47 +384,29 @@ impl Render for ZedPredictModal {
|
||||
} else {
|
||||
(IconName::ChevronDown, IconName::ChevronUp)
|
||||
};
|
||||
let plan = plan.unwrap_or(proto::Plan::Free);
|
||||
|
||||
base.child(Label::new(copy).color(Color::Muted))
|
||||
.child(h_flex().map(|parent| {
|
||||
if let Some(plan) = plan {
|
||||
parent.child(
|
||||
Checkbox::new("plan", ToggleState::Selected)
|
||||
.fill()
|
||||
.disabled(true)
|
||||
.label(format!(
|
||||
"You get {} edit predictions through your {}.",
|
||||
if plan == proto::Plan::Free {
|
||||
"2,000"
|
||||
} else {
|
||||
"unlimited"
|
||||
},
|
||||
match plan {
|
||||
proto::Plan::Free => "Zed Free plan",
|
||||
proto::Plan::ZedPro => "Zed Pro plan",
|
||||
proto::Plan::ZedProTrial => "Zed Pro trial",
|
||||
}
|
||||
)),
|
||||
)
|
||||
} else {
|
||||
parent
|
||||
.child(
|
||||
Checkbox::new("plan-required", ToggleState::Unselected)
|
||||
.fill()
|
||||
.disabled(true)
|
||||
.label("To get started with edit prediction"),
|
||||
)
|
||||
.child(
|
||||
Button::new("subscribe", "choose a plan")
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::Indicator)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(|_event, _window, cx| {
|
||||
cx.open_url(&zed_urls::account_url(cx));
|
||||
}),
|
||||
)
|
||||
}
|
||||
}))
|
||||
.child(
|
||||
h_flex().child(
|
||||
Checkbox::new("plan", ToggleState::Selected)
|
||||
.fill()
|
||||
.disabled(true)
|
||||
.label(format!(
|
||||
"You get {} edit predictions through your {}.",
|
||||
if plan == proto::Plan::Free {
|
||||
"2,000"
|
||||
} else {
|
||||
"unlimited"
|
||||
},
|
||||
match plan {
|
||||
proto::Plan::Free => "Zed Free plan",
|
||||
proto::Plan::ZedPro => "Zed Pro plan",
|
||||
proto::Plan::ZedProTrial => "Zed Pro trial",
|
||||
}
|
||||
)),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.child(
|
||||
@@ -495,7 +477,7 @@ impl Render for ZedPredictModal {
|
||||
.w_full()
|
||||
.child(
|
||||
Button::new("accept-tos", "Enable Edit Prediction")
|
||||
.disabled(plan.is_none() || !self.terms_of_service)
|
||||
.disabled(!self.terms_of_service)
|
||||
.style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.full_width()
|
||||
.on_click(cx.listener(Self::accept_and_enable)),
|
||||
|
||||
@@ -14,6 +14,7 @@ Here's an overview of the supported providers and tool call support:
|
||||
| [Anthropic](#anthropic) | ✅ |
|
||||
| [GitHub Copilot Chat](#github-copilot-chat) | In Some Cases |
|
||||
| [Google AI](#google-ai) | ✅ |
|
||||
| [Mistral](#mistral) | ✅ |
|
||||
| [Ollama](#ollama) | ✅ |
|
||||
| [OpenAI](#openai) | ✅ |
|
||||
| [DeepSeek](#deepseek) | 🚫 |
|
||||
@@ -128,6 +129,44 @@ By default Zed will use `stable` versions of models, but you can use specific ve
|
||||
|
||||
Custom models will be listed in the model dropdown in the Agent Panel.
|
||||
|
||||
### Mistral {#mistral}
|
||||
|
||||
> 🔨Supports tool use
|
||||
|
||||
1. Visit the Mistral platform and [create an API key](https://console.mistral.ai/api-keys/)
|
||||
2. Open the configuration view (`assistant: show configuration`) and navigate to the Mistral section
|
||||
3. Enter your Mistral API key
|
||||
|
||||
The Mistral API key will be saved in your keychain.
|
||||
|
||||
Zed will also use the `MISTRAL_API_KEY` environment variable if it's defined.
|
||||
|
||||
#### Mistral Custom Models {#mistral-custom-models}
|
||||
|
||||
The Zed Assistant comes pre-configured with several Mistral models (codestral-latest, mistral-large-latest, mistral-medium-latest, mistral-small-latest, open-mistral-nemo, and open-codestral-mamba). All the default models support tool use. If you wish to use alternate models or customize their parameters, you can do so by adding the following to your Zed `settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"language_models": {
|
||||
"mistral": {
|
||||
"api_url": "https://api.mistral.ai/v1",
|
||||
"available_models": [
|
||||
{
|
||||
"name": "mistral-tiny-latest",
|
||||
"display_name": "Mistral Tiny",
|
||||
"max_tokens": 32000,
|
||||
"max_output_tokens": 4096,
|
||||
"max_completion_tokens": 1024,
|
||||
"supports_tools": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Custom models will be listed in the model dropdown in the assistant panel.
|
||||
|
||||
### Ollama {#ollama}
|
||||
|
||||
> ✅ Supports tool use
|
||||
|
||||
Reference in New Issue
Block a user