Compare commits
76 Commits
v0.186.3-p
...
ask-profil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b31a6f4d2 | ||
|
|
85fda90993 | ||
|
|
b343a8aa22 | ||
|
|
3a3d3c05e8 | ||
|
|
ee56706d15 | ||
|
|
3cc8850a58 | ||
|
|
9f6809a28d | ||
|
|
032022e37b | ||
|
|
b091581e4b | ||
|
|
20387f24aa | ||
|
|
4b5158b168 | ||
|
|
a61958e886 | ||
|
|
d06d0e6a94 | ||
|
|
e8b67872ed | ||
|
|
fcf066aff5 | ||
|
|
b4109a2376 | ||
|
|
6565c091e4 | ||
|
|
d39c220f26 | ||
|
|
1ec466b728 | ||
|
|
a127ff4a4f | ||
|
|
f16f4303f4 | ||
|
|
3615d6d96c | ||
|
|
02ed4aefb8 | ||
|
|
6cc6e4d4b3 | ||
|
|
d6c7cdd60f | ||
|
|
37010aac6b | ||
|
|
6ac2f4e6a5 | ||
|
|
011aa715cf | ||
|
|
3339c84cdd | ||
|
|
9c1b2afa49 | ||
|
|
607a9445fc | ||
|
|
902931fdfc | ||
|
|
3c128ef83f | ||
|
|
0d726603ce | ||
|
|
466a53b51e | ||
|
|
358c324e26 | ||
|
|
6e19c9b141 | ||
|
|
77ac82587a | ||
|
|
7c76cee16d | ||
|
|
22ad207baf | ||
|
|
4469b7339f | ||
|
|
ea769455e4 | ||
|
|
89430a019c | ||
|
|
582ad845b9 | ||
|
|
1b3140d4ab | ||
|
|
625e45bac0 | ||
|
|
d50562ed81 | ||
|
|
a34fb6f6b1 | ||
|
|
5ca114be24 | ||
|
|
fcb9706022 | ||
|
|
0a44048af8 | ||
|
|
a4aa446a20 | ||
|
|
a4e26e0710 | ||
|
|
542c4a3d35 | ||
|
|
e44d167e56 | ||
|
|
c1d4f0873d | ||
|
|
48dfdc416b | ||
|
|
2618191785 | ||
|
|
1a520990cc | ||
|
|
7bc3f74cab | ||
|
|
02765947e0 | ||
|
|
f7e77123cc | ||
|
|
c19a5c2fd6 | ||
|
|
9711fb49dc | ||
|
|
c15e5d275a | ||
|
|
858d61a65e | ||
|
|
1fcd2647ed | ||
|
|
60d51d56cd | ||
|
|
03b635bb27 | ||
|
|
f7511c3f65 | ||
|
|
264097e253 | ||
|
|
795fadc0bc | ||
|
|
38975586d4 | ||
|
|
5539d82ea6 | ||
|
|
0cdd8bdded | ||
|
|
ab3e5cdc6c |
72
Cargo.lock
generated
72
Cargo.lock
generated
@@ -56,6 +56,7 @@ dependencies = [
|
||||
"assistant_context_editor",
|
||||
"assistant_settings",
|
||||
"assistant_slash_command",
|
||||
"assistant_slash_commands",
|
||||
"assistant_tool",
|
||||
"async-watch",
|
||||
"buffer_diff",
|
||||
@@ -78,6 +79,7 @@ dependencies = [
|
||||
"heed",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indexed_docs",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"jsonschema",
|
||||
@@ -470,68 +472,6 @@ dependencies = [
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assistant"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant_context_editor",
|
||||
"assistant_settings",
|
||||
"assistant_slash_command",
|
||||
"assistant_slash_commands",
|
||||
"assistant_tool",
|
||||
"async-watch",
|
||||
"client",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"ctor",
|
||||
"db",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"indexed_docs",
|
||||
"indoc",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_model_selector",
|
||||
"languages",
|
||||
"log",
|
||||
"lsp",
|
||||
"menu",
|
||||
"multi_buffer",
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_store",
|
||||
"proto",
|
||||
"rand 0.8.5",
|
||||
"rope",
|
||||
"rules_library",
|
||||
"schemars",
|
||||
"search",
|
||||
"serde",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smol",
|
||||
"streaming_diff",
|
||||
"telemetry",
|
||||
"telemetry_events",
|
||||
"terminal",
|
||||
"terminal_view",
|
||||
"text",
|
||||
"theme",
|
||||
"tree-sitter-md",
|
||||
"ui",
|
||||
"unindent",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "assistant_context_editor"
|
||||
version = "0.1.0"
|
||||
@@ -3002,7 +2942,6 @@ name = "collab"
|
||||
version = "0.44.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assistant",
|
||||
"assistant_context_editor",
|
||||
"assistant_settings",
|
||||
"assistant_slash_command",
|
||||
@@ -4068,7 +4007,6 @@ dependencies = [
|
||||
"http_client",
|
||||
"language",
|
||||
"log",
|
||||
"lsp-types",
|
||||
"node_runtime",
|
||||
"parking_lot",
|
||||
"paths",
|
||||
@@ -4104,7 +4042,6 @@ dependencies = [
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"language",
|
||||
"lsp-types",
|
||||
"paths",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -4249,6 +4186,7 @@ dependencies = [
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"dap",
|
||||
"dap_adapters",
|
||||
"db",
|
||||
"debugger_tools",
|
||||
"editor",
|
||||
@@ -7797,6 +7735,7 @@ dependencies = [
|
||||
"tree-sitter-html",
|
||||
"tree-sitter-json",
|
||||
"tree-sitter-md",
|
||||
"tree-sitter-python",
|
||||
"tree-sitter-ruby",
|
||||
"tree-sitter-rust",
|
||||
"tree-sitter-typescript",
|
||||
@@ -18694,7 +18633,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.186.3"
|
||||
version = "0.187.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -18702,7 +18641,6 @@ dependencies = [
|
||||
"ashpd",
|
||||
"askpass",
|
||||
"assets",
|
||||
"assistant",
|
||||
"assistant_context_editor",
|
||||
"assistant_settings",
|
||||
"assistant_tools",
|
||||
|
||||
@@ -6,7 +6,6 @@ members = [
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
"crates/assets",
|
||||
"crates/assistant",
|
||||
"crates/assistant_context_editor",
|
||||
"crates/assistant_settings",
|
||||
"crates/assistant_slash_command",
|
||||
@@ -214,7 +213,6 @@ ai = { path = "crates/ai" }
|
||||
anthropic = { path = "crates/anthropic" }
|
||||
askpass = { path = "crates/askpass" }
|
||||
assets = { path = "crates/assets" }
|
||||
assistant = { path = "crates/assistant" }
|
||||
assistant_context_editor = { path = "crates/assistant_context_editor" }
|
||||
assistant_settings = { path = "crates/assistant_settings" }
|
||||
assistant_slash_command = { path = "crates/assistant_slash_command" }
|
||||
|
||||
@@ -213,21 +213,6 @@
|
||||
"ctrl-shift-n": "agent::RejectAll"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "AssistantPanel",
|
||||
"bindings": {
|
||||
"ctrl-k c": "assistant::CopyCode",
|
||||
"ctrl-shift-e": "project_panel::ToggleFocus",
|
||||
"ctrl-g": "search::SelectNextMatch",
|
||||
"ctrl-shift-g": "search::SelectPreviousMatch",
|
||||
"ctrl-alt-/": "agent::ToggleModelSelector",
|
||||
"ctrl-k h": "assistant::DeployHistory",
|
||||
"ctrl-k l": "agent::OpenRulesLibrary",
|
||||
"new": "assistant::NewChat",
|
||||
"ctrl-t": "assistant::NewChat",
|
||||
"ctrl-n": "assistant::NewChat"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "ContextEditor > Editor",
|
||||
"bindings": {
|
||||
@@ -713,8 +698,8 @@
|
||||
{
|
||||
"context": "PromptEditor",
|
||||
"bindings": {
|
||||
"ctrl-[": "assistant::CyclePreviousInlineAssist",
|
||||
"ctrl-]": "assistant::CycleNextInlineAssist",
|
||||
"ctrl-[": "agent::CyclePreviousInlineAssist",
|
||||
"ctrl-]": "agent::CycleNextInlineAssist",
|
||||
"ctrl-alt-e": "agent::RemoveAllContext"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -258,21 +258,6 @@
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "AssistantPanel",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-k c": "assistant::CopyCode",
|
||||
"cmd-shift-e": "project_panel::ToggleFocus",
|
||||
"cmd-g": "search::SelectNextMatch",
|
||||
"cmd-shift-g": "search::SelectPreviousMatch",
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
"cmd-k h": "assistant::DeployHistory",
|
||||
"cmd-k l": "agent::OpenRulesLibrary",
|
||||
"cmd-t": "assistant::NewChat",
|
||||
"cmd-n": "assistant::NewChat"
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "ContextEditor > Editor",
|
||||
"use_key_equivalents": true,
|
||||
@@ -780,8 +765,8 @@
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
"cmd-alt-e": "agent::RemoveAllContext",
|
||||
"ctrl-[": "assistant::CyclePreviousInlineAssist",
|
||||
"ctrl-]": "assistant::CycleNextInlineAssist"
|
||||
"ctrl-[": "agent::CyclePreviousInlineAssist",
|
||||
"ctrl-]": "agent::CycleNextInlineAssist"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -356,6 +356,49 @@
|
||||
"vertical": true
|
||||
}
|
||||
},
|
||||
// Minimap related settings
|
||||
"minimap": {
|
||||
// When to show the minimap in the editor.
|
||||
// This setting can take three values:
|
||||
// 1. Show the minimap if the editor's scrollbar is visible:
|
||||
// "auto"
|
||||
// 2. Always show the minimap:
|
||||
// "always"
|
||||
// 3. Never show the minimap:
|
||||
// "never" (default)
|
||||
"show": "never",
|
||||
// When to show the minimap thumb.
|
||||
// This setting can take two values:
|
||||
// 1. Show the minimap thumb if the mouse is over the minimap:
|
||||
// "hover"
|
||||
// 2. Always show the minimap thumb:
|
||||
// "always" (default)
|
||||
"thumb": "always",
|
||||
// How the minimap thumb border should look.
|
||||
// This setting can take five values:
|
||||
// 1. Display a border on all sides of the thumb:
|
||||
// "thumb_border": "full"
|
||||
// 2. Display a border on all sides except the left side of the thumb:
|
||||
// "thumb_border": "left_open" (default)
|
||||
// 3. Display a border on all sides except the right side of the thumb:
|
||||
// "thumb_border": "right_open"
|
||||
// 4. Display a border only on the left side of the thumb:
|
||||
// "thumb_border": "left_only"
|
||||
// 5. Display the thumb without any border:
|
||||
// "thumb_border": "none"
|
||||
"thumb_border": "left_open",
|
||||
// How to highlight the current line in the minimap.
|
||||
// This setting can take the following values:
|
||||
//
|
||||
// 1. `null` to inherit the editor `current_line_highlight` setting (default)
|
||||
// 2. "line" or "all" to highlight the current line in the minimap.
|
||||
// 3. "gutter" or "none" to not highlight the current line in the minimap.
|
||||
"current_line_highlight": null,
|
||||
// The width of the minimap in pixels.
|
||||
"width": 100,
|
||||
// The font size of the minimap in pixels.
|
||||
"font_size": 2
|
||||
},
|
||||
// Enable middle-click paste on Linux.
|
||||
"middle_click_paste": true,
|
||||
// What to do when multibuffer is double clicked in some of its excerpts
|
||||
@@ -370,8 +413,6 @@
|
||||
"gutter": {
|
||||
// Whether to show line numbers in the gutter.
|
||||
"line_numbers": true,
|
||||
// Whether to show code action buttons in the gutter.
|
||||
"code_actions": true,
|
||||
// Whether to show runnables buttons in the gutter.
|
||||
"runnables": true,
|
||||
// Whether to show breakpoints in the gutter.
|
||||
|
||||
@@ -43,6 +43,7 @@ pub struct ActivityIndicator {
|
||||
context_menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ServerStatus {
|
||||
name: SharedString,
|
||||
status: BinaryStatus,
|
||||
@@ -70,6 +71,7 @@ impl ActivityIndicator {
|
||||
) -> Entity<ActivityIndicator> {
|
||||
let project = workspace.project().clone();
|
||||
let auto_updater = AutoUpdater::get(cx);
|
||||
let workspace_handle = cx.entity();
|
||||
let this = cx.new(|cx| {
|
||||
let mut status_events = languages.language_server_binary_statuses();
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -84,6 +86,25 @@ impl ActivityIndicator {
|
||||
})
|
||||
.detach();
|
||||
|
||||
cx.subscribe_in(
|
||||
&workspace_handle,
|
||||
window,
|
||||
|activity_indicator, _, event, window, cx| match event {
|
||||
workspace::Event::ClearActivityIndicator { .. } => {
|
||||
if activity_indicator.statuses.pop().is_some() {
|
||||
activity_indicator.dismiss_error_message(
|
||||
&DismissErrorMessage,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
|
||||
cx.subscribe(
|
||||
&project.read(cx).lsp_store(),
|
||||
|_, _, event, cx| match event {
|
||||
@@ -115,7 +136,7 @@ impl ActivityIndicator {
|
||||
}
|
||||
|
||||
Self {
|
||||
statuses: Default::default(),
|
||||
statuses: Vec::new(),
|
||||
project: project.clone(),
|
||||
auto_updater,
|
||||
context_menu_handle: Default::default(),
|
||||
@@ -185,11 +206,8 @@ impl ActivityIndicator {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(updater) = &self.auto_updater {
|
||||
updater.update(cx, |updater, cx| {
|
||||
updater.dismiss_error(cx);
|
||||
});
|
||||
updater.update(cx, |updater, cx| updater.dismiss_error(cx));
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn pending_language_server_work<'a>(
|
||||
|
||||
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant.rs"
|
||||
path = "src/agent.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
@@ -23,6 +23,7 @@ anyhow.workspace = true
|
||||
assistant_context_editor.workspace = true
|
||||
assistant_settings.workspace = true
|
||||
assistant_slash_command.workspace = true
|
||||
assistant_slash_commands.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-watch.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
@@ -45,6 +46,7 @@ gpui.workspace = true
|
||||
heed.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
language.workspace = true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::AssistantPanel;
|
||||
use crate::AgentPanel;
|
||||
use crate::context::{AgentContextHandle, RULES_ICON};
|
||||
use crate::context_picker::{ContextPicker, MentionLink};
|
||||
use crate::context_store::ContextStore;
|
||||
@@ -712,7 +712,7 @@ fn open_markdown_link(
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
Some(MentionLink::Thread(thread_id)) => workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel
|
||||
.open_thread_by_id(&thread_id, window, cx)
|
||||
@@ -721,7 +721,7 @@ fn open_markdown_link(
|
||||
}
|
||||
}),
|
||||
Some(MentionLink::TextThread(path)) => workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel
|
||||
.open_saved_prompt_editor(path, window, cx)
|
||||
@@ -1211,8 +1211,7 @@ impl ActiveThread {
|
||||
|
||||
if let Some(workspace) = workspace_handle.upgrade() {
|
||||
workspace.update(_cx, |workspace, cx| {
|
||||
workspace
|
||||
.focus_panel::<AssistantPanel>(window, cx);
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
@@ -1282,9 +1281,6 @@ impl ActiveThread {
|
||||
return;
|
||||
};
|
||||
|
||||
// Cancel any ongoing streaming when user starts editing a previous message
|
||||
self.cancel_last_completion(window, cx);
|
||||
|
||||
let editor = crate::message_editor::create_editor(
|
||||
self.workspace.clone(),
|
||||
self.context_store.downgrade(),
|
||||
@@ -1415,6 +1411,7 @@ impl ActiveThread {
|
||||
mode: None,
|
||||
messages: vec![request_message],
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
stop: vec![],
|
||||
temperature: AssistantSettings::temperature_for_model(
|
||||
&configured_model.model,
|
||||
@@ -3260,7 +3257,7 @@ impl ActiveThread {
|
||||
c.tool_use_id.clone(),
|
||||
c.ui_text.clone(),
|
||||
c.input.clone(),
|
||||
&c.messages,
|
||||
c.request.clone(),
|
||||
c.tool.clone(),
|
||||
configured.model,
|
||||
Some(window.window_handle()),
|
||||
@@ -3527,7 +3524,7 @@ pub(crate) fn open_context(
|
||||
}
|
||||
|
||||
AgentContextHandle::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.open_thread(thread_context.thread.clone(), window, cx);
|
||||
});
|
||||
@@ -3536,7 +3533,7 @@ pub(crate) fn open_context(
|
||||
|
||||
AgentContextHandle::TextThread(text_thread_context) => {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.open_prompt_editor(text_thread_context.context.clone(), window, cx)
|
||||
});
|
||||
@@ -3580,152 +3577,3 @@ fn open_editor_at_position(
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use assistant_tool::{ToolRegistry, ToolWorkingSet};
|
||||
use editor::EditorSettings;
|
||||
use fs::FakeFs;
|
||||
use gpui::{TestAppContext, VisualTestContext};
|
||||
use language_model::{LanguageModel, fake_provider::FakeLanguageModel};
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
use crate::{ContextLoadResult, thread_store};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_current_completion_cancelled_when_message_edited(cx: &mut TestAppContext) {
|
||||
init_test_settings(cx);
|
||||
|
||||
let project = create_test_project(
|
||||
cx,
|
||||
json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let (cx, active_thread, thread, model) = setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Insert user message without any context (empty context vector)
|
||||
let message = thread.update(cx, |thread, cx| {
|
||||
let message_id = thread.insert_user_message(
|
||||
"What is the best way to learn Rust?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
vec![],
|
||||
cx,
|
||||
);
|
||||
thread
|
||||
.message(message_id)
|
||||
.expect("message should exist")
|
||||
.clone()
|
||||
});
|
||||
|
||||
// Stream response to user message
|
||||
thread.update(cx, |thread, cx| {
|
||||
let request = thread.to_completion_request(model.clone(), cx);
|
||||
thread.stream_completion(request, model, cx.active_window(), cx)
|
||||
});
|
||||
let generating = thread.update(cx, |thread, _cx| thread.is_generating());
|
||||
assert!(generating, "There should be one pending completion");
|
||||
|
||||
// Edit the previous message
|
||||
active_thread.update_in(cx, |active_thread, window, cx| {
|
||||
active_thread.start_editing_message(message.id, &message.segments, window, cx);
|
||||
});
|
||||
|
||||
// Check that the stream was cancelled
|
||||
let generating = thread.update(cx, |thread, _cx| thread.is_generating());
|
||||
assert!(!generating, "The completion should have been cancelled");
|
||||
}
|
||||
|
||||
fn init_test_settings(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
language::init(cx);
|
||||
Project::init_settings(cx);
|
||||
AssistantSettings::register(cx);
|
||||
prompt_store::init(cx);
|
||||
thread_store::init(cx);
|
||||
workspace::init_settings(cx);
|
||||
language_model::init_settings(cx);
|
||||
ThemeSettings::register(cx);
|
||||
EditorSettings::register(cx);
|
||||
ToolRegistry::default_global(cx);
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to create a test project with test files
|
||||
async fn create_test_project(
|
||||
cx: &mut TestAppContext,
|
||||
files: serde_json::Value,
|
||||
) -> Entity<Project> {
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), files).await;
|
||||
Project::test(fs, [path!("/test").as_ref()], cx).await
|
||||
}
|
||||
|
||||
async fn setup_test_environment(
|
||||
cx: &mut TestAppContext,
|
||||
project: Entity<Project>,
|
||||
) -> (
|
||||
&mut VisualTestContext,
|
||||
Entity<ActiveThread>,
|
||||
Entity<Thread>,
|
||||
Arc<dyn LanguageModel>,
|
||||
) {
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let thread_store = cx
|
||||
.update(|_, cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
None,
|
||||
prompt_builder.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let text_thread_store = cx
|
||||
.update(|_, cx| {
|
||||
TextThreadStore::new(project.clone(), prompt_builder, Default::default(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
|
||||
|
||||
let model = FakeLanguageModel::default();
|
||||
let model: Arc<dyn LanguageModel> = Arc::new(model);
|
||||
|
||||
let language_registry = LanguageRegistry::new(cx.executor());
|
||||
let language_registry = Arc::new(language_registry);
|
||||
|
||||
let active_thread = cx.update(|window, cx| {
|
||||
cx.new(|cx| {
|
||||
ActiveThread::new(
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
text_thread_store.clone(),
|
||||
context_store.clone(),
|
||||
language_registry.clone(),
|
||||
workspace.downgrade(),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
(cx, active_thread, thread, model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,102 +1,123 @@
|
||||
#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
|
||||
|
||||
mod assistant_configuration;
|
||||
pub mod assistant_panel;
|
||||
mod active_thread;
|
||||
mod agent_configuration;
|
||||
mod agent_diff;
|
||||
mod agent_model_selector;
|
||||
mod agent_panel;
|
||||
mod buffer_codegen;
|
||||
mod context;
|
||||
mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_server_tool;
|
||||
mod context_store;
|
||||
mod context_strip;
|
||||
mod debug;
|
||||
mod history_store;
|
||||
mod inline_assistant;
|
||||
pub mod slash_command_settings;
|
||||
mod inline_prompt_editor;
|
||||
mod message_editor;
|
||||
mod profile_selector;
|
||||
mod slash_command_settings;
|
||||
mod terminal_codegen;
|
||||
mod terminal_inline_assistant;
|
||||
mod thread;
|
||||
mod thread_history;
|
||||
mod thread_store;
|
||||
mod tool_compatibility;
|
||||
mod tool_use;
|
||||
mod ui;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{AssistantSettings, LanguageModelSelection};
|
||||
use assistant_settings::{AgentProfileId, AssistantSettings, LanguageModelSelection};
|
||||
use assistant_slash_command::SlashCommandRegistry;
|
||||
use client::Client;
|
||||
use command_palette_hooks::CommandPaletteFilter;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use fs::Fs;
|
||||
use gpui::{App, Global, ReadGlobal, UpdateGlobal, actions};
|
||||
use language_model::{
|
||||
LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
|
||||
};
|
||||
use gpui::{App, actions, impl_actions};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{LanguageModelId, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use thread::ThreadId;
|
||||
|
||||
pub use crate::assistant_panel::{AssistantPanel, AssistantPanelEvent};
|
||||
pub(crate) use crate::inline_assistant::*;
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::agent_configuration::{AddContextServerModal, ManageProfilesModal};
|
||||
pub use crate::agent_panel::{AgentPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
use crate::slash_command_settings::SlashCommandSettings;
|
||||
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::{TextThreadStore, ThreadStore};
|
||||
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
|
||||
pub use context_store::ContextStore;
|
||||
pub use ui::preview::{all_agent_previews, get_agent_preview};
|
||||
|
||||
actions!(
|
||||
assistant,
|
||||
agent,
|
||||
[
|
||||
InsertActivePrompt,
|
||||
DeployHistory,
|
||||
NewChat,
|
||||
NewTextThread,
|
||||
ToggleContextPicker,
|
||||
ToggleNavigationMenu,
|
||||
ToggleOptionsMenu,
|
||||
DeleteRecentlyOpenThread,
|
||||
ToggleProfileSelector,
|
||||
RemoveAllContext,
|
||||
ExpandMessageEditor,
|
||||
OpenHistory,
|
||||
AddContextServer,
|
||||
RemoveSelectedThread,
|
||||
Chat,
|
||||
CycleNextInlineAssist,
|
||||
CyclePreviousInlineAssist
|
||||
CyclePreviousInlineAssist,
|
||||
FocusUp,
|
||||
FocusDown,
|
||||
FocusLeft,
|
||||
FocusRight,
|
||||
RemoveFocusedContext,
|
||||
AcceptSuggestedContext,
|
||||
OpenActiveThreadAsMarkdown,
|
||||
OpenAgentDiff,
|
||||
Keep,
|
||||
Reject,
|
||||
RejectAll,
|
||||
KeepAll,
|
||||
Follow,
|
||||
ResetTrialUpsell,
|
||||
]
|
||||
);
|
||||
|
||||
const DEFAULT_CONTEXT_LINES: usize = 50;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema)]
|
||||
pub struct NewThread {
|
||||
#[serde(default)]
|
||||
from_thread_id: Option<ThreadId>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct LanguageModelChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: LanguageModelResponseMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema)]
|
||||
pub struct ManageProfiles {
|
||||
#[serde(default)]
|
||||
pub customize_tools: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
/// The state pertaining to the Assistant.
|
||||
#[derive(Default)]
|
||||
struct Assistant {
|
||||
/// Whether the Assistant is enabled.
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl Global for Assistant {}
|
||||
|
||||
impl Assistant {
|
||||
const NAMESPACE: &'static str = "assistant";
|
||||
|
||||
fn set_enabled(&mut self, enabled: bool, cx: &mut App) {
|
||||
if self.enabled == enabled {
|
||||
return;
|
||||
impl ManageProfiles {
|
||||
pub fn customize_tools(profile_id: AgentProfileId) -> Self {
|
||||
Self {
|
||||
customize_tools: Some(profile_id),
|
||||
}
|
||||
|
||||
self.enabled = enabled;
|
||||
|
||||
if !enabled {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(Self::NAMESPACE);
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.show_namespace(Self::NAMESPACE);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn enabled(cx: &App) -> bool {
|
||||
Self::global(cx).enabled
|
||||
}
|
||||
}
|
||||
|
||||
impl_actions!(agent, [NewThread, ManageProfiles]);
|
||||
|
||||
/// Initializes the `agent` crate.
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
client: Arc<Client>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
cx.set_global(Assistant::default());
|
||||
AssistantSettings::register(cx);
|
||||
SlashCommandSettings::register(cx);
|
||||
|
||||
@@ -104,8 +125,9 @@ pub fn init(
|
||||
rules_library::init(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
thread_store::init(cx);
|
||||
agent_panel::init(cx);
|
||||
context_server_configuration::init(language_registry, cx);
|
||||
|
||||
register_slash_commands(cx);
|
||||
inline_assistant::init(
|
||||
@@ -121,22 +143,8 @@ pub fn init(
|
||||
cx,
|
||||
);
|
||||
indexed_docs::init(cx);
|
||||
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
filter.hide_namespace(Assistant::NAMESPACE);
|
||||
});
|
||||
Assistant::update_global(cx, |assistant, cx| {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
|
||||
assistant.set_enabled(settings.enabled, cx);
|
||||
});
|
||||
cx.observe_global::<SettingsStore>(|cx| {
|
||||
Assistant::update_global(cx, |assistant, cx| {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
assistant.set_enabled(settings.enabled, cx);
|
||||
});
|
||||
})
|
||||
.detach();
|
||||
cx.observe_new(AddContextServerModal::register).detach();
|
||||
cx.observe_new(ManageProfilesModal::register).detach();
|
||||
}
|
||||
|
||||
fn init_language_model_settings(cx: &mut App) {
|
||||
@@ -250,11 +258,3 @@ fn update_slash_commands_from_settings(cx: &mut App) {
|
||||
.unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
env_logger::init();
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,7 @@ pub(crate) use manage_profiles_modal::ManageProfilesModal;
|
||||
|
||||
use crate::AddContextServer;
|
||||
|
||||
pub struct AssistantConfiguration {
|
||||
pub struct AgentConfiguration {
|
||||
fs: Arc<dyn Fs>,
|
||||
focus_handle: FocusHandle,
|
||||
configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
|
||||
@@ -42,7 +42,7 @@ pub struct AssistantConfiguration {
|
||||
scrollbar_state: ScrollbarState,
|
||||
}
|
||||
|
||||
impl AssistantConfiguration {
|
||||
impl AgentConfiguration {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
context_server_store: Entity<ContextServerStore>,
|
||||
@@ -110,7 +110,7 @@ impl AssistantConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AssistantConfiguration {
|
||||
impl Focusable for AgentConfiguration {
|
||||
fn focus_handle(&self, _: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
@@ -120,9 +120,9 @@ pub enum AssistantConfigurationEvent {
|
||||
NewThread(Arc<dyn LanguageModelProvider>),
|
||||
}
|
||||
|
||||
impl EventEmitter<AssistantConfigurationEvent> for AssistantConfiguration {}
|
||||
impl EventEmitter<AssistantConfigurationEvent> for AgentConfiguration {}
|
||||
|
||||
impl AssistantConfiguration {
|
||||
impl AgentConfiguration {
|
||||
fn render_provider_configuration_block(
|
||||
&mut self,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
@@ -571,7 +571,7 @@ impl AssistantConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantConfiguration {
|
||||
impl Render for AgentConfiguration {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
v_flex()
|
||||
.id("assistant-configuration")
|
||||
@@ -18,9 +18,9 @@ use ui::{
|
||||
use util::ResultExt as _;
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
use crate::assistant_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader;
|
||||
use crate::assistant_configuration::tool_picker::{ToolPicker, ToolPickerDelegate};
|
||||
use crate::{AssistantPanel, ManageProfiles, ThreadStore};
|
||||
use crate::agent_configuration::manage_profiles_modal::profile_modal_header::ProfileModalHeader;
|
||||
use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate};
|
||||
use crate::{AgentPanel, ManageProfiles, ThreadStore};
|
||||
|
||||
use super::tool_picker::ToolPickerMode;
|
||||
|
||||
@@ -115,7 +115,7 @@ impl ManageProfilesModal {
|
||||
_cx: &mut Context<Workspace>,
|
||||
) {
|
||||
workspace.register_action(|workspace, action: &ManageProfiles, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
let thread_store = panel.read(cx).thread_store();
|
||||
let tools = thread_store.read(cx).tools();
|
||||
@@ -124,7 +124,7 @@ impl ManageProfilesModal {
|
||||
let mut this = Self::new(fs, tools, thread_store, window, cx);
|
||||
|
||||
if let Some(profile_id) = action.customize_tools.clone() {
|
||||
this.configure_tools(profile_id, window, cx);
|
||||
this.configure_builtin_tools(profile_id, window, cx);
|
||||
}
|
||||
|
||||
this
|
||||
@@ -190,7 +190,7 @@ impl ManageProfilesModal {
|
||||
self.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn configure_mcps(
|
||||
fn configure_mcp_tools(
|
||||
&mut self,
|
||||
profile_id: AgentProfileId,
|
||||
window: &mut Window,
|
||||
@@ -228,7 +228,7 @@ impl ManageProfilesModal {
|
||||
self.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn configure_tools(
|
||||
fn configure_builtin_tools(
|
||||
&mut self,
|
||||
profile_id: AgentProfileId,
|
||||
window: &mut Window,
|
||||
@@ -581,16 +581,20 @@ impl ManageProfilesModal {
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("configure-tools")
|
||||
.id("configure-builtin-tools")
|
||||
.track_focus(&mode.configure_tools.focus_handle)
|
||||
.on_action({
|
||||
let profile_id = mode.profile_id.clone();
|
||||
cx.listener(move |this, _: &menu::Confirm, window, cx| {
|
||||
this.configure_tools(profile_id.clone(), window, cx);
|
||||
this.configure_builtin_tools(
|
||||
profile_id.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
})
|
||||
.child(
|
||||
ListItem::new("configure-tools")
|
||||
ListItem::new("configure-builtin-tools-item")
|
||||
.toggle_state(
|
||||
mode.configure_tools
|
||||
.focus_handle
|
||||
@@ -603,11 +607,11 @@ impl ManageProfilesModal {
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new("Configure Tools"))
|
||||
.child(Label::new("Configure Built-in Tools"))
|
||||
.on_click({
|
||||
let profile_id = mode.profile_id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.configure_tools(
|
||||
this.configure_builtin_tools(
|
||||
profile_id.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -623,11 +627,11 @@ impl ManageProfilesModal {
|
||||
.on_action({
|
||||
let profile_id = mode.profile_id.clone();
|
||||
cx.listener(move |this, _: &menu::Confirm, window, cx| {
|
||||
this.configure_mcps(profile_id.clone(), window, cx);
|
||||
this.configure_mcp_tools(profile_id.clone(), window, cx);
|
||||
})
|
||||
})
|
||||
.child(
|
||||
ListItem::new("configure-mcps")
|
||||
ListItem::new("configure-mcp-tools")
|
||||
.toggle_state(
|
||||
mode.configure_mcps
|
||||
.focus_handle
|
||||
@@ -640,11 +644,15 @@ impl ManageProfilesModal {
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(Label::new("Configure MCP Servers"))
|
||||
.child(Label::new("Configure MCP Tools"))
|
||||
.on_click({
|
||||
let profile_id = mode.profile_id.clone();
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
this.configure_mcps(profile_id.clone(), window, cx);
|
||||
this.configure_mcp_tools(
|
||||
profile_id.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
}),
|
||||
),
|
||||
@@ -777,7 +785,7 @@ impl Render for ManageProfilesModal {
|
||||
v_flex()
|
||||
.pb_1()
|
||||
.child(ProfileModalHeader::new(
|
||||
format!("{profile_name} — Configure Tools"),
|
||||
format!("{profile_name} — Configure Built-in Tools"),
|
||||
Some(IconName::Cog),
|
||||
))
|
||||
.child(ListSeparator)
|
||||
@@ -800,7 +808,7 @@ impl Render for ManageProfilesModal {
|
||||
v_flex()
|
||||
.pb_1()
|
||||
.child(ProfileModalHeader::new(
|
||||
format!("{profile_name} — Configure MCP Servers"),
|
||||
format!("{profile_name} — Configure MCP Tools"),
|
||||
Some(IconName::Hammer),
|
||||
))
|
||||
.child(ListSeparator)
|
||||
@@ -176,7 +176,7 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
||||
match self.mode {
|
||||
ToolPickerMode::BuiltinTools => "Search built-in tools…",
|
||||
ToolPickerMode::McpTools => "Search MCP servers…",
|
||||
ToolPickerMode::McpTools => "Search MCP tools…",
|
||||
}
|
||||
.into()
|
||||
}
|
||||
@@ -17,13 +17,13 @@ pub enum ModelType {
|
||||
InlineAssistant,
|
||||
}
|
||||
|
||||
pub struct AssistantModelSelector {
|
||||
pub struct AgentModelSelector {
|
||||
selector: Entity<LanguageModelSelector>,
|
||||
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
|
||||
impl AssistantModelSelector {
|
||||
impl AgentModelSelector {
|
||||
pub(crate) fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
menu_handle: PopoverMenuHandle<LanguageModelSelector>,
|
||||
@@ -99,7 +99,7 @@ impl AssistantModelSelector {
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantModelSelector {
|
||||
impl Render for AgentModelSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
@@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_context_editor::{
|
||||
AssistantContext, AssistantPanelDelegate, ConfigurationError, ContextEditor, ContextEvent,
|
||||
AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent,
|
||||
SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate,
|
||||
render_remaining_tokens,
|
||||
};
|
||||
@@ -53,8 +53,8 @@ use zed_actions::{DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFon
|
||||
use zed_llm_client::UsageLimit;
|
||||
|
||||
use crate::active_thread::{self, ActiveThread, ActiveThreadEvent};
|
||||
use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent};
|
||||
use crate::agent_diff::AgentDiff;
|
||||
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
|
||||
use crate::history_store::{HistoryEntry, HistoryStore, RecentEntry};
|
||||
use crate::message_editor::{MessageEditor, MessageEditorEvent};
|
||||
use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio};
|
||||
@@ -71,7 +71,7 @@ use crate::{
|
||||
const AGENT_PANEL_KEY: &str = "agent_panel";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SerializedAssistantPanel {
|
||||
struct SerializedAgentPanel {
|
||||
width: Option<Pixels>,
|
||||
}
|
||||
|
||||
@@ -80,40 +80,40 @@ pub fn init(cx: &mut App) {
|
||||
|workspace: &mut Workspace, _window, _cx: &mut Context<Workspace>| {
|
||||
workspace
|
||||
.register_action(|workspace, action: &NewThread, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| panel.new_thread(action, window, cx));
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &OpenHistory, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.open_history(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &OpenConfiguration, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &NewTextThread, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &OpenAgentDiff, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
let thread = panel.read(cx).thread.read(cx).thread().clone();
|
||||
AgentDiffPane::deploy_in_workspace(thread, workspace, window, cx);
|
||||
}
|
||||
@@ -122,8 +122,8 @@ pub fn init(cx: &mut App) {
|
||||
workspace.follow(CollaboratorId::Agent, window, cx);
|
||||
})
|
||||
.register_action(|workspace, _: &ExpandMessageEditor, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.message_editor.update(cx, |editor, cx| {
|
||||
editor.expand_message_editor(&ExpandMessageEditor, window, cx);
|
||||
@@ -132,16 +132,16 @@ pub fn init(cx: &mut App) {
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &ToggleNavigationMenu, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.toggle_navigation_menu(&ToggleNavigationMenu, window, cx);
|
||||
});
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &ToggleOptionsMenu, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.toggle_options_menu(&ToggleOptionsMenu, window, cx);
|
||||
});
|
||||
@@ -335,7 +335,7 @@ impl ActiveView {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AssistantPanel {
|
||||
pub struct AgentPanel {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
user_store: Entity<UserStore>,
|
||||
project: Entity<Project>,
|
||||
@@ -349,7 +349,7 @@ pub struct AssistantPanel {
|
||||
context_store: Entity<TextThreadStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
inline_assist_context_store: Entity<crate::context_store::ContextStore>,
|
||||
configuration: Option<Entity<AssistantConfiguration>>,
|
||||
configuration: Option<Entity<AgentConfiguration>>,
|
||||
configuration_subscription: Option<Subscription>,
|
||||
local_timezone: UtcOffset,
|
||||
active_view: ActiveView,
|
||||
@@ -366,14 +366,14 @@ pub struct AssistantPanel {
|
||||
_trial_markdown: Entity<Markdown>,
|
||||
}
|
||||
|
||||
impl AssistantPanel {
|
||||
impl AgentPanel {
|
||||
fn serialize(&mut self, cx: &mut Context<Self>) {
|
||||
let width = self.width;
|
||||
self.pending_serialization = Some(cx.background_spawn(async move {
|
||||
KEY_VALUE_STORE
|
||||
.write_kvp(
|
||||
AGENT_PANEL_KEY.into(),
|
||||
serde_json::to_string(&SerializedAssistantPanel { width })?,
|
||||
serde_json::to_string(&SerializedAgentPanel { width })?,
|
||||
)
|
||||
.await?;
|
||||
anyhow::Ok(())
|
||||
@@ -423,7 +423,7 @@ impl AssistantPanel {
|
||||
.log_err()
|
||||
.flatten()
|
||||
{
|
||||
Some(serde_json::from_str::<SerializedAssistantPanel>(&panel)?)
|
||||
Some(serde_json::from_str::<SerializedAgentPanel>(&panel)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -1163,15 +1163,13 @@ impl AssistantPanel {
|
||||
|
||||
self.set_active_view(ActiveView::Configuration, window, cx);
|
||||
self.configuration =
|
||||
Some(cx.new(|cx| {
|
||||
AssistantConfiguration::new(fs, context_server_store, tools, window, cx)
|
||||
}));
|
||||
Some(cx.new(|cx| AgentConfiguration::new(fs, context_server_store, tools, window, cx)));
|
||||
|
||||
if let Some(configuration) = self.configuration.as_ref() {
|
||||
self.configuration_subscription = Some(cx.subscribe_in(
|
||||
configuration,
|
||||
window,
|
||||
Self::handle_assistant_configuration_event,
|
||||
Self::handle_agent_configuration_event,
|
||||
));
|
||||
|
||||
configuration.focus_handle(cx).focus(window);
|
||||
@@ -1201,9 +1199,9 @@ impl AssistantPanel {
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn handle_assistant_configuration_event(
|
||||
fn handle_agent_configuration_event(
|
||||
&mut self,
|
||||
_entity: &Entity<AssistantConfiguration>,
|
||||
_entity: &Entity<AgentConfiguration>,
|
||||
event: &AssistantConfigurationEvent,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -1316,7 +1314,7 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for AssistantPanel {
|
||||
impl Focusable for AgentPanel {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
match &self.active_view {
|
||||
ActiveView::Thread { .. } => self.message_editor.focus_handle(cx),
|
||||
@@ -1341,9 +1339,9 @@ fn agent_panel_dock_position(cx: &App) -> DockPosition {
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<PanelEvent> for AssistantPanel {}
|
||||
impl EventEmitter<PanelEvent> for AgentPanel {}
|
||||
|
||||
impl Panel for AssistantPanel {
|
||||
impl Panel for AgentPanel {
|
||||
fn persistent_name() -> &'static str {
|
||||
"AgentPanel"
|
||||
}
|
||||
@@ -1418,7 +1416,7 @@ impl Panel for AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantPanel {
|
||||
impl AgentPanel {
|
||||
fn render_title_view(&self, _window: &mut Window, cx: &Context<Self>) -> AnyElement {
|
||||
const LOADING_SUMMARY_PLACEHOLDER: &str = "Loading Summary…";
|
||||
|
||||
@@ -1977,9 +1975,9 @@ impl AssistantPanel {
|
||||
.style(ButtonStyle::Transparent)
|
||||
.color(Color::Muted)
|
||||
.on_click({
|
||||
let assistant_panel = cx.entity();
|
||||
let agent_panel = cx.entity();
|
||||
move |_, _, cx| {
|
||||
assistant_panel.update(
|
||||
agent_panel.update(
|
||||
cx,
|
||||
|this, cx| {
|
||||
let hidden =
|
||||
@@ -2744,7 +2742,7 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AssistantPanel {
|
||||
impl Render for AgentPanel {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let content = match &self.active_view {
|
||||
ActiveView::Thread { .. } => v_flex()
|
||||
@@ -2855,28 +2853,26 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
})
|
||||
}
|
||||
|
||||
fn focus_assistant_panel(
|
||||
fn focus_agent_panel(
|
||||
&self,
|
||||
workspace: &mut Workspace,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) -> bool {
|
||||
workspace
|
||||
.focus_panel::<AssistantPanel>(window, cx)
|
||||
.is_some()
|
||||
workspace.focus_panel::<AgentPanel>(window, cx).is_some()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConcreteAssistantPanelDelegate;
|
||||
|
||||
impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
impl AgentPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
fn active_context_editor(
|
||||
&self,
|
||||
workspace: &mut Workspace,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) -> Option<Entity<ContextEditor>> {
|
||||
let panel = workspace.panel::<AssistantPanel>(cx)?;
|
||||
let panel = workspace.panel::<AgentPanel>(cx)?;
|
||||
panel.read(cx).active_context_editor()
|
||||
}
|
||||
|
||||
@@ -2887,7 +2883,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) -> Task<Result<()>> {
|
||||
let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
|
||||
let Some(panel) = workspace.panel::<AgentPanel>(cx) else {
|
||||
return Task::ready(Err(anyhow!("Agent panel not found")));
|
||||
};
|
||||
|
||||
@@ -2914,12 +2910,12 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
|
||||
let Some(panel) = workspace.panel::<AgentPanel>(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if !panel.focus_handle(cx).contains_focused(window, cx) {
|
||||
workspace.toggle_panel_focus::<AssistantPanel>(window, cx);
|
||||
workspace.toggle_panel_focus::<AgentPanel>(window, cx);
|
||||
}
|
||||
|
||||
panel.update(cx, |_, cx| {
|
||||
@@ -1,135 +0,0 @@
|
||||
mod active_thread;
|
||||
mod agent_diff;
|
||||
mod assistant_configuration;
|
||||
mod assistant_model_selector;
|
||||
mod assistant_panel;
|
||||
mod buffer_codegen;
|
||||
mod context;
|
||||
mod context_picker;
|
||||
mod context_server_configuration;
|
||||
mod context_server_tool;
|
||||
mod context_store;
|
||||
mod context_strip;
|
||||
mod debug;
|
||||
mod history_store;
|
||||
mod inline_assistant;
|
||||
mod inline_prompt_editor;
|
||||
mod message_editor;
|
||||
mod profile_selector;
|
||||
mod terminal_codegen;
|
||||
mod terminal_inline_assistant;
|
||||
mod thread;
|
||||
mod thread_history;
|
||||
mod thread_store;
|
||||
mod tool_compatibility;
|
||||
mod tool_use;
|
||||
mod ui;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{AgentProfileId, AssistantSettings};
|
||||
use client::Client;
|
||||
use fs::Fs;
|
||||
use gpui::{App, actions, impl_actions};
|
||||
use language::LanguageRegistry;
|
||||
use prompt_store::PromptBuilder;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use settings::Settings as _;
|
||||
use thread::ThreadId;
|
||||
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
|
||||
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::{TextThreadStore, ThreadStore};
|
||||
pub use agent_diff::{AgentDiffPane, AgentDiffToolbar};
|
||||
pub use context_store::ContextStore;
|
||||
pub use ui::preview::{all_agent_previews, get_agent_preview};
|
||||
|
||||
actions!(
|
||||
agent,
|
||||
[
|
||||
NewTextThread,
|
||||
ToggleContextPicker,
|
||||
ToggleNavigationMenu,
|
||||
ToggleOptionsMenu,
|
||||
DeleteRecentlyOpenThread,
|
||||
ToggleProfileSelector,
|
||||
RemoveAllContext,
|
||||
ExpandMessageEditor,
|
||||
OpenHistory,
|
||||
AddContextServer,
|
||||
RemoveSelectedThread,
|
||||
Chat,
|
||||
CycleNextInlineAssist,
|
||||
CyclePreviousInlineAssist,
|
||||
FocusUp,
|
||||
FocusDown,
|
||||
FocusLeft,
|
||||
FocusRight,
|
||||
RemoveFocusedContext,
|
||||
AcceptSuggestedContext,
|
||||
OpenActiveThreadAsMarkdown,
|
||||
OpenAgentDiff,
|
||||
Keep,
|
||||
Reject,
|
||||
RejectAll,
|
||||
KeepAll,
|
||||
Follow,
|
||||
ResetTrialUpsell,
|
||||
]
|
||||
);
|
||||
|
||||
#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema)]
|
||||
pub struct NewThread {
|
||||
#[serde(default)]
|
||||
from_thread_id: Option<ThreadId>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema)]
|
||||
pub struct ManageProfiles {
|
||||
#[serde(default)]
|
||||
pub customize_tools: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
impl ManageProfiles {
|
||||
pub fn customize_tools(profile_id: AgentProfileId) -> Self {
|
||||
Self {
|
||||
customize_tools: Some(profile_id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_actions!(agent, [NewThread, ManageProfiles]);
|
||||
|
||||
/// Initializes the `agent` crate.
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
client: Arc<Client>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
AssistantSettings::register(cx);
|
||||
thread_store::init(cx);
|
||||
assistant_panel::init(cx);
|
||||
context_server_configuration::init(language_registry, cx);
|
||||
|
||||
inline_assistant::init(
|
||||
fs.clone(),
|
||||
prompt_builder.clone(),
|
||||
client.telemetry().clone(),
|
||||
cx,
|
||||
);
|
||||
terminal_inline_assistant::init(
|
||||
fs.clone(),
|
||||
prompt_builder.clone(),
|
||||
client.telemetry().clone(),
|
||||
cx,
|
||||
);
|
||||
cx.observe_new(AddContextServerModal::register).detach();
|
||||
cx.observe_new(ManageProfilesModal::register).detach();
|
||||
}
|
||||
@@ -466,6 +466,7 @@ impl CodegenAlternative {
|
||||
prompt_id: None,
|
||||
mode: None,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature,
|
||||
messages: vec![request_message],
|
||||
|
||||
@@ -36,7 +36,7 @@ use ui::{
|
||||
use uuid::Uuid;
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::AssistantPanel;
|
||||
use crate::AgentPanel;
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::ThreadId;
|
||||
@@ -648,7 +648,7 @@ fn recent_context_picker_entries(
|
||||
let current_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
let active_thread_id = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.panel::<AgentPanel>(cx)
|
||||
.and_then(|panel| Some(panel.read(cx).active_thread()?.read(cx).id()));
|
||||
|
||||
if let Some((thread_store, text_thread_store)) = thread_store
|
||||
|
||||
@@ -10,7 +10,7 @@ use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::assistant_configuration::ConfigureContextServerModal;
|
||||
use crate::agent_configuration::ConfigureContextServerModal;
|
||||
|
||||
pub(crate) fn init(language_registry: Arc<LanguageRegistry>, cx: &mut App) {
|
||||
cx.observe_new(move |_: &mut Workspace, window, cx| {
|
||||
|
||||
@@ -4,7 +4,7 @@ use anyhow::{Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolSource};
|
||||
use context_server::{ContextServerId, types};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, context_server_store::ContextServerStore};
|
||||
use ui::IconName;
|
||||
|
||||
@@ -72,7 +72,7 @@ impl Tool for ContextServerTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::thread::Thread;
|
||||
use crate::thread_store::{TextThreadStore, ThreadStore};
|
||||
use crate::ui::{AddedContext, ContextPill};
|
||||
use crate::{
|
||||
AcceptSuggestedContext, AssistantPanel, FocusDown, FocusLeft, FocusRight, FocusUp,
|
||||
AcceptSuggestedContext, AgentPanel, FocusDown, FocusLeft, FocusRight, FocusUp,
|
||||
RemoveAllContext, RemoveFocusedContext, ToggleContextPicker,
|
||||
};
|
||||
|
||||
@@ -144,7 +144,7 @@ impl ContextStrip {
|
||||
}
|
||||
|
||||
let workspace = self.workspace.upgrade()?;
|
||||
let panel = workspace.read(cx).panel::<AssistantPanel>(cx)?.read(cx);
|
||||
let panel = workspace.read(cx).panel::<AgentPanel>(cx)?.read(cx);
|
||||
|
||||
if let Some(active_thread) = panel.active_thread() {
|
||||
let weak_active_thread = active_thread.downgrade();
|
||||
|
||||
@@ -8,9 +8,10 @@ use anyhow::{Context as _, Result};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::{HashMap, HashSet, VecDeque, hash_map};
|
||||
use editor::display_map::EditorMargins;
|
||||
use editor::{
|
||||
Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorEvent, ExcerptId, ExcerptRange,
|
||||
GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint,
|
||||
MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint,
|
||||
actions::SelectAll,
|
||||
display_map::{
|
||||
BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
|
||||
@@ -42,7 +43,7 @@ use util::ResultExt;
|
||||
use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId};
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
|
||||
use crate::AssistantPanel;
|
||||
use crate::AgentPanel;
|
||||
use crate::buffer_codegen::{BufferCodegen, CodegenAlternative, CodegenEvent};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::inline_prompt_editor::{CodegenStatus, InlineAssistId, PromptEditor, PromptEditorEvent};
|
||||
@@ -181,13 +182,12 @@ impl InlineAssistant {
|
||||
if let Some(editor) = item.act_as::<Editor>(cx) {
|
||||
editor.update(cx, |editor, cx| {
|
||||
if is_assistant2_enabled {
|
||||
let panel = workspace.read(cx).panel::<AssistantPanel>(cx);
|
||||
let panel = workspace.read(cx).panel::<AgentPanel>(cx);
|
||||
let thread_store = panel
|
||||
.as_ref()
|
||||
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
|
||||
let text_thread_store = panel.map(|assistant_panel| {
|
||||
assistant_panel.read(cx).text_thread_store().downgrade()
|
||||
});
|
||||
.map(|agent_panel| agent_panel.read(cx).thread_store().downgrade());
|
||||
let text_thread_store = panel
|
||||
.map(|agent_panel| agent_panel.read(cx).text_thread_store().downgrade());
|
||||
|
||||
editor.add_code_action_provider(
|
||||
Rc::new(AssistantCodeActionProvider {
|
||||
@@ -226,7 +226,7 @@ impl InlineAssistant {
|
||||
|
||||
let Some(inline_assist_target) = Self::resolve_inline_assist_target(
|
||||
workspace,
|
||||
workspace.panel::<AssistantPanel>(cx),
|
||||
workspace.panel::<AgentPanel>(cx),
|
||||
window,
|
||||
cx,
|
||||
) else {
|
||||
@@ -239,15 +239,15 @@ impl InlineAssistant {
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx))
|
||||
};
|
||||
|
||||
let Some(assistant_panel) = workspace.panel::<AssistantPanel>(cx) else {
|
||||
let Some(agent_panel) = workspace.panel::<AgentPanel>(cx) else {
|
||||
return;
|
||||
};
|
||||
let assistant_panel = assistant_panel.read(cx);
|
||||
let agent_panel = agent_panel.read(cx);
|
||||
|
||||
let prompt_store = assistant_panel.prompt_store().as_ref().cloned();
|
||||
let thread_store = Some(assistant_panel.thread_store().downgrade());
|
||||
let text_thread_store = Some(assistant_panel.text_thread_store().downgrade());
|
||||
let context_store = assistant_panel.inline_assist_context_store().clone();
|
||||
let prompt_store = agent_panel.prompt_store().as_ref().cloned();
|
||||
let thread_store = Some(agent_panel.thread_store().downgrade());
|
||||
let text_thread_store = Some(agent_panel.text_thread_store().downgrade());
|
||||
let context_store = agent_panel.inline_assist_context_store().clone();
|
||||
|
||||
let handle_assist =
|
||||
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
|
||||
@@ -458,11 +458,11 @@ impl InlineAssistant {
|
||||
)
|
||||
});
|
||||
|
||||
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
|
||||
let editor_margins = Arc::new(Mutex::new(EditorMargins::default()));
|
||||
let prompt_editor = cx.new(|cx| {
|
||||
PromptEditor::new_buffer(
|
||||
assist_id,
|
||||
gutter_dimensions.clone(),
|
||||
editor_margins,
|
||||
self.prompt_history.clone(),
|
||||
prompt_buffer.clone(),
|
||||
codegen.clone(),
|
||||
@@ -577,11 +577,11 @@ impl InlineAssistant {
|
||||
)
|
||||
});
|
||||
|
||||
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
|
||||
let editor_margins = Arc::new(Mutex::new(EditorMargins::default()));
|
||||
let prompt_editor = cx.new(|cx| {
|
||||
PromptEditor::new_buffer(
|
||||
assist_id,
|
||||
gutter_dimensions.clone(),
|
||||
editor_margins,
|
||||
self.prompt_history.clone(),
|
||||
prompt_buffer.clone(),
|
||||
codegen.clone(),
|
||||
@@ -650,6 +650,7 @@ impl InlineAssistant {
|
||||
height: Some(prompt_editor_height),
|
||||
render: build_assist_editor_renderer(prompt_editor),
|
||||
priority: 0,
|
||||
render_in_minimap: false,
|
||||
},
|
||||
BlockProperties {
|
||||
style: BlockStyle::Sticky,
|
||||
@@ -664,6 +665,7 @@ impl InlineAssistant {
|
||||
.into_any_element()
|
||||
}),
|
||||
priority: 0,
|
||||
render_in_minimap: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1405,11 +1407,11 @@ impl InlineAssistant {
|
||||
|
||||
enum DeletedLines {}
|
||||
let mut editor = Editor::for_multibuffer(multi_buffer, None, window, cx);
|
||||
editor.disable_scrollbars_and_minimap(cx);
|
||||
editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
|
||||
editor.set_show_wrap_guides(false, cx);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_scrollbars(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_edit_predictions(Some(false), window, cx);
|
||||
editor.highlight_rows::<DeletedLines>(
|
||||
@@ -1433,11 +1435,12 @@ impl InlineAssistant {
|
||||
.bg(cx.theme().status().deleted_background)
|
||||
.size_full()
|
||||
.h(height as f32 * cx.window.line_height())
|
||||
.pl(cx.gutter_dimensions.full_width())
|
||||
.pl(cx.margins.gutter.full_width())
|
||||
.child(deleted_lines_editor.clone())
|
||||
.into_any_element()
|
||||
}),
|
||||
priority: 0,
|
||||
render_in_minimap: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1450,7 +1453,7 @@ impl InlineAssistant {
|
||||
|
||||
fn resolve_inline_assist_target(
|
||||
workspace: &mut Workspace,
|
||||
assistant_panel: Option<Entity<AssistantPanel>>,
|
||||
agent_panel: Option<Entity<AgentPanel>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<InlineAssistTarget> {
|
||||
@@ -1470,7 +1473,7 @@ impl InlineAssistant {
|
||||
}
|
||||
}
|
||||
|
||||
let context_editor = assistant_panel
|
||||
let context_editor = agent_panel
|
||||
.and_then(|panel| panel.read(cx).active_context_editor())
|
||||
.and_then(|editor| {
|
||||
let editor = &editor.read(cx).editor().clone();
|
||||
@@ -1595,9 +1598,9 @@ fn build_assist_editor_renderer(editor: &Entity<PromptEditor<BufferCodegen>>) ->
|
||||
let editor = editor.clone();
|
||||
|
||||
Arc::new(move |cx: &mut BlockContext| {
|
||||
let gutter_dimensions = editor.read(cx).gutter_dimensions();
|
||||
let editor_margins = editor.read(cx).editor_margins();
|
||||
|
||||
*gutter_dimensions.lock() = *cx.gutter_dimensions;
|
||||
*editor_margins.lock() = *cx.margins;
|
||||
editor.clone().into_any_element()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::agent_model_selector::{AgentModelSelector, ModelType};
|
||||
use crate::buffer_codegen::BufferCodegen;
|
||||
use crate::context::ContextCreasesAddon;
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
|
||||
@@ -11,9 +11,9 @@ use crate::{CycleNextInlineAssist, CyclePreviousInlineAssist};
|
||||
use crate::{RemoveAllContext, ToggleContextPicker};
|
||||
use client::ErrorExt;
|
||||
use collections::VecDeque;
|
||||
use editor::display_map::EditorMargins;
|
||||
use editor::{
|
||||
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
|
||||
GutterDimensions, MultiBuffer,
|
||||
ContextMenuOptions, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
|
||||
actions::{MoveDown, MoveUp},
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
|
||||
@@ -42,7 +42,7 @@ pub struct PromptEditor<T> {
|
||||
context_store: Entity<ContextStore>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AssistantModelSelector>,
|
||||
model_selector: Entity<AgentModelSelector>,
|
||||
edited_since_done: bool,
|
||||
prompt_history: VecDeque<String>,
|
||||
prompt_history_ix: Option<usize>,
|
||||
@@ -61,11 +61,13 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
let ui_font_size = ThemeSettings::get_global(cx).ui_font_size(cx);
|
||||
let mut buttons = Vec::new();
|
||||
|
||||
let left_gutter_width = match &self.mode {
|
||||
const RIGHT_PADDING: Pixels = px(9.);
|
||||
|
||||
let (left_gutter_width, right_padding) = match &self.mode {
|
||||
PromptEditorMode::Buffer {
|
||||
id: _,
|
||||
codegen,
|
||||
gutter_dimensions,
|
||||
editor_margins,
|
||||
} => {
|
||||
let codegen = codegen.read(cx);
|
||||
|
||||
@@ -73,13 +75,17 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
buttons.push(self.render_cycle_controls(&codegen, cx));
|
||||
}
|
||||
|
||||
let gutter_dimensions = gutter_dimensions.lock();
|
||||
let editor_margins = editor_margins.lock();
|
||||
let gutter = editor_margins.gutter;
|
||||
|
||||
gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)
|
||||
let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0);
|
||||
let right_padding = editor_margins.right + RIGHT_PADDING;
|
||||
|
||||
(left_gutter_width, right_padding)
|
||||
}
|
||||
PromptEditorMode::Terminal { .. } => {
|
||||
// Give the equivalent of the same left-padding that we're using on the right
|
||||
Pixels::from(40.0)
|
||||
(Pixels::from(40.0), Pixels::from(24.))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -100,7 +106,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
.size_full()
|
||||
.pt_0p5()
|
||||
.pb(bottom_padding)
|
||||
.pr_6()
|
||||
.pr(right_padding)
|
||||
.child(
|
||||
h_flex()
|
||||
.items_start()
|
||||
@@ -284,12 +290,12 @@ impl<T: 'static> PromptEditor<T> {
|
||||
PromptEditorMode::Terminal { .. } => "Generate",
|
||||
};
|
||||
|
||||
let assistant_panel_keybinding =
|
||||
let agent_panel_keybinding =
|
||||
ui::text_for_action(&zed_actions::assistant::ToggleFocus, window, cx)
|
||||
.map(|keybinding| format!("{keybinding} to chat ― "))
|
||||
.unwrap_or_default();
|
||||
|
||||
format!("{action}… ({assistant_panel_keybinding}↓↑ for history)")
|
||||
format!("{action}… ({agent_panel_keybinding}↓↑ for history)")
|
||||
}
|
||||
|
||||
pub fn prompt(&self, cx: &App) -> String {
|
||||
@@ -806,7 +812,7 @@ pub enum PromptEditorMode {
|
||||
Buffer {
|
||||
id: InlineAssistId,
|
||||
codegen: Entity<BufferCodegen>,
|
||||
gutter_dimensions: Arc<Mutex<GutterDimensions>>,
|
||||
editor_margins: Arc<Mutex<EditorMargins>>,
|
||||
},
|
||||
Terminal {
|
||||
id: TerminalInlineAssistId,
|
||||
@@ -838,7 +844,7 @@ impl InlineAssistId {
|
||||
impl PromptEditor<BufferCodegen> {
|
||||
pub fn new_buffer(
|
||||
id: InlineAssistId,
|
||||
gutter_dimensions: Arc<Mutex<GutterDimensions>>,
|
||||
editor_margins: Arc<Mutex<EditorMargins>>,
|
||||
prompt_history: VecDeque<String>,
|
||||
prompt_buffer: Entity<MultiBuffer>,
|
||||
codegen: Entity<BufferCodegen>,
|
||||
@@ -855,7 +861,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
let mode = PromptEditorMode::Buffer {
|
||||
id,
|
||||
codegen,
|
||||
gutter_dimensions,
|
||||
editor_margins,
|
||||
};
|
||||
|
||||
let prompt_editor = cx.new(|cx| {
|
||||
@@ -921,7 +927,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
AgentModelSelector::new(
|
||||
fs,
|
||||
model_selector_menu_handle,
|
||||
prompt_editor.focus_handle(cx),
|
||||
@@ -995,11 +1001,9 @@ impl PromptEditor<BufferCodegen> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gutter_dimensions(&self) -> &Arc<Mutex<GutterDimensions>> {
|
||||
pub fn editor_margins(&self) -> &Arc<Mutex<EditorMargins>> {
|
||||
match &self.mode {
|
||||
PromptEditorMode::Buffer {
|
||||
gutter_dimensions, ..
|
||||
} => gutter_dimensions,
|
||||
PromptEditorMode::Buffer { editor_margins, .. } => editor_margins,
|
||||
PromptEditorMode::Terminal { .. } => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -1094,7 +1098,7 @@ impl PromptEditor<TerminalCodegen> {
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
AgentModelSelector::new(
|
||||
fs,
|
||||
model_selector_menu_handle.clone(),
|
||||
prompt_editor.focus_handle(cx),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
|
||||
use crate::agent_model_selector::{AgentModelSelector, ModelType};
|
||||
use crate::context::{AgentContextKey, ContextCreasesAddon, ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use crate::ui::{
|
||||
@@ -65,7 +65,7 @@ pub struct MessageEditor {
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AssistantModelSelector>,
|
||||
model_selector: Entity<AgentModelSelector>,
|
||||
last_loaded_context: Option<ContextLoadResult>,
|
||||
load_context_task: Option<Shared<Task<()>>>,
|
||||
profile_selector: Entity<ProfileSelector>,
|
||||
@@ -189,7 +189,7 @@ impl MessageEditor {
|
||||
];
|
||||
|
||||
let model_selector = cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
AgentModelSelector::new(
|
||||
fs.clone(),
|
||||
model_selector_menu_handle,
|
||||
editor.focus_handle(cx),
|
||||
@@ -199,6 +199,10 @@ impl MessageEditor {
|
||||
)
|
||||
});
|
||||
|
||||
let profile_selector = cx.new(|cx| {
|
||||
ProfileSelector::new(thread.clone(), thread_store, editor.focus_handle(cx), cx)
|
||||
});
|
||||
|
||||
Self {
|
||||
editor: editor.clone(),
|
||||
project: thread.read(cx).project().clone(),
|
||||
@@ -215,8 +219,7 @@ impl MessageEditor {
|
||||
model_selector,
|
||||
edits_expanded: false,
|
||||
editor_is_expanded: false,
|
||||
profile_selector: cx
|
||||
.new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)),
|
||||
profile_selector,
|
||||
last_estimated_token_count: None,
|
||||
update_token_count_task: None,
|
||||
_subscriptions: subscriptions,
|
||||
@@ -1242,6 +1245,7 @@ impl MessageEditor {
|
||||
mode: None,
|
||||
messages: vec![request_message],
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
stop: vec![],
|
||||
temperature: AssistantSettings::temperature_for_model(&model.model, cx),
|
||||
};
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{
|
||||
AgentProfile, AgentProfileId, AssistantDockPosition, AssistantSettings, GroupedAgentProfiles,
|
||||
builtin_profiles,
|
||||
};
|
||||
use fs::Fs;
|
||||
use gpui::{Action, Entity, FocusHandle, Subscription, WeakEntity, prelude::*};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use settings::{Settings as _, SettingsStore, update_settings_file};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use ui::{
|
||||
ContextMenu, ContextMenuEntry, DocumentationSide, PopoverMenu, PopoverMenuHandle, Tooltip,
|
||||
prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{ManageProfiles, ThreadStore, ToggleProfileSelector};
|
||||
use crate::{ManageProfiles, Thread, ThreadStore, ToggleProfileSelector};
|
||||
|
||||
pub struct ProfileSelector {
|
||||
profiles: GroupedAgentProfiles,
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
menu_handle: PopoverMenuHandle<ContextMenu>,
|
||||
focus_handle: FocusHandle,
|
||||
@@ -27,7 +24,7 @@ pub struct ProfileSelector {
|
||||
|
||||
impl ProfileSelector {
|
||||
pub fn new(
|
||||
fs: Arc<dyn Fs>,
|
||||
thread: Entity<Thread>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
focus_handle: FocusHandle,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -38,7 +35,7 @@ impl ProfileSelector {
|
||||
|
||||
Self {
|
||||
profiles: GroupedAgentProfiles::from_settings(AssistantSettings::get_global(cx)),
|
||||
fs,
|
||||
thread,
|
||||
thread_store,
|
||||
menu_handle: PopoverMenuHandle::default(),
|
||||
focus_handle,
|
||||
@@ -113,15 +110,15 @@ impl ProfileSelector {
|
||||
};
|
||||
|
||||
entry.handler({
|
||||
let fs = self.fs.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let profile_id = profile_id.clone();
|
||||
let profile = profile.clone();
|
||||
|
||||
let thread = self.thread.clone();
|
||||
|
||||
move |_window, cx| {
|
||||
update_settings_file::<AssistantSettings>(fs.clone(), cx, {
|
||||
let profile_id = profile_id.clone();
|
||||
move |settings, _cx| {
|
||||
settings.set_profile(profile_id.clone());
|
||||
}
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_configured_profile(Some(profile.clone()), cx);
|
||||
});
|
||||
|
||||
thread_store
|
||||
@@ -137,17 +134,28 @@ impl ProfileSelector {
|
||||
impl Render for ProfileSelector {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings = AssistantSettings::get_global(cx);
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
let profile = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| thread.configured_profile())
|
||||
.or_else(|| {
|
||||
let profile_id = &settings.default_profile;
|
||||
let profile = settings.profiles.get(profile_id);
|
||||
profile.cloned()
|
||||
});
|
||||
|
||||
let selected_profile = profile
|
||||
.map(|profile| profile.name.clone())
|
||||
.unwrap_or_else(|| "Unknown".into());
|
||||
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
let supports_tools = model_registry
|
||||
.default_model()
|
||||
.map_or(false, |default| default.model.supports_tools());
|
||||
let configured_model = self
|
||||
.thread
|
||||
.read_with(cx, |thread, _cx| thread.configured_model())
|
||||
.or_else(|| {
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
model_registry.default_model()
|
||||
});
|
||||
let supports_tools =
|
||||
configured_model.map_or(false, |default| default.model.supports_tools());
|
||||
|
||||
if supports_tools {
|
||||
let this = cx.entity().clone();
|
||||
|
||||
@@ -293,6 +293,7 @@ impl TerminalInlineAssistant {
|
||||
mode: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature,
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_settings::{AssistantSettings, CompletionMode};
|
||||
use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings, CompletionMode};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::HashMap;
|
||||
@@ -359,6 +359,7 @@ pub struct Thread {
|
||||
>,
|
||||
remaining_turns: u32,
|
||||
configured_model: Option<ConfiguredModel>,
|
||||
configured_profile: Option<AgentProfile>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -379,6 +380,9 @@ impl Thread {
|
||||
) -> Self {
|
||||
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
|
||||
let configured_model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
let assistant_settings = AssistantSettings::get_global(cx);
|
||||
let profile_id = &assistant_settings.default_profile;
|
||||
let configured_profile = assistant_settings.profiles.get(profile_id).cloned();
|
||||
|
||||
Self {
|
||||
id: ThreadId::new(),
|
||||
@@ -421,6 +425,7 @@ impl Thread {
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_profile,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,6 +473,13 @@ impl Thread {
|
||||
.completion_mode
|
||||
.unwrap_or_else(|| AssistantSettings::get_global(cx).preferred_completion_mode);
|
||||
|
||||
let configured_profile = serialized.profile.and_then(|profile| {
|
||||
AssistantSettings::get_global(cx)
|
||||
.profiles
|
||||
.get(&profile)
|
||||
.cloned()
|
||||
});
|
||||
|
||||
Self {
|
||||
id,
|
||||
updated_at: serialized.updated_at,
|
||||
@@ -541,6 +553,7 @@ impl Thread {
|
||||
request_callback: None,
|
||||
remaining_turns: u32::MAX,
|
||||
configured_model,
|
||||
configured_profile,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,6 +609,19 @@ impl Thread {
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn configured_profile(&self) -> Option<AgentProfile> {
|
||||
self.configured_profile.clone()
|
||||
}
|
||||
|
||||
pub fn set_configured_profile(
|
||||
&mut self,
|
||||
profile: Option<AgentProfile>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.configured_profile = profile;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
|
||||
|
||||
pub fn summary_or_default(&self) -> SharedString {
|
||||
@@ -1100,6 +1126,10 @@ impl Thread {
|
||||
provider: model.provider.id().0.to_string(),
|
||||
model: model.model.id().0.to_string(),
|
||||
}),
|
||||
profile: this
|
||||
.configured_profile
|
||||
.as_ref()
|
||||
.map(|profile| AgentProfileId(profile.name.clone().into())),
|
||||
completion_mode: Some(this.completion_mode),
|
||||
})
|
||||
})
|
||||
@@ -1153,6 +1183,7 @@ impl Thread {
|
||||
mode: None,
|
||||
messages: vec![],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: AssistantSettings::temperature_for_model(&model, cx),
|
||||
};
|
||||
@@ -1197,6 +1228,7 @@ impl Thread {
|
||||
}));
|
||||
}
|
||||
|
||||
let mut message_ix_to_cache = None;
|
||||
for message in &self.messages {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: message.role,
|
||||
@@ -1233,19 +1265,57 @@ impl Thread {
|
||||
};
|
||||
}
|
||||
|
||||
self.tool_use
|
||||
.attach_tool_uses(message.id, &mut request_message);
|
||||
let mut cache_message = true;
|
||||
let mut tool_results_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
for (tool_use, tool_result) in self.tool_use.tool_results(message.id) {
|
||||
if let Some(tool_result) = tool_result {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
tool_results_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
output: None,
|
||||
}));
|
||||
} else {
|
||||
cache_message = false;
|
||||
log::debug!(
|
||||
"skipped tool use {:?} because it is still pending",
|
||||
tool_use
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if cache_message {
|
||||
message_ix_to_cache = Some(request.messages.len());
|
||||
}
|
||||
request.messages.push(request_message);
|
||||
|
||||
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
|
||||
if !tool_results_message.content.is_empty() {
|
||||
if cache_message {
|
||||
message_ix_to_cache = Some(request.messages.len());
|
||||
}
|
||||
request.messages.push(tool_results_message);
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
if let Some(last) = request.messages.last_mut() {
|
||||
last.cache = true;
|
||||
if let Some(message_ix_to_cache) = message_ix_to_cache {
|
||||
request.messages[message_ix_to_cache].cache = true;
|
||||
}
|
||||
|
||||
self.attached_tracked_files_state(&mut request.messages, cx);
|
||||
@@ -1272,6 +1342,7 @@ impl Thread {
|
||||
mode: None,
|
||||
messages: vec![],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: AssistantSettings::temperature_for_model(model, cx),
|
||||
};
|
||||
@@ -1888,8 +1959,7 @@ impl Thread {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
) -> Vec<PendingToolUse> {
|
||||
self.auto_capture_telemetry(cx);
|
||||
let request = self.to_completion_request(model.clone(), cx);
|
||||
let messages = Arc::new(request.messages);
|
||||
let request = Arc::new(self.to_completion_request(model.clone(), cx));
|
||||
let pending_tool_uses = self
|
||||
.tool_use
|
||||
.pending_tool_uses()
|
||||
@@ -1907,7 +1977,7 @@ impl Thread {
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
messages.clone(),
|
||||
request.clone(),
|
||||
tool,
|
||||
);
|
||||
cx.emit(ThreadEvent::ToolConfirmationNeeded);
|
||||
@@ -1916,7 +1986,7 @@ impl Thread {
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
tool_use.input.clone(),
|
||||
&messages,
|
||||
request.clone(),
|
||||
tool,
|
||||
model.clone(),
|
||||
window,
|
||||
@@ -2011,21 +2081,14 @@ impl Thread {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: impl Into<SharedString>,
|
||||
input: serde_json::Value,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
request: Arc<LanguageModelRequest>,
|
||||
tool: Arc<dyn Tool>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
let task = self.spawn_tool_use(
|
||||
tool_use_id.clone(),
|
||||
messages,
|
||||
input,
|
||||
tool,
|
||||
model,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
let task =
|
||||
self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx);
|
||||
self.tool_use
|
||||
.run_pending_tool(tool_use_id, ui_text.into(), task);
|
||||
}
|
||||
@@ -2033,7 +2096,7 @@ impl Thread {
|
||||
fn spawn_tool_use(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
request: Arc<LanguageModelRequest>,
|
||||
input: serde_json::Value,
|
||||
tool: Arc<dyn Tool>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
@@ -2047,7 +2110,7 @@ impl Thread {
|
||||
} else {
|
||||
tool.run(
|
||||
input,
|
||||
messages,
|
||||
request,
|
||||
self.project.clone(),
|
||||
self.action_log.clone(),
|
||||
model,
|
||||
|
||||
@@ -19,10 +19,10 @@ use util::ResultExt;
|
||||
|
||||
use crate::history_store::{HistoryEntry, HistoryStore};
|
||||
use crate::thread_store::SerializedThreadMetadata;
|
||||
use crate::{AssistantPanel, RemoveSelectedThread};
|
||||
use crate::{AgentPanel, RemoveSelectedThread};
|
||||
|
||||
pub struct ThreadHistory {
|
||||
assistant_panel: WeakEntity<AssistantPanel>,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
scroll_handle: UniformListScrollHandle,
|
||||
selected_index: usize,
|
||||
@@ -69,7 +69,7 @@ impl HistoryListItem {
|
||||
|
||||
impl ThreadHistory {
|
||||
pub(crate) fn new(
|
||||
assistant_panel: WeakEntity<AssistantPanel>,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -96,7 +96,7 @@ impl ThreadHistory {
|
||||
let scrollbar_state = ScrollbarState::new(scroll_handle.clone());
|
||||
|
||||
let mut this = Self {
|
||||
assistant_panel,
|
||||
agent_panel,
|
||||
history_store,
|
||||
scroll_handle,
|
||||
selected_index: 0,
|
||||
@@ -380,14 +380,12 @@ impl ThreadHistory {
|
||||
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if let Some(entry) = self.get_match(self.selected_index) {
|
||||
let task_result = match entry {
|
||||
HistoryEntry::Thread(thread) => self.assistant_panel.update(cx, move |this, cx| {
|
||||
HistoryEntry::Thread(thread) => self.agent_panel.update(cx, move |this, cx| {
|
||||
this.open_thread_by_id(&thread.id, window, cx)
|
||||
}),
|
||||
HistoryEntry::Context(context) => {
|
||||
self.assistant_panel.update(cx, move |this, cx| {
|
||||
this.open_saved_prompt_editor(context.path.clone(), window, cx)
|
||||
})
|
||||
}
|
||||
HistoryEntry::Context(context) => self.agent_panel.update(cx, move |this, cx| {
|
||||
this.open_saved_prompt_editor(context.path.clone(), window, cx)
|
||||
}),
|
||||
};
|
||||
|
||||
if let Some(task) = task_result.log_err() {
|
||||
@@ -407,10 +405,10 @@ impl ThreadHistory {
|
||||
if let Some(entry) = self.get_match(self.selected_index) {
|
||||
let task_result = match entry {
|
||||
HistoryEntry::Thread(thread) => self
|
||||
.assistant_panel
|
||||
.agent_panel
|
||||
.update(cx, |this, cx| this.delete_thread(&thread.id, cx)),
|
||||
HistoryEntry::Context(context) => self
|
||||
.assistant_panel
|
||||
.agent_panel
|
||||
.update(cx, |this, cx| this.delete_context(context.path.clone(), cx)),
|
||||
};
|
||||
|
||||
@@ -506,7 +504,7 @@ impl ThreadHistory {
|
||||
match entry {
|
||||
HistoryEntry::Thread(thread) => PastThread::new(
|
||||
thread.clone(),
|
||||
self.assistant_panel.clone(),
|
||||
self.agent_panel.clone(),
|
||||
is_active,
|
||||
highlight_positions,
|
||||
format,
|
||||
@@ -514,7 +512,7 @@ impl ThreadHistory {
|
||||
.into_any_element(),
|
||||
HistoryEntry::Context(context) => PastContext::new(
|
||||
context.clone(),
|
||||
self.assistant_panel.clone(),
|
||||
self.agent_panel.clone(),
|
||||
is_active,
|
||||
highlight_positions,
|
||||
format,
|
||||
@@ -605,7 +603,7 @@ impl Render for ThreadHistory {
|
||||
#[derive(IntoElement)]
|
||||
pub struct PastThread {
|
||||
thread: SerializedThreadMetadata,
|
||||
assistant_panel: WeakEntity<AssistantPanel>,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
@@ -614,14 +612,14 @@ pub struct PastThread {
|
||||
impl PastThread {
|
||||
pub fn new(
|
||||
thread: SerializedThreadMetadata,
|
||||
assistant_panel: WeakEntity<AssistantPanel>,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
) -> Self {
|
||||
Self {
|
||||
thread,
|
||||
assistant_panel,
|
||||
agent_panel,
|
||||
selected,
|
||||
highlight_positions,
|
||||
timestamp_format,
|
||||
@@ -634,7 +632,7 @@ impl RenderOnce for PastThread {
|
||||
let summary = self.thread.summary;
|
||||
|
||||
let thread_timestamp = self.timestamp_format.format_timestamp(
|
||||
&self.assistant_panel,
|
||||
&self.agent_panel,
|
||||
self.thread.updated_at.timestamp(),
|
||||
cx,
|
||||
);
|
||||
@@ -667,10 +665,10 @@ impl RenderOnce for PastThread {
|
||||
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
|
||||
})
|
||||
.on_click({
|
||||
let assistant_panel = self.assistant_panel.clone();
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let id = self.thread.id.clone();
|
||||
move |_event, _window, cx| {
|
||||
assistant_panel
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.delete_thread(&id, cx).detach_and_log_err(cx);
|
||||
})
|
||||
@@ -680,10 +678,10 @@ impl RenderOnce for PastThread {
|
||||
),
|
||||
)
|
||||
.on_click({
|
||||
let assistant_panel = self.assistant_panel.clone();
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let id = self.thread.id.clone();
|
||||
move |_event, window, cx| {
|
||||
assistant_panel
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.open_thread_by_id(&id, window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -697,7 +695,7 @@ impl RenderOnce for PastThread {
|
||||
#[derive(IntoElement)]
|
||||
pub struct PastContext {
|
||||
context: SavedContextMetadata,
|
||||
assistant_panel: WeakEntity<AssistantPanel>,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
@@ -706,14 +704,14 @@ pub struct PastContext {
|
||||
impl PastContext {
|
||||
pub fn new(
|
||||
context: SavedContextMetadata,
|
||||
assistant_panel: WeakEntity<AssistantPanel>,
|
||||
agent_panel: WeakEntity<AgentPanel>,
|
||||
selected: bool,
|
||||
highlight_positions: Vec<usize>,
|
||||
timestamp_format: EntryTimeFormat,
|
||||
) -> Self {
|
||||
Self {
|
||||
context,
|
||||
assistant_panel,
|
||||
agent_panel,
|
||||
selected,
|
||||
highlight_positions,
|
||||
timestamp_format,
|
||||
@@ -725,7 +723,7 @@ impl RenderOnce for PastContext {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let summary = self.context.title;
|
||||
let context_timestamp = self.timestamp_format.format_timestamp(
|
||||
&self.assistant_panel,
|
||||
&self.agent_panel,
|
||||
self.context.mtime.timestamp(),
|
||||
cx,
|
||||
);
|
||||
@@ -760,10 +758,10 @@ impl RenderOnce for PastContext {
|
||||
Tooltip::for_action("Delete", &RemoveSelectedThread, window, cx)
|
||||
})
|
||||
.on_click({
|
||||
let assistant_panel = self.assistant_panel.clone();
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let path = self.context.path.clone();
|
||||
move |_event, _window, cx| {
|
||||
assistant_panel
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.delete_context(path.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -774,10 +772,10 @@ impl RenderOnce for PastContext {
|
||||
),
|
||||
)
|
||||
.on_click({
|
||||
let assistant_panel = self.assistant_panel.clone();
|
||||
let agent_panel = self.agent_panel.clone();
|
||||
let path = self.context.path.clone();
|
||||
move |_event, window, cx| {
|
||||
assistant_panel
|
||||
agent_panel
|
||||
.update(cx, |this, cx| {
|
||||
this.open_saved_prompt_editor(path.clone(), window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -797,12 +795,12 @@ pub enum EntryTimeFormat {
|
||||
impl EntryTimeFormat {
|
||||
fn format_timestamp(
|
||||
&self,
|
||||
assistant_panel: &WeakEntity<AssistantPanel>,
|
||||
agent_panel: &WeakEntity<AgentPanel>,
|
||||
timestamp: i64,
|
||||
cx: &App,
|
||||
) -> String {
|
||||
let timestamp = OffsetDateTime::from_unix_timestamp(timestamp).unwrap();
|
||||
let timezone = assistant_panel
|
||||
let timezone = agent_panel
|
||||
.read_with(cx, |this, _cx| this.local_timezone())
|
||||
.unwrap_or(UtcOffset::UTC);
|
||||
|
||||
|
||||
@@ -657,6 +657,8 @@ pub struct SerializedThread {
|
||||
pub model: Option<SerializedLanguageModel>,
|
||||
#[serde(default)]
|
||||
pub completion_mode: Option<CompletionMode>,
|
||||
#[serde(default)]
|
||||
pub profile: Option<AgentProfileId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
@@ -802,6 +804,7 @@ impl LegacySerializedThread {
|
||||
exceeded_window_error: None,
|
||||
model: None,
|
||||
completion_mode: None,
|
||||
profile: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ use futures::FutureExt as _;
|
||||
use futures::future::Shared;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
|
||||
ConfiguredModel, LanguageModel, LanguageModelRequest, LanguageModelToolResult,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||
};
|
||||
use project::Project;
|
||||
use ui::{IconName, Window};
|
||||
@@ -354,7 +354,7 @@ impl ToolUseState {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: impl Into<Arc<str>>,
|
||||
input: serde_json::Value,
|
||||
messages: Arc<Vec<LanguageModelRequestMessage>>,
|
||||
request: Arc<LanguageModelRequest>,
|
||||
tool: Arc<dyn Tool>,
|
||||
) {
|
||||
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
|
||||
@@ -363,7 +363,7 @@ impl ToolUseState {
|
||||
let confirmation = Confirmation {
|
||||
tool_use_id,
|
||||
input,
|
||||
messages,
|
||||
request,
|
||||
tool,
|
||||
ui_text,
|
||||
};
|
||||
@@ -449,72 +449,20 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_tool_uses(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
|
||||
for tool_use in tool_uses {
|
||||
if self.tool_results.contains_key(&tool_use.id) {
|
||||
// Do not send tool uses until they are completed
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolUse(tool_use.clone()));
|
||||
} else {
|
||||
log::debug!(
|
||||
"skipped tool use {:?} because it is still pending",
|
||||
tool_use
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.contains_key(&assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_results_message(
|
||||
pub fn tool_results(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Option<LanguageModelRequestMessage> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)?;
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
output: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Some(request_message)
|
||||
) -> impl Iterator<Item = (&LanguageModelToolUse, Option<&LanguageModelToolResult>)> {
|
||||
self.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -535,7 +483,7 @@ pub struct Confirmation {
|
||||
pub tool_use_id: LanguageModelToolUseId,
|
||||
pub input: serde_json::Value,
|
||||
pub ui_text: Arc<str>,
|
||||
pub messages: Arc<Vec<LanguageModelRequestMessage>>,
|
||||
pub request: Arc<LanguageModelRequest>,
|
||||
pub tool: Arc<dyn Tool>,
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ use gpui::{
|
||||
use ui::{TintColor, Vector, VectorName, prelude::*};
|
||||
use workspace::{ModalView, Workspace};
|
||||
|
||||
use crate::assistant_panel::AssistantPanel;
|
||||
use crate::agent_panel::AgentPanel;
|
||||
|
||||
macro_rules! agent_onboarding_event {
|
||||
($name:expr) => {
|
||||
@@ -31,7 +31,7 @@ impl AgentOnboardingModal {
|
||||
|
||||
fn open_panel(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.workspace.update(cx, |workspace, cx| {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
});
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
|
||||
@@ -578,6 +578,7 @@ pub enum ToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
Tool { name: String },
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
[package]
|
||||
name = "assistant"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/assistant.rs"
|
||||
doctest = false
|
||||
|
||||
[features]
|
||||
test-support = [
|
||||
"editor/test-support",
|
||||
"language/test-support",
|
||||
"project/test-support",
|
||||
"text/test-support",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_context_editor.workspace = true
|
||||
assistant_settings.workspace = true
|
||||
assistant_slash_command.workspace = true
|
||||
assistant_slash_commands.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
async-watch.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
indoc.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
language_model_selector.workspace = true
|
||||
log.workspace = true
|
||||
lsp.workspace = true
|
||||
menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
search.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
streaming_diff.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
terminal.workspace = true
|
||||
terminal_view.workspace = true
|
||||
text.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
ctor.workspace = true
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
languages = { workspace = true, features = ["test-support"] }
|
||||
log.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
terminal_view = { workspace = true, features = ["test-support"] }
|
||||
text = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-md.workspace = true
|
||||
unindent.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-GPL
|
||||
@@ -1,199 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::HashMap;
|
||||
use gpui::{AnyView, App, EventEmitter, FocusHandle, Focusable, Subscription, canvas};
|
||||
use language_model::{LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry};
|
||||
use ui::{ElevationIndex, prelude::*};
|
||||
use workspace::Item;
|
||||
|
||||
pub struct ConfigurationView {
|
||||
focus_handle: FocusHandle,
|
||||
configuration_views: HashMap<LanguageModelProviderId, AnyView>,
|
||||
_registry_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let focus_handle = cx.focus_handle();
|
||||
|
||||
let registry_subscription = cx.subscribe_in(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|this, _, event: &language_model::Event, window, cx| match event {
|
||||
language_model::Event::AddedProvider(provider_id) => {
|
||||
let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
|
||||
if let Some(provider) = provider {
|
||||
this.add_configuration_view(&provider, window, cx);
|
||||
}
|
||||
}
|
||||
language_model::Event::RemovedProvider(provider_id) => {
|
||||
this.remove_configuration_view(provider_id);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
);
|
||||
|
||||
let mut this = Self {
|
||||
focus_handle,
|
||||
configuration_views: HashMap::default(),
|
||||
_registry_subscription: registry_subscription,
|
||||
};
|
||||
this.build_configuration_views(window, cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn build_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
for provider in providers {
|
||||
self.add_configuration_view(&provider, window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_configuration_view(&mut self, provider_id: &LanguageModelProviderId) {
|
||||
self.configuration_views.remove(provider_id);
|
||||
}
|
||||
|
||||
fn add_configuration_view(
|
||||
&mut self,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let configuration_view = provider.configuration_view(window, cx);
|
||||
self.configuration_views
|
||||
.insert(provider.id(), configuration_view);
|
||||
}
|
||||
|
||||
fn render_provider_view(
|
||||
&mut self,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Div {
|
||||
let provider_id = provider.id().0.clone();
|
||||
let provider_name = provider.name().0.clone();
|
||||
let configuration_view = self.configuration_views.get(&provider.id()).cloned();
|
||||
|
||||
let open_new_context = cx.listener({
|
||||
let provider = provider.clone();
|
||||
move |_, _, _window, cx| {
|
||||
cx.emit(ConfigurationViewEvent::NewProviderContextEditor(
|
||||
provider.clone(),
|
||||
))
|
||||
}
|
||||
});
|
||||
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_between()
|
||||
.child(Headline::new(provider_name.clone()).size(HeadlineSize::Small))
|
||||
.when(provider.is_authenticated(cx), move |this| {
|
||||
this.child(
|
||||
h_flex().justify_end().child(
|
||||
Button::new(
|
||||
SharedString::from(format!("new-context-{provider_id}")),
|
||||
"Open New Chat",
|
||||
)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon(IconName::Plus)
|
||||
.style(ButtonStyle::Filled)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.on_click(open_new_context),
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.p(DynamicSpacing::Base08.rems(cx))
|
||||
.bg(cx.theme().colors().surface_background)
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.rounded_sm()
|
||||
.when(configuration_view.is_none(), |this| {
|
||||
this.child(div().child(Label::new(format!(
|
||||
"No configuration view for {}",
|
||||
provider_name
|
||||
))))
|
||||
})
|
||||
.when_some(configuration_view, |this, configuration_view| {
|
||||
this.child(configuration_view)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
let provider_views = providers
|
||||
.into_iter()
|
||||
.map(|provider| self.render_provider_view(&provider, cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut element = v_flex()
|
||||
.id("assistant-configuration-view")
|
||||
.track_focus(&self.focus_handle(cx))
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.size_full()
|
||||
.overflow_y_scroll()
|
||||
.child(
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.gap_1()
|
||||
.child(Headline::new("Configure your Assistant").size(HeadlineSize::Medium))
|
||||
.child(
|
||||
Label::new(
|
||||
"At least one LLM provider must be configured to use the Assistant.",
|
||||
)
|
||||
.color(Color::Muted),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.p(DynamicSpacing::Base16.rems(cx))
|
||||
.mt_1()
|
||||
.gap_6()
|
||||
.flex_1()
|
||||
.children(provider_views),
|
||||
)
|
||||
.into_any();
|
||||
|
||||
// We use a canvas here to get scrolling to work in the ConfigurationView. It's a workaround
|
||||
// because we couldn't the element to take up the size of the parent.
|
||||
canvas(
|
||||
move |bounds, window, cx| {
|
||||
element.prepaint_as_root(bounds.origin, bounds.size.into(), window, cx);
|
||||
element
|
||||
},
|
||||
|_, mut element, window, cx| {
|
||||
element.paint(window, cx);
|
||||
},
|
||||
)
|
||||
.flex_1()
|
||||
.w_full()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ConfigurationViewEvent {
|
||||
NewProviderContextEditor(Arc<dyn LanguageModelProvider>),
|
||||
}
|
||||
|
||||
impl EventEmitter<ConfigurationViewEvent> for ConfigurationView {}
|
||||
|
||||
impl Focusable for ConfigurationView {
|
||||
fn focus_handle(&self, _: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Item for ConfigurationView {
|
||||
type Event = ConfigurationViewEvent;
|
||||
|
||||
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
|
||||
"Configuration".into()
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -2585,6 +2585,7 @@ impl AssistantContext {
|
||||
mode: None,
|
||||
messages: Vec::new(),
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: model
|
||||
.and_then(|model| AssistantSettings::temperature_for_model(model, cx)),
|
||||
|
||||
@@ -137,7 +137,7 @@ pub enum ThoughtProcessStatus {
|
||||
Completed,
|
||||
}
|
||||
|
||||
pub trait AssistantPanelDelegate {
|
||||
pub trait AgentPanelDelegate {
|
||||
fn active_context_editor(
|
||||
&self,
|
||||
workspace: &mut Workspace,
|
||||
@@ -171,7 +171,7 @@ pub trait AssistantPanelDelegate {
|
||||
);
|
||||
}
|
||||
|
||||
impl dyn AssistantPanelDelegate {
|
||||
impl dyn AgentPanelDelegate {
|
||||
/// Returns the global [`AssistantPanelDelegate`], if it exists.
|
||||
pub fn try_global(cx: &App) -> Option<Arc<Self>> {
|
||||
cx.try_global::<GlobalAssistantPanelDelegate>()
|
||||
@@ -184,7 +184,7 @@ impl dyn AssistantPanelDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalAssistantPanelDelegate(Arc<dyn AssistantPanelDelegate>);
|
||||
struct GlobalAssistantPanelDelegate(Arc<dyn AgentPanelDelegate>);
|
||||
|
||||
impl Global for GlobalAssistantPanelDelegate {}
|
||||
|
||||
@@ -242,9 +242,9 @@ impl ContextEditor {
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor =
|
||||
Editor::for_buffer(context.read(cx).buffer().clone(), None, window, cx);
|
||||
editor.disable_scrollbars_and_minimap(cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
|
||||
editor.set_show_line_numbers(false, cx);
|
||||
editor.set_show_scrollbars(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_runnables(false, cx);
|
||||
@@ -367,10 +367,16 @@ impl ContextEditor {
|
||||
}
|
||||
|
||||
fn assist(&mut self, _: &Assist, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.sending_disabled(cx) {
|
||||
return;
|
||||
}
|
||||
self.send_to_model(RequestType::Chat, window, cx);
|
||||
}
|
||||
|
||||
fn edit(&mut self, _: &Edit, window: &mut Window, cx: &mut Context<Self>) {
|
||||
if self.sending_disabled(cx) {
|
||||
return;
|
||||
}
|
||||
self.send_to_model(RequestType::SuggestEdits, window, cx);
|
||||
}
|
||||
|
||||
@@ -942,7 +948,7 @@ impl ContextEditor {
|
||||
let patch_range = range.clone();
|
||||
move |cx: &mut BlockContext| {
|
||||
let max_width = cx.max_width;
|
||||
let gutter_width = cx.gutter_dimensions.full_width();
|
||||
let gutter_width = cx.margins.gutter.full_width();
|
||||
let block_id = cx.block_id;
|
||||
let selected = cx.selected;
|
||||
let window = &mut cx.window;
|
||||
@@ -1488,7 +1494,7 @@ impl ContextEditor {
|
||||
|
||||
h_flex()
|
||||
.id(("message_header", message_id.as_u64()))
|
||||
.pl(cx.gutter_dimensions.full_width())
|
||||
.pl(cx.margins.gutter.full_width())
|
||||
.h_11()
|
||||
.w_full()
|
||||
.relative()
|
||||
@@ -1583,6 +1589,7 @@ impl ContextEditor {
|
||||
),
|
||||
priority: usize::MAX,
|
||||
render: render_block(MessageMetadata::from(message)),
|
||||
render_in_minimap: false,
|
||||
};
|
||||
let mut new_blocks = vec![];
|
||||
let mut block_index_to_message = vec![];
|
||||
@@ -1665,11 +1672,11 @@ impl ContextEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let Some(assistant_panel_delegate) = <dyn AssistantPanelDelegate>::try_global(cx) else {
|
||||
let Some(agent_panel_delegate) = <dyn AgentPanelDelegate>::try_global(cx) else {
|
||||
return;
|
||||
};
|
||||
let Some(context_editor_view) =
|
||||
assistant_panel_delegate.active_context_editor(workspace, window, cx)
|
||||
agent_panel_delegate.active_context_editor(workspace, window, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
@@ -1695,9 +1702,9 @@ impl ContextEditor {
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let result = maybe!({
|
||||
let assistant_panel_delegate = <dyn AssistantPanelDelegate>::try_global(cx)?;
|
||||
let agent_panel_delegate = <dyn AgentPanelDelegate>::try_global(cx)?;
|
||||
let context_editor_view =
|
||||
assistant_panel_delegate.active_context_editor(workspace, window, cx)?;
|
||||
agent_panel_delegate.active_context_editor(workspace, window, cx)?;
|
||||
Self::get_selection_or_code_block(&context_editor_view, cx)
|
||||
});
|
||||
let Some((text, is_code_block)) = result else {
|
||||
@@ -1730,11 +1737,11 @@ impl ContextEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let Some(assistant_panel_delegate) = <dyn AssistantPanelDelegate>::try_global(cx) else {
|
||||
let Some(agent_panel_delegate) = <dyn AgentPanelDelegate>::try_global(cx) else {
|
||||
return;
|
||||
};
|
||||
let Some(context_editor_view) =
|
||||
assistant_panel_delegate.active_context_editor(workspace, window, cx)
|
||||
agent_panel_delegate.active_context_editor(workspace, window, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
@@ -1820,7 +1827,7 @@ impl ContextEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Workspace>,
|
||||
) {
|
||||
let Some(assistant_panel_delegate) = <dyn AssistantPanelDelegate>::try_global(cx) else {
|
||||
let Some(agent_panel_delegate) = <dyn AgentPanelDelegate>::try_global(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -1851,7 +1858,7 @@ impl ContextEditor {
|
||||
return;
|
||||
}
|
||||
|
||||
assistant_panel_delegate.quote_selection(workspace, selections, buffer, window, cx);
|
||||
agent_panel_delegate.quote_selection(workspace, selections, buffer, window, cx);
|
||||
}
|
||||
|
||||
pub fn quote_ranges(
|
||||
@@ -2157,12 +2164,12 @@ impl ContextEditor {
|
||||
let image_size = size_for_image(
|
||||
&image,
|
||||
size(
|
||||
cx.max_width - cx.gutter_dimensions.full_width(),
|
||||
cx.max_width - cx.margins.gutter.full_width(),
|
||||
MAX_HEIGHT_IN_LINES as f32 * cx.line_height,
|
||||
),
|
||||
);
|
||||
h_flex()
|
||||
.pl(cx.gutter_dimensions.full_width())
|
||||
.pl(cx.margins.gutter.full_width())
|
||||
.child(
|
||||
img(image.clone())
|
||||
.object_fit(gpui::ObjectFit::ScaleDown)
|
||||
@@ -2172,6 +2179,7 @@ impl ContextEditor {
|
||||
.into_any_element()
|
||||
}),
|
||||
priority: 0,
|
||||
render_in_minimap: false,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
@@ -2436,17 +2444,8 @@ impl ContextEditor {
|
||||
None => (ButtonStyle::Filled, None),
|
||||
};
|
||||
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
&& model
|
||||
.as_ref()
|
||||
.map_or(false, |model| model.provider.must_accept_terms(cx));
|
||||
let disabled = has_configuration_error || needs_to_accept_terms;
|
||||
|
||||
ButtonLike::new("send_button")
|
||||
.disabled(disabled)
|
||||
.disabled(self.sending_disabled(cx))
|
||||
.style(style)
|
||||
.when_some(tooltip, |button, tooltip| {
|
||||
button.tooltip(move |_, _| tooltip.clone())
|
||||
@@ -2468,6 +2467,20 @@ impl ContextEditor {
|
||||
})
|
||||
}
|
||||
|
||||
/// Whether or not we should allow messages to be sent.
|
||||
/// Will return false if the selected provided has a configuration error or
|
||||
/// if the user has not accepted the terms of service for this provider.
|
||||
fn sending_disabled(&self, cx: &mut Context<'_, ContextEditor>) -> bool {
|
||||
let model = LanguageModelRegistry::read_global(cx).default_model();
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
&& model
|
||||
.as_ref()
|
||||
.map_or(false, |model| model.provider.must_accept_terms(cx));
|
||||
has_configuration_error || needs_to_accept_terms
|
||||
}
|
||||
|
||||
fn render_edit_button(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let focus_handle = self.focus_handle(cx).clone();
|
||||
|
||||
@@ -2495,19 +2508,8 @@ impl ContextEditor {
|
||||
None => (ButtonStyle::Filled, None),
|
||||
};
|
||||
|
||||
let provider = LanguageModelRegistry::read_global(cx)
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
|
||||
let has_configuration_error = configuration_error(cx).is_some();
|
||||
let needs_to_accept_terms = self.show_accept_terms
|
||||
&& provider
|
||||
.as_ref()
|
||||
.map_or(false, |provider| provider.must_accept_terms(cx));
|
||||
let disabled = has_configuration_error || needs_to_accept_terms;
|
||||
|
||||
ButtonLike::new("edit_button")
|
||||
.disabled(disabled)
|
||||
.disabled(self.sending_disabled(cx))
|
||||
.style(style)
|
||||
.when_some(tooltip, |button, tooltip| {
|
||||
button.tooltip(move |_, _| tooltip.clone())
|
||||
@@ -3359,10 +3361,10 @@ impl FollowableItem for ContextEditor {
|
||||
let editor_state = state.editor?;
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let assistant_panel_delegate = <dyn AssistantPanelDelegate>::try_global(cx)?;
|
||||
let agent_panel_delegate = <dyn AgentPanelDelegate>::try_global(cx)?;
|
||||
|
||||
let context_editor_task = workspace.update(cx, |workspace, cx| {
|
||||
assistant_panel_delegate.open_remote_context(workspace, context_id, window, cx)
|
||||
agent_panel_delegate.open_remote_context(workspace, context_id, window, cx)
|
||||
});
|
||||
|
||||
Some(window.spawn(cx, async move |cx| {
|
||||
|
||||
@@ -8,7 +8,7 @@ use ui::{Avatar, ListItem, ListItemSpacing, prelude::*};
|
||||
use workspace::{Item, Workspace};
|
||||
|
||||
use crate::{
|
||||
AssistantPanelDelegate, ContextStore, DEFAULT_TAB_TITLE, RemoteContextMetadata,
|
||||
AgentPanelDelegate, ContextStore, DEFAULT_TAB_TITLE, RemoteContextMetadata,
|
||||
SavedContextMetadata,
|
||||
};
|
||||
|
||||
@@ -70,19 +70,19 @@ impl ContextHistory {
|
||||
) {
|
||||
let SavedContextPickerEvent::Confirmed(context) = event;
|
||||
|
||||
let Some(assistant_panel_delegate) = <dyn AssistantPanelDelegate>::try_global(cx) else {
|
||||
let Some(agent_panel_delegate) = <dyn AgentPanelDelegate>::try_global(cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.workspace
|
||||
.update(cx, |workspace, cx| match context {
|
||||
ContextMetadata::Remote(metadata) => {
|
||||
assistant_panel_delegate
|
||||
agent_panel_delegate
|
||||
.open_remote_context(workspace, metadata.id.clone(), window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
ContextMetadata::Saved(metadata) => {
|
||||
assistant_panel_delegate
|
||||
agent_panel_delegate
|
||||
.open_saved_context(workspace, metadata.path.clone(), window, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||
use icons::IconName;
|
||||
use language_model::LanguageModel;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelRequest;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
use workspace::Workspace;
|
||||
@@ -206,7 +206,7 @@ pub trait Tool: 'static + Send + Sync {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -35,25 +35,17 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
const KEYS_TO_REMOVE: [&str; 4] = [
|
||||
const KEYS_TO_REMOVE: [&str; 5] = [
|
||||
"format",
|
||||
"additionalProperties",
|
||||
"exclusiveMinimum",
|
||||
"exclusiveMaximum",
|
||||
"optional",
|
||||
];
|
||||
for key in KEYS_TO_REMOVE {
|
||||
obj.remove(key);
|
||||
}
|
||||
|
||||
if let Some(default) = obj.get("default") {
|
||||
let is_null = default.is_null();
|
||||
// Default is not supported, so we need to remove it
|
||||
obj.remove("default");
|
||||
if is_null {
|
||||
obj.insert("nullable".to_string(), Value::Bool(true));
|
||||
}
|
||||
}
|
||||
|
||||
// If a type is not specified for an input parameter, add a default type
|
||||
if matches!(obj.get("description"), Some(Value::String(_)))
|
||||
&& !obj.contains_key("type")
|
||||
@@ -92,26 +84,6 @@ mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_transform_default_null_to_nullable() {
|
||||
let mut json = json!({
|
||||
"description": "A test field",
|
||||
"type": "string",
|
||||
"default": null
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"description": "A test field",
|
||||
"type": "string",
|
||||
"nullable": true
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_adds_type_when_missing() {
|
||||
let mut json = json!({
|
||||
@@ -157,7 +129,8 @@ mod tests {
|
||||
"format": "uint32",
|
||||
"exclusiveMinimum": 0,
|
||||
"exclusiveMaximum": 100,
|
||||
"additionalProperties": false
|
||||
"additionalProperties": false,
|
||||
"optional": true
|
||||
});
|
||||
|
||||
adapt_to_json_schema_subset(&mut json).unwrap();
|
||||
|
||||
@@ -3,8 +3,8 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage};
|
||||
use language_model::LanguageModel;
|
||||
use language_model::{LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -74,7 +74,7 @@ impl Tool for CopyPathTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -61,7 +61,7 @@ impl Tool for CreateDirectoryTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -59,7 +59,7 @@ impl Tool for DeletePathTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -3,7 +3,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -79,7 +79,7 @@ impl Tool for DiagnosticsTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -17,7 +17,7 @@ use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
|
||||
use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
MessageContent, Role,
|
||||
LanguageModelToolChoice, MessageContent, Role,
|
||||
};
|
||||
use project::{AgentLocation, Project};
|
||||
use serde::Serialize;
|
||||
@@ -83,7 +83,7 @@ impl EditAgent {
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_description: String,
|
||||
previous_messages: Vec<LanguageModelRequestMessage>,
|
||||
conversation: &LanguageModelRequest,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (
|
||||
Task<Result<EditAgentOutput>>,
|
||||
@@ -91,6 +91,7 @@ impl EditAgent {
|
||||
) {
|
||||
let this = self.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
let conversation = conversation.clone();
|
||||
let output = cx.spawn(async move |cx| {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
@@ -99,7 +100,7 @@ impl EditAgent {
|
||||
edit_description,
|
||||
}
|
||||
.render(&this.templates)?;
|
||||
let new_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||
let new_chunks = this.request(conversation, prompt, cx).await?;
|
||||
|
||||
let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
|
||||
while let Some(event) = inner_events.next().await {
|
||||
@@ -194,7 +195,7 @@ impl EditAgent {
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edit_description: String,
|
||||
previous_messages: Vec<LanguageModelRequestMessage>,
|
||||
conversation: &LanguageModelRequest,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (
|
||||
Task<Result<EditAgentOutput>>,
|
||||
@@ -214,6 +215,7 @@ impl EditAgent {
|
||||
|
||||
let this = self.clone();
|
||||
let (events_tx, events_rx) = mpsc::unbounded();
|
||||
let conversation = conversation.clone();
|
||||
let output = cx.spawn(async move |cx| {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
@@ -222,7 +224,7 @@ impl EditAgent {
|
||||
edit_description,
|
||||
}
|
||||
.render(&this.templates)?;
|
||||
let edit_chunks = this.request(previous_messages, prompt, cx).await?;
|
||||
let edit_chunks = this.request(conversation, prompt, cx).await?;
|
||||
|
||||
let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
|
||||
while let Some(event) = inner_events.next().await {
|
||||
@@ -512,32 +514,67 @@ impl EditAgent {
|
||||
|
||||
async fn request(
|
||||
&self,
|
||||
mut messages: Vec<LanguageModelRequestMessage>,
|
||||
mut conversation: LanguageModelRequest,
|
||||
prompt: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
|
||||
let mut message_content = Vec::new();
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
let mut messages_iter = conversation.messages.iter_mut();
|
||||
if let Some(last_message) = messages_iter.next_back() {
|
||||
if last_message.role == Role::Assistant {
|
||||
let old_content_len = last_message.content.len();
|
||||
last_message
|
||||
.content
|
||||
.retain(|content| !matches!(content, MessageContent::ToolUse(_)));
|
||||
let new_content_len = last_message.content.len();
|
||||
|
||||
// We just removed pending tool uses from the content of the
|
||||
// last message, so it doesn't make sense to cache it anymore
|
||||
// (e.g., the message will look very different on the next
|
||||
// request). Thus, we move the flag to the message prior to it,
|
||||
// as it will still be a valid prefix of the conversation.
|
||||
if old_content_len != new_content_len && last_message.cache {
|
||||
if let Some(prev_message) = messages_iter.next_back() {
|
||||
last_message.cache = false;
|
||||
prev_message.cache = true;
|
||||
}
|
||||
}
|
||||
|
||||
if last_message.content.is_empty() {
|
||||
messages.pop();
|
||||
conversation.messages.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
message_content.push(MessageContent::Text(prompt));
|
||||
messages.push(LanguageModelRequestMessage {
|
||||
|
||||
conversation.messages.push(LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: message_content,
|
||||
content: vec![MessageContent::Text(prompt)],
|
||||
cache: false,
|
||||
});
|
||||
|
||||
// Include tools in the request so that we can take advantage of
|
||||
// caching when ToolChoice::None is supported.
|
||||
let mut tool_choice = None;
|
||||
let mut tools = Vec::new();
|
||||
if !conversation.tools.is_empty()
|
||||
&& self
|
||||
.model
|
||||
.supports_tool_choice(LanguageModelToolChoice::None)
|
||||
{
|
||||
tool_choice = Some(LanguageModelToolChoice::None);
|
||||
tools = conversation.tools.clone();
|
||||
}
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
messages,
|
||||
..Default::default()
|
||||
thread_id: conversation.thread_id,
|
||||
prompt_id: conversation.prompt_id,
|
||||
mode: conversation.mode,
|
||||
messages: conversation.messages,
|
||||
tool_choice,
|
||||
tools,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
};
|
||||
|
||||
Ok(self.model.stream_completion_text(request, cx).await?.stream)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,14 +2,16 @@ use super::*;
|
||||
use crate::{ReadFileToolInput, edit_file_tool::EditFileToolInput, grep_tool::GrepToolInput};
|
||||
use Role::*;
|
||||
use anyhow::anyhow;
|
||||
use assistant_tool::ToolRegistry;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashMap;
|
||||
use fs::FakeFs;
|
||||
use futures::{FutureExt, future::LocalBoxFuture};
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use indoc::{formatdoc, indoc};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId,
|
||||
LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse,
|
||||
LanguageModelToolUseId,
|
||||
};
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
@@ -37,7 +39,7 @@ fn eval_extract_handle_command_output() {
|
||||
conversation: vec![
|
||||
message(
|
||||
User,
|
||||
[text(indoc! {"
|
||||
[text(formatdoc! {"
|
||||
Read the `{input_file_path}` file and extract a method in
|
||||
the final stanza of `run_git_blame` to deal with command failures,
|
||||
call it `handle_command_output` and take the std::process::Output as the only parameter.
|
||||
@@ -96,7 +98,7 @@ fn eval_delete_run_git_blame() {
|
||||
conversation: vec![
|
||||
message(
|
||||
User,
|
||||
[text(indoc! {"
|
||||
[text(formatdoc! {"
|
||||
Read the `{input_file_path}` file and delete `run_git_blame`. Just that
|
||||
one function, not its usages.
|
||||
"})],
|
||||
@@ -138,6 +140,61 @@ fn eval_delete_run_git_blame() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "eval"), ignore)]
|
||||
fn eval_translate_doc_comments() {
|
||||
let input_file_path = "root/canvas.rs";
|
||||
let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs");
|
||||
let edit_description = "Translate all doc comments to Italian";
|
||||
eval(
|
||||
200,
|
||||
1.,
|
||||
EvalInput {
|
||||
conversation: vec![
|
||||
message(
|
||||
User,
|
||||
[text(formatdoc! {"
|
||||
Read the {input_file_path} file and edit it (without overwriting it),
|
||||
translating all the doc comments to italian.
|
||||
"})],
|
||||
),
|
||||
message(
|
||||
Assistant,
|
||||
[tool_use(
|
||||
"tool_1",
|
||||
"read_file",
|
||||
ReadFileToolInput {
|
||||
path: input_file_path.into(),
|
||||
start_line: None,
|
||||
end_line: None,
|
||||
},
|
||||
)],
|
||||
),
|
||||
message(
|
||||
User,
|
||||
[tool_result("tool_1", "read_file", input_file_content)],
|
||||
),
|
||||
message(
|
||||
Assistant,
|
||||
[tool_use(
|
||||
"tool_2",
|
||||
"edit_file",
|
||||
EditFileToolInput {
|
||||
display_description: edit_description.into(),
|
||||
path: input_file_path.into(),
|
||||
create_or_overwrite: false,
|
||||
},
|
||||
)],
|
||||
),
|
||||
],
|
||||
input_path: input_file_path.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::judge_diff("Doc comments were translated to Italian"),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "eval"), ignore)]
|
||||
fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||
@@ -152,7 +209,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||
conversation: vec![
|
||||
message(
|
||||
User,
|
||||
[text(indoc! {"
|
||||
[text(formatdoc! {"
|
||||
Read the `{input_file_path}` file and change `compile_parser_to_wasm` to use `wasi-sdk` instead of emscripten.
|
||||
Use `ureq` to download the SDK for the current platform and architecture.
|
||||
Extract the archive into a sibling of `lib` inside the `tree-sitter` directory in the cache_dir.
|
||||
@@ -160,7 +217,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||
that's inside of the archive.
|
||||
Don't re-download the SDK if that executable already exists.
|
||||
|
||||
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{language_name}
|
||||
Use these clang flags: -fPIC -shared -Os -Wl,--export=tree_sitter_{{language_name}}
|
||||
|
||||
Here are the available wasi-sdk assets:
|
||||
- wasi-sdk-25.0-x86_64-macos.tar.gz
|
||||
@@ -261,11 +318,10 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() {
|
||||
fn eval_disable_cursor_blinking() {
|
||||
let input_file_path = "root/editor.rs";
|
||||
let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs");
|
||||
let output_file_content = include_str!("evals/fixtures/disable_cursor_blinking/after.rs");
|
||||
let edit_description = "Comment out the call to `BlinkManager::enable`";
|
||||
eval(
|
||||
200,
|
||||
0.6, // TODO: make this eval better
|
||||
0.95,
|
||||
EvalInput {
|
||||
conversation: vec![
|
||||
message(User, [text("Let's research how to cursor blinking works.")]),
|
||||
@@ -324,7 +380,11 @@ fn eval_disable_cursor_blinking() {
|
||||
input_path: input_file_path.into(),
|
||||
input_content: Some(input_file_content.into()),
|
||||
edit_description: edit_description.into(),
|
||||
assertion: EvalAssertion::assert_eq(output_file_content),
|
||||
assertion: EvalAssertion::judge_diff(indoc! {"
|
||||
- Calls to BlinkManager in `observe_window_activation` were commented out
|
||||
- The call to `blink_manager.enable` above the call to show_cursor_names was commented out
|
||||
- All the edits have valid indentation
|
||||
"}),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -1031,7 +1091,8 @@ impl EvalAssertion {
|
||||
|
||||
fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||
let mut evaluated_count = 0;
|
||||
report_progress(evaluated_count, iterations);
|
||||
let mut failed_count = 0;
|
||||
report_progress(evaluated_count, failed_count, iterations);
|
||||
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
@@ -1048,7 +1109,6 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||
}
|
||||
drop(tx);
|
||||
|
||||
let mut failed_count = 0;
|
||||
let mut failed_evals = HashMap::default();
|
||||
let mut errored_evals = HashMap::default();
|
||||
let mut eval_outputs = Vec::new();
|
||||
@@ -1073,7 +1133,7 @@ fn eval(iterations: usize, expected_pass_ratio: f32, mut eval: EvalInput) {
|
||||
}
|
||||
|
||||
evaluated_count += 1;
|
||||
report_progress(evaluated_count, iterations);
|
||||
report_progress(evaluated_count, failed_count, iterations);
|
||||
}
|
||||
|
||||
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
|
||||
@@ -1144,8 +1204,19 @@ impl Display for EvalOutput {
|
||||
}
|
||||
}
|
||||
|
||||
fn report_progress(evaluated_count: usize, iterations: usize) {
|
||||
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
|
||||
fn report_progress(evaluated_count: usize, failed_count: usize, iterations: usize) {
|
||||
let passed_count = evaluated_count - failed_count;
|
||||
let passed_ratio = if evaluated_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
passed_count as f64 / evaluated_count as f64
|
||||
};
|
||||
print!(
|
||||
"\r\x1b[KEvaluated {}/{} ({:.2}%)",
|
||||
evaluated_count,
|
||||
iterations,
|
||||
passed_ratio * 100.0
|
||||
);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
@@ -1158,25 +1229,30 @@ struct EditAgentTest {
|
||||
impl EditAgentTest {
|
||||
async fn new(cx: &mut TestAppContext) -> Self {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(language::init);
|
||||
cx.update(gpui_tokio::init);
|
||||
cx.update(client::init_settings);
|
||||
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
cx.update(|cx| {
|
||||
settings::init(cx);
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
|
||||
cx.set_http_client(http_client);
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
|
||||
settings::init(cx);
|
||||
Project::init_settings(cx);
|
||||
language::init(cx);
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
crate::init(client.http_client(), cx);
|
||||
});
|
||||
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let (agent_model, judge_model) = cx
|
||||
.update(|cx| {
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let agent_model =
|
||||
Self::load_model("anthropic", "claude-3-7-sonnet-latest", cx).await;
|
||||
@@ -1225,12 +1301,32 @@ impl EditAgentTest {
|
||||
.update(cx, |project, cx| project.open_buffer(path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let conversation = LanguageModelRequest {
|
||||
messages: eval.conversation,
|
||||
tools: cx.update(|cx| {
|
||||
ToolRegistry::default_global(cx)
|
||||
.tools()
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
let input_schema = tool
|
||||
.input_schema(self.agent.model.tool_input_format())
|
||||
.ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let edit_output = if let Some(input_content) = eval.input_content.as_deref() {
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(input_content, cx));
|
||||
let (edit_output, _) = self.agent.edit(
|
||||
buffer.clone(),
|
||||
eval.edit_description,
|
||||
eval.conversation,
|
||||
&conversation,
|
||||
&mut cx.to_async(),
|
||||
);
|
||||
edit_output.await?
|
||||
@@ -1238,7 +1334,7 @@ impl EditAgentTest {
|
||||
let (edit_output, _) = self.agent.overwrite(
|
||||
buffer.clone(),
|
||||
eval.edit_description,
|
||||
eval.conversation,
|
||||
&conversation,
|
||||
&mut cx.to_async(),
|
||||
);
|
||||
edit_output.await?
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,339 @@
|
||||
// font-kit/src/canvas.rs
|
||||
//
|
||||
// Copyright © 2018 The Pathfinder Project Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! An in-memory bitmap surface for glyph rasterization.
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use pathfinder_geometry::rect::RectI;
|
||||
use pathfinder_geometry::vector::Vector2I;
|
||||
use std::cmp;
|
||||
use std::fmt;
|
||||
|
||||
use crate::utils;
|
||||
|
||||
lazy_static! {
|
||||
static ref BITMAP_1BPP_TO_8BPP_LUT: [[u8; 8]; 256] = {
|
||||
let mut lut = [[0; 8]; 256];
|
||||
for byte in 0..0x100 {
|
||||
let mut value = [0; 8];
|
||||
for bit in 0..8 {
|
||||
if (byte & (0x80 >> bit)) != 0 {
|
||||
value[bit] = 0xff;
|
||||
}
|
||||
}
|
||||
lut[byte] = value
|
||||
}
|
||||
lut
|
||||
};
|
||||
}
|
||||
|
||||
/// An in-memory bitmap surface for glyph rasterization.
|
||||
pub struct Canvas {
|
||||
/// The raw pixel data.
|
||||
pub pixels: Vec<u8>,
|
||||
/// The size of the buffer, in pixels.
|
||||
pub size: Vector2I,
|
||||
/// The number of *bytes* between successive rows.
|
||||
pub stride: usize,
|
||||
/// The image format of the canvas.
|
||||
pub format: Format,
|
||||
}
|
||||
|
||||
impl Canvas {
|
||||
/// Creates a new blank canvas with the given pixel size and format.
|
||||
///
|
||||
/// Stride is automatically calculated from width.
|
||||
///
|
||||
/// The canvas is initialized with transparent black (all values 0).
|
||||
#[inline]
|
||||
pub fn new(size: Vector2I, format: Format) -> Canvas {
|
||||
Canvas::with_stride(
|
||||
size,
|
||||
size.x() as usize * format.bytes_per_pixel() as usize,
|
||||
format,
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new blank canvas with the given pixel size, stride (number of bytes between
|
||||
/// successive rows), and format.
|
||||
///
|
||||
/// The canvas is initialized with transparent black (all values 0).
|
||||
pub fn with_stride(size: Vector2I, stride: usize, format: Format) -> Canvas {
|
||||
Canvas {
|
||||
pixels: vec![0; stride * size.y() as usize],
|
||||
size,
|
||||
stride,
|
||||
format,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn blit_from_canvas(&mut self, src: &Canvas) {
|
||||
self.blit_from(
|
||||
Vector2I::default(),
|
||||
&src.pixels,
|
||||
src.size,
|
||||
src.stride,
|
||||
src.format,
|
||||
)
|
||||
}
|
||||
|
||||
/// Blits to a rectangle with origin at `dst_point` and size according to `src_size`.
|
||||
/// If the target area overlaps the boundaries of the canvas, only the drawable region is blitted.
|
||||
/// `dst_point` and `src_size` are specified in pixels. `src_stride` is specified in bytes.
|
||||
/// `src_stride` must be equal or larger than the actual data length.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn blit_from(
|
||||
&mut self,
|
||||
dst_point: Vector2I,
|
||||
src_bytes: &[u8],
|
||||
src_size: Vector2I,
|
||||
src_stride: usize,
|
||||
src_format: Format,
|
||||
) {
|
||||
assert_eq!(
|
||||
src_stride * src_size.y() as usize,
|
||||
src_bytes.len(),
|
||||
"Number of pixels in src_bytes does not match stride and size."
|
||||
);
|
||||
assert!(
|
||||
src_stride >= src_size.x() as usize * src_format.bytes_per_pixel() as usize,
|
||||
"src_stride must be >= than src_size.x()"
|
||||
);
|
||||
|
||||
let dst_rect = RectI::new(dst_point, src_size);
|
||||
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
|
||||
let dst_rect = match dst_rect {
|
||||
Some(dst_rect) => dst_rect,
|
||||
None => return,
|
||||
};
|
||||
|
||||
match (self.format, src_format) {
|
||||
(Format::A8, Format::A8)
|
||||
| (Format::Rgb24, Format::Rgb24)
|
||||
| (Format::Rgba32, Format::Rgba32) => {
|
||||
self.blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format)
|
||||
}
|
||||
(Format::A8, Format::Rgb24) => {
|
||||
self.blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format)
|
||||
}
|
||||
(Format::Rgb24, Format::A8) => {
|
||||
self.blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format)
|
||||
}
|
||||
(Format::Rgb24, Format::Rgba32) => self
|
||||
.blit_from_with::<BlitRgba32ToRgb24>(dst_rect, src_bytes, src_stride, src_format),
|
||||
(Format::Rgba32, Format::Rgb24) => self
|
||||
.blit_from_with::<BlitRgb24ToRgba32>(dst_rect, src_bytes, src_stride, src_format),
|
||||
(Format::Rgba32, Format::A8) | (Format::A8, Format::Rgba32) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn blit_from_bitmap_1bpp(
|
||||
&mut self,
|
||||
dst_point: Vector2I,
|
||||
src_bytes: &[u8],
|
||||
src_size: Vector2I,
|
||||
src_stride: usize,
|
||||
) {
|
||||
if self.format != Format::A8 {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
let dst_rect = RectI::new(dst_point, src_size);
|
||||
let dst_rect = dst_rect.intersection(RectI::new(Vector2I::default(), self.size));
|
||||
let dst_rect = match dst_rect {
|
||||
Some(dst_rect) => dst_rect,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let size = dst_rect.size();
|
||||
|
||||
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
|
||||
let dest_row_stride = size.x() as usize * dest_bytes_per_pixel;
|
||||
let src_row_stride = utils::div_round_up(size.x() as usize, 8);
|
||||
|
||||
for y in 0..size.y() {
|
||||
let (dest_row_start, src_row_start) = (
|
||||
(y + dst_rect.origin_y()) as usize * self.stride
|
||||
+ dst_rect.origin_x() as usize * dest_bytes_per_pixel,
|
||||
y as usize * src_stride,
|
||||
);
|
||||
let dest_row_end = dest_row_start + dest_row_stride;
|
||||
let src_row_end = src_row_start + src_row_stride;
|
||||
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
|
||||
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
|
||||
for x in 0..src_row_stride {
|
||||
let pattern = &BITMAP_1BPP_TO_8BPP_LUT[src_row_pixels[x] as usize];
|
||||
let dest_start = x * 8;
|
||||
let dest_end = cmp::min(dest_start + 8, dest_row_stride);
|
||||
let src = &pattern[0..(dest_end - dest_start)];
|
||||
dest_row_pixels[dest_start..dest_end].clone_from_slice(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Blits to area `rect` using the data given in the buffer `src_bytes`.
|
||||
/// `src_stride` must be specified in bytes.
|
||||
/// The dimensions of `rect` must be in pixels.
|
||||
fn blit_from_with<B: Blit>(
|
||||
&mut self,
|
||||
rect: RectI,
|
||||
src_bytes: &[u8],
|
||||
src_stride: usize,
|
||||
src_format: Format,
|
||||
) {
|
||||
let src_bytes_per_pixel = src_format.bytes_per_pixel() as usize;
|
||||
let dest_bytes_per_pixel = self.format.bytes_per_pixel() as usize;
|
||||
|
||||
for y in 0..rect.height() {
|
||||
let (dest_row_start, src_row_start) = (
|
||||
(y + rect.origin_y()) as usize * self.stride
|
||||
+ rect.origin_x() as usize * dest_bytes_per_pixel,
|
||||
y as usize * src_stride,
|
||||
);
|
||||
let dest_row_end = dest_row_start + rect.width() as usize * dest_bytes_per_pixel;
|
||||
let src_row_end = src_row_start + rect.width() as usize * src_bytes_per_pixel;
|
||||
let dest_row_pixels = &mut self.pixels[dest_row_start..dest_row_end];
|
||||
let src_row_pixels = &src_bytes[src_row_start..src_row_end];
|
||||
B::blit(dest_row_pixels, src_row_pixels)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Canvas {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.debug_struct("Canvas")
|
||||
.field("pixels", &self.pixels.len()) // Do not dump a vector content.
|
||||
.field("size", &self.size)
|
||||
.field("stride", &self.stride)
|
||||
.field("format", &self.format)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// The image format for the canvas.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum Format {
|
||||
/// Premultiplied R8G8B8A8, little-endian.
|
||||
Rgba32,
|
||||
/// R8G8B8, little-endian.
|
||||
Rgb24,
|
||||
/// A8.
|
||||
A8,
|
||||
}
|
||||
|
||||
impl Format {
|
||||
/// Returns the number of bits per pixel that this image format corresponds to.
|
||||
#[inline]
|
||||
pub fn bits_per_pixel(self) -> u8 {
|
||||
match self {
|
||||
Format::Rgba32 => 32,
|
||||
Format::Rgb24 => 24,
|
||||
Format::A8 => 8,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of color channels per pixel that this image format corresponds to.
|
||||
#[inline]
|
||||
pub fn components_per_pixel(self) -> u8 {
|
||||
match self {
|
||||
Format::Rgba32 => 4,
|
||||
Format::Rgb24 => 3,
|
||||
Format::A8 => 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of bits per color channel that this image format contains.
|
||||
#[inline]
|
||||
pub fn bits_per_component(self) -> u8 {
|
||||
self.bits_per_pixel() / self.components_per_pixel()
|
||||
}
|
||||
|
||||
/// Returns the number of bytes per pixel that this image format corresponds to.
|
||||
#[inline]
|
||||
pub fn bytes_per_pixel(self) -> u8 {
|
||||
self.bits_per_pixel() / 8
|
||||
}
|
||||
}
|
||||
|
||||
/// The antialiasing strategy that should be used when rasterizing glyphs.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum RasterizationOptions {
|
||||
/// "Black-and-white" rendering. Each pixel is either entirely on or off.
|
||||
Bilevel,
|
||||
/// Grayscale antialiasing. Only one channel is used.
|
||||
GrayscaleAa,
|
||||
/// Subpixel RGB antialiasing, for LCD screens.
|
||||
SubpixelAa,
|
||||
}
|
||||
|
||||
trait Blit {
|
||||
fn blit(dest: &mut [u8], src: &[u8]);
|
||||
}
|
||||
|
||||
struct BlitMemcpy;
|
||||
|
||||
impl Blit for BlitMemcpy {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
dest.clone_from_slice(src)
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitRgb24ToA8;
|
||||
|
||||
impl Blit for BlitRgb24ToA8 {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
// TODO(pcwalton): SIMD.
|
||||
for (dest, src) in dest.iter_mut().zip(src.chunks(3)) {
|
||||
*dest = src[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitA8ToRgb24;
|
||||
|
||||
impl Blit for BlitA8ToRgb24 {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
for (dest, src) in dest.chunks_mut(3).zip(src.iter()) {
|
||||
dest[0] = *src;
|
||||
dest[1] = *src;
|
||||
dest[2] = *src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitRgba32ToRgb24;
|
||||
|
||||
impl Blit for BlitRgba32ToRgb24 {
|
||||
#[inline]
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
// TODO(pcwalton): SIMD.
|
||||
for (dest, src) in dest.chunks_mut(3).zip(src.chunks(4)) {
|
||||
dest.copy_from_slice(&src[0..3])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlitRgb24ToRgba32;
|
||||
|
||||
impl Blit for BlitRgb24ToRgba32 {
|
||||
fn blit(dest: &mut [u8], src: &[u8]) {
|
||||
for (dest, src) in dest.chunks_mut(4).zip(src.chunks(3)) {
|
||||
dest[0] = src[0];
|
||||
dest[1] = src[1];
|
||||
dest[2] = src[2];
|
||||
dest[3] = 255;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,18 +8,18 @@ use assistant_tool::{
|
||||
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus,
|
||||
};
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{Editor, EditorElement, EditorMode, EditorStyle, MultiBuffer, PathKey};
|
||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
|
||||
use futures::StreamExt;
|
||||
use gpui::{
|
||||
Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, EntityId, Task,
|
||||
TextStyle, WeakEntity, pulsating_between,
|
||||
TextStyleRefinement, WeakEntity, pulsating_between,
|
||||
};
|
||||
use indoc::formatdoc;
|
||||
use language::{
|
||||
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
|
||||
language_settings::SoftWrap,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -146,7 +146,7 @@ impl Tool for EditFileTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
@@ -177,7 +177,6 @@ impl Tool for EditFileTool {
|
||||
});
|
||||
|
||||
let card_clone = card.clone();
|
||||
let messages = messages.to_vec();
|
||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
|
||||
|
||||
@@ -209,14 +208,14 @@ impl Tool for EditFileTool {
|
||||
edit_agent.overwrite(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
&request,
|
||||
cx,
|
||||
)
|
||||
} else {
|
||||
edit_agent.edit(
|
||||
buffer.clone(),
|
||||
input.display_description.clone(),
|
||||
messages,
|
||||
&request,
|
||||
cx,
|
||||
)
|
||||
};
|
||||
@@ -360,9 +359,9 @@ impl EditFileToolCard {
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.disable_scrollbars_and_minimap(cx);
|
||||
editor.set_soft_wrap_mode(SoftWrap::None, cx);
|
||||
editor.scroll_manager.set_forbid_vertical_scroll(true);
|
||||
editor.set_show_scrollbars(false, cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_read_only(true);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
@@ -574,33 +573,16 @@ impl ToolCard for EditFileToolCard {
|
||||
.map(|style| style.text.line_height_in_pixels(window.rem_size()))
|
||||
.unwrap_or_default();
|
||||
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let element = EditorElement::new(
|
||||
&cx.entity(),
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
horizontal_padding: rems(0.25).to_pixels(window.rem_size()),
|
||||
local_player: cx.theme().players().local(),
|
||||
text: TextStyle {
|
||||
color: cx.theme().colors().editor_foreground,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_size: TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(settings.agent_font_size(cx))
|
||||
.into(),
|
||||
font_weight: settings.buffer_font.weight,
|
||||
line_height: relative(settings.buffer_line_height.value()),
|
||||
..Default::default()
|
||||
},
|
||||
scrollbar_width: EditorElement::SCROLLBAR_WIDTH,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
status: cx.theme().status().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
editor.set_text_style_refinement(TextStyleRefinement {
|
||||
font_size: Some(
|
||||
TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
|
||||
.into(),
|
||||
),
|
||||
..TextStyleRefinement::default()
|
||||
});
|
||||
let element = editor.render(window, cx);
|
||||
(element.into_any_element(), line_height)
|
||||
});
|
||||
|
||||
@@ -864,7 +846,15 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
Arc::new(EditFileTool)
|
||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -9,7 +9,7 @@ use futures::AsyncReadExt as _;
|
||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -142,7 +142,7 @@ impl Tool for FetchTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -7,7 +7,7 @@ use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
};
|
||||
use language;
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -73,7 +73,7 @@ impl Tool for FindPathTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -4,7 +4,7 @@ use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::StreamExt;
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language::{OffsetRangeExt, ParseStatus, Point};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::{
|
||||
Project,
|
||||
search::{SearchQuery, SearchResult},
|
||||
@@ -96,7 +96,7 @@ impl Tool for GrepTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
@@ -746,7 +746,8 @@ mod tests {
|
||||
let tool = Arc::new(GrepTool);
|
||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, model, None, cx));
|
||||
let task =
|
||||
cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
|
||||
|
||||
match task.output.await {
|
||||
Ok(result) => {
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -73,7 +73,7 @@ impl Tool for ListDirectoryTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -86,7 +86,7 @@ impl Tool for MovePathTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -5,7 +5,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -56,7 +56,7 @@ impl Tool for NowTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -49,7 +49,7 @@ impl Tool for OpenTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -7,7 +7,7 @@ use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use itertools::Itertools;
|
||||
use language::{Anchor, Point};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::{AgentLocation, Project};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -83,7 +83,7 @@ impl Tool for ReadFileTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
@@ -231,7 +231,15 @@ mod test {
|
||||
"path": "root/nonexistent_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -262,7 +270,15 @@ mod test {
|
||||
"path": "root/small_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -295,7 +311,7 @@ mod test {
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
&[],
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
@@ -325,7 +341,15 @@ mod test {
|
||||
"offset": 1
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -372,7 +396,15 @@ mod test {
|
||||
"end_line": 4
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -406,7 +438,7 @@ mod test {
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
&[],
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
@@ -429,7 +461,7 @@ mod test {
|
||||
Arc::new(ReadFileTool)
|
||||
.run(
|
||||
input,
|
||||
&[],
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
@@ -450,7 +482,15 @@ mod test {
|
||||
"end_line": 2
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, model, None, cx)
|
||||
.run(
|
||||
input,
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log,
|
||||
model,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
You are an expert text editor and your task is to produce a series of edits to a file given a description of the changes you need to make.
|
||||
|
||||
You MUST respond with a series of edits to that one file in the following format:
|
||||
You MUST respond with a series of edits to a file, using the following format:
|
||||
|
||||
```
|
||||
<edits>
|
||||
@@ -51,3 +49,5 @@ Rules for editing:
|
||||
<edit_description>
|
||||
{{edit_description}}
|
||||
</edit_description>
|
||||
|
||||
Tool calls have been disabled. You MUST start your response with <edits>.
|
||||
|
||||
@@ -4,7 +4,7 @@ use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, WeakEntity, Window};
|
||||
use language::LineEnding;
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
|
||||
use project::{Project, terminals::TerminalKind};
|
||||
use schemars::JsonSchema;
|
||||
@@ -107,7 +107,7 @@ impl Tool for TerminalTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
@@ -656,7 +656,7 @@ mod tests {
|
||||
TerminalTool::run(
|
||||
Arc::new(TerminalTool::new(cx)),
|
||||
serde_json::to_value(input).unwrap(),
|
||||
&[],
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model,
|
||||
@@ -691,7 +691,7 @@ mod tests {
|
||||
let headless_result = TerminalTool::run(
|
||||
Arc::new(TerminalTool::new(cx)),
|
||||
serde_json::to_value(input).unwrap(),
|
||||
&[],
|
||||
Arc::default(),
|
||||
project.clone(),
|
||||
action_log.clone(),
|
||||
model.clone(),
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -47,7 +47,7 @@ impl Tool for ThinkingTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -8,7 +8,7 @@ use futures::{Future, FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -53,7 +53,7 @@ impl Tool for WebSearchTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
|
||||
@@ -335,9 +335,13 @@ impl AutoUpdater {
|
||||
self.status.clone()
|
||||
}
|
||||
|
||||
pub fn dismiss_error(&mut self, cx: &mut Context<Self>) {
|
||||
pub fn dismiss_error(&mut self, cx: &mut Context<Self>) -> bool {
|
||||
if self.status == AutoUpdateStatus::Idle {
|
||||
return false;
|
||||
}
|
||||
self.status = AutoUpdateStatus::Idle;
|
||||
cx.notify();
|
||||
true
|
||||
}
|
||||
|
||||
// If you are packaging Zed and need to override the place it downloads SSH remotes from,
|
||||
|
||||
@@ -7,9 +7,10 @@ use anyhow::{Error, Result, anyhow};
|
||||
use aws_sdk_bedrockruntime as bedrock;
|
||||
pub use aws_sdk_bedrockruntime as bedrock_client;
|
||||
pub use aws_sdk_bedrockruntime::types::{
|
||||
AutoToolChoice as BedrockAutoToolChoice, ContentBlock as BedrockInnerContent,
|
||||
Tool as BedrockTool, ToolChoice as BedrockToolChoice, ToolConfiguration as BedrockToolConfig,
|
||||
ToolInputSchema as BedrockToolInputSchema, ToolSpecification as BedrockToolSpec,
|
||||
AnyToolChoice as BedrockAnyToolChoice, AutoToolChoice as BedrockAutoToolChoice,
|
||||
ContentBlock as BedrockInnerContent, Tool as BedrockTool, ToolChoice as BedrockToolChoice,
|
||||
ToolConfiguration as BedrockToolConfig, ToolInputSchema as BedrockToolInputSchema,
|
||||
ToolSpecification as BedrockToolSpec,
|
||||
};
|
||||
pub use aws_smithy_types::Blob as BedrockBlob;
|
||||
use aws_smithy_types::{Document, Number as AwsNumber};
|
||||
|
||||
@@ -11,7 +11,7 @@ use postage::{sink::Sink, watch};
|
||||
use rpc::proto::{RequestMessage, UsersResponse};
|
||||
use std::sync::{Arc, Weak};
|
||||
use text::ReplicaId;
|
||||
use util::TryFutureExt as _;
|
||||
use util::{TryFutureExt as _, maybe};
|
||||
|
||||
pub type UserId = u64;
|
||||
|
||||
@@ -101,6 +101,7 @@ pub struct UserStore {
|
||||
participant_indices: HashMap<u64, ParticipantIndex>,
|
||||
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
|
||||
current_plan: Option<proto::Plan>,
|
||||
subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
|
||||
trial_started_at: Option<DateTime<Utc>>,
|
||||
model_request_usage_amount: Option<u32>,
|
||||
model_request_usage_limit: Option<proto::UsageLimit>,
|
||||
@@ -166,6 +167,7 @@ impl UserStore {
|
||||
by_github_login: Default::default(),
|
||||
current_user: current_user_rx,
|
||||
current_plan: None,
|
||||
subscription_period: None,
|
||||
trial_started_at: None,
|
||||
model_request_usage_amount: None,
|
||||
model_request_usage_limit: None,
|
||||
@@ -333,6 +335,13 @@ impl UserStore {
|
||||
) -> Result<()> {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.current_plan = Some(message.payload.plan());
|
||||
this.subscription_period = maybe!({
|
||||
let period = message.payload.subscription_period?;
|
||||
let started_at = DateTime::from_timestamp(period.started_at as i64, 0)?;
|
||||
let ended_at = DateTime::from_timestamp(period.ended_at as i64, 0)?;
|
||||
|
||||
Some((started_at, ended_at))
|
||||
});
|
||||
this.trial_started_at = message
|
||||
.payload
|
||||
.trial_started_at
|
||||
@@ -713,6 +722,10 @@ impl UserStore {
|
||||
self.current_plan
|
||||
}
|
||||
|
||||
pub fn subscription_period(&self) -> Option<(DateTime<Utc>, DateTime<Utc>)> {
|
||||
self.subscription_period
|
||||
}
|
||||
|
||||
pub fn trial_started_at(&self) -> Option<DateTime<Utc>> {
|
||||
self.trial_started_at
|
||||
}
|
||||
|
||||
@@ -76,7 +76,6 @@ workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
assistant = { workspace = true, features = ["test-support"] }
|
||||
assistant_context_editor.workspace = true
|
||||
assistant_settings.workspace = true
|
||||
assistant_slash_command.workspace = true
|
||||
|
||||
@@ -71,6 +71,7 @@ struct GetBillingPreferencesParams {
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct BillingPreferencesResponse {
|
||||
trial_started_at: Option<String>,
|
||||
max_monthly_llm_usage_spending_in_cents: i32,
|
||||
model_request_overages_enabled: bool,
|
||||
model_request_overages_spend_limit_in_cents: i32,
|
||||
@@ -86,9 +87,17 @@ async fn get_billing_preferences(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
|
||||
let preferences = app.db.get_billing_preferences(user.id).await?;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
trial_started_at: billing_customer
|
||||
.and_then(|billing_customer| billing_customer.trial_started_at)
|
||||
.map(|trial_started_at| {
|
||||
trial_started_at
|
||||
.and_utc()
|
||||
.to_rfc3339_opts(SecondsFormat::Millis, true)
|
||||
}),
|
||||
max_monthly_llm_usage_spending_in_cents: preferences
|
||||
.as_ref()
|
||||
.map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
|
||||
@@ -127,6 +136,8 @@ async fn update_billing_preferences(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("user not found"))?;
|
||||
|
||||
let billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?;
|
||||
|
||||
let max_monthly_llm_usage_spending_in_cents =
|
||||
body.max_monthly_llm_usage_spending_in_cents.max(0);
|
||||
let model_request_overages_spend_limit_in_cents =
|
||||
@@ -182,6 +193,13 @@ async fn update_billing_preferences(
|
||||
rpc_server.refresh_llm_tokens_for_user(user.id).await;
|
||||
|
||||
Ok(Json(BillingPreferencesResponse {
|
||||
trial_started_at: billing_customer
|
||||
.and_then(|billing_customer| billing_customer.trial_started_at)
|
||||
.map(|trial_started_at| {
|
||||
trial_started_at
|
||||
.and_utc()
|
||||
.to_rfc3339_opts(SecondsFormat::Millis, true)
|
||||
}),
|
||||
max_monthly_llm_usage_spending_in_cents: billing_preferences
|
||||
.max_monthly_llm_usage_spending_in_cents,
|
||||
model_request_overages_enabled: billing_preferences.model_request_overages_enabled,
|
||||
@@ -301,13 +319,6 @@ async fn create_billing_subscription(
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::error!("failed to retrieve LLM database");
|
||||
Err(Error::http(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"not supported".into(),
|
||||
))?
|
||||
};
|
||||
|
||||
if app.db.has_active_billing_subscription(user.id).await? {
|
||||
return Err(Error::http(
|
||||
@@ -399,16 +410,10 @@ async fn create_billing_subscription(
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
let default_model = llm_db.model(
|
||||
zed_llm_client::LanguageModelProvider::Anthropic,
|
||||
"claude-3-7-sonnet",
|
||||
)?;
|
||||
let stripe_model = stripe_billing
|
||||
.register_model_for_token_based_usage(default_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
|
||||
.await?
|
||||
return Err(Error::http(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No product selected".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1381,81 +1386,6 @@ async fn find_or_create_billing_customer(
|
||||
Ok(Some(billing_customer))
|
||||
}
|
||||
|
||||
const SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_token_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
let Some(stripe_billing) = app.stripe_billing.clone() else {
|
||||
log::warn!("failed to retrieve Stripe billing object");
|
||||
return;
|
||||
};
|
||||
let Some(llm_db) = app.llm_db.clone() else {
|
||||
log::warn!("failed to retrieve LLM database");
|
||||
return;
|
||||
};
|
||||
|
||||
let executor = app.executor.clone();
|
||||
executor.spawn_detached({
|
||||
let executor = executor.clone();
|
||||
async move {
|
||||
loop {
|
||||
sync_token_usage_with_stripe(&app, &llm_db, &stripe_billing)
|
||||
.await
|
||||
.context("failed to sync LLM usage to Stripe")
|
||||
.trace_err();
|
||||
executor
|
||||
.sleep(SYNC_LLM_TOKEN_USAGE_WITH_STRIPE_INTERVAL)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn sync_token_usage_with_stripe(
|
||||
app: &Arc<AppState>,
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let events = llm_db.get_billing_events().await?;
|
||||
let user_ids = events
|
||||
.iter()
|
||||
.map(|(event, _)| event.user_id)
|
||||
.collect::<HashSet<UserId>>();
|
||||
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;
|
||||
|
||||
for (event, model) in events {
|
||||
let Some((stripe_db_customer, stripe_db_subscription)) =
|
||||
stripe_subscriptions.get(&event.user_id)
|
||||
else {
|
||||
tracing::warn!(
|
||||
user_id = event.user_id.0,
|
||||
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
|
||||
.stripe_subscription_id
|
||||
.parse()
|
||||
.context("failed to parse stripe subscription id from db")?;
|
||||
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
|
||||
.stripe_customer_id
|
||||
.parse()
|
||||
.context("failed to parse stripe customer id from db")?;
|
||||
|
||||
let stripe_model = stripe_billing
|
||||
.register_model_for_token_based_usage(&model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
|
||||
.await?;
|
||||
stripe_billing
|
||||
.bill_model_token_usage(&stripe_customer_id, &stripe_model, &event)
|
||||
.await?;
|
||||
llm_db.consume_billing_event(event.id).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const SYNC_LLM_REQUEST_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);
|
||||
|
||||
pub fn sync_llm_request_usage_with_stripe_periodically(app: Arc<AppState>) {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use super::*;
|
||||
|
||||
pub mod billing_events;
|
||||
pub mod providers;
|
||||
pub mod subscription_usage_meters;
|
||||
pub mod subscription_usages;
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use super::*;
|
||||
use crate::Result;
|
||||
use anyhow::Context as _;
|
||||
|
||||
impl LlmDatabase {
|
||||
pub async fn get_billing_events(&self) -> Result<Vec<(billing_event::Model, model::Model)>> {
|
||||
self.transaction(|tx| async move {
|
||||
let events_with_models = billing_event::Entity::find()
|
||||
.find_also_related(model::Entity)
|
||||
.all(&*tx)
|
||||
.await?;
|
||||
events_with_models
|
||||
.into_iter()
|
||||
.map(|(event, model)| {
|
||||
let model =
|
||||
model.context("could not find model associated with billing event")?;
|
||||
Ok((event, model))
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn consume_billing_event(&self, id: BillingEventId) -> Result<()> {
|
||||
self.transaction(|tx| async move {
|
||||
billing_event::Entity::delete_by_id(id).exec(&*tx).await?;
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
pub mod billing_event;
|
||||
pub mod model;
|
||||
pub mod monthly_usage;
|
||||
pub mod provider;
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
use crate::{
|
||||
db::UserId,
|
||||
llm::db::{BillingEventId, ModelId},
|
||||
};
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "billing_events")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key)]
|
||||
pub id: BillingEventId,
|
||||
pub idempotency_key: Uuid,
|
||||
pub user_id: UserId,
|
||||
pub model_id: ModelId,
|
||||
pub input_tokens: i64,
|
||||
pub input_cache_creation_tokens: i64,
|
||||
pub input_cache_read_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::model::Entity",
|
||||
from = "Column::ModelId",
|
||||
to = "super::model::Column::Id"
|
||||
)]
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Related<super::model::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::Model.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -31,8 +31,6 @@ pub enum Relation {
|
||||
Provider,
|
||||
#[sea_orm(has_many = "super::usage::Entity")]
|
||||
Usages,
|
||||
#[sea_orm(has_many = "super::billing_event::Entity")]
|
||||
BillingEvents,
|
||||
}
|
||||
|
||||
impl Related<super::provider::Entity> for Entity {
|
||||
@@ -47,10 +45,4 @@ impl Related<super::usage::Entity> for Entity {
|
||||
}
|
||||
}
|
||||
|
||||
impl Related<super::billing_event::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::BillingEvents.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
|
||||
@@ -8,9 +8,7 @@ use axum::{
|
||||
};
|
||||
|
||||
use collab::api::CloudflareIpCountryHeader;
|
||||
use collab::api::billing::{
|
||||
sync_llm_request_usage_with_stripe_periodically, sync_llm_token_usage_with_stripe_periodically,
|
||||
};
|
||||
use collab::api::billing::sync_llm_request_usage_with_stripe_periodically;
|
||||
use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
@@ -155,7 +153,6 @@ async fn main() -> Result<()> {
|
||||
if let Some(mut llm_db) = llm_db {
|
||||
llm_db.initialize().await?;
|
||||
sync_llm_request_usage_with_stripe_periodically(state.clone());
|
||||
sync_llm_token_usage_with_stripe_periodically(state.clone());
|
||||
}
|
||||
|
||||
app = app
|
||||
|
||||
@@ -2709,7 +2709,7 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
|
||||
let billing_preferences = db.get_billing_preferences(user_id).await?;
|
||||
|
||||
let usage = if let Some(llm_db) = session.app_state.llm_db.clone() {
|
||||
let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
|
||||
let subscription_period = crate::db::billing_subscription::Model::current_period(
|
||||
@@ -2717,15 +2717,17 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
session.is_staff(),
|
||||
);
|
||||
|
||||
if let Some((period_start_at, period_end_at)) = subscription_period {
|
||||
let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
|
||||
llm_db
|
||||
.get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
|
||||
.await?
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
(subscription_period, usage)
|
||||
} else {
|
||||
None
|
||||
(None, None)
|
||||
};
|
||||
|
||||
session
|
||||
@@ -2743,6 +2745,12 @@ async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
|
||||
billing_preferences
|
||||
.map(|preferences| preferences.model_request_overages_enabled)
|
||||
},
|
||||
subscription_period: subscription_period.map(|(started_at, ended_at)| {
|
||||
proto::SubscriptionPeriod {
|
||||
started_at: started_at.timestamp() as u64,
|
||||
ended_at: ended_at.timestamp() as u64,
|
||||
}
|
||||
}),
|
||||
usage: usage.map(|usage| {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => zed_llm_client::Plan::Free,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::llm::{self, AGENT_EXTENDED_TRIAL_FEATURE_FLAG};
|
||||
use crate::{Cents, Result};
|
||||
use crate::Result;
|
||||
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use chrono::{Datelike, Utc};
|
||||
use chrono::Utc;
|
||||
use collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use stripe::PriceId;
|
||||
@@ -22,18 +22,6 @@ struct StripeBillingState {
|
||||
prices_by_lookup_key: HashMap<String, stripe::Price>,
|
||||
}
|
||||
|
||||
pub struct StripeModelTokenPrices {
|
||||
input_tokens_price: StripeBillingPrice,
|
||||
input_cache_creation_tokens_price: StripeBillingPrice,
|
||||
input_cache_read_tokens_price: StripeBillingPrice,
|
||||
output_tokens_price: StripeBillingPrice,
|
||||
}
|
||||
|
||||
struct StripeBillingPrice {
|
||||
id: stripe::PriceId,
|
||||
meter_event_name: String,
|
||||
}
|
||||
|
||||
impl StripeBilling {
|
||||
pub fn new(client: Arc<stripe::Client>) -> Self {
|
||||
Self {
|
||||
@@ -109,142 +97,6 @@ impl StripeBilling {
|
||||
.ok_or_else(|| crate::Error::Internal(anyhow!("no price found for {lookup_key:?}")))
|
||||
}
|
||||
|
||||
pub async fn register_model_for_token_based_usage(
|
||||
&self,
|
||||
model: &llm::db::model::Model,
|
||||
) -> Result<StripeModelTokenPrices> {
|
||||
let input_tokens_price = self
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/input_tokens", model.id),
|
||||
&format!("{} (Input Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_creation_tokens_price = self
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/input_cache_creation_tokens", model.id),
|
||||
&format!("{} (Input Cache Creation Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_creation_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let input_cache_read_tokens_price = self
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/input_cache_read_tokens", model.id),
|
||||
&format!("{} (Input Cache Read Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_cache_read_input_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
let output_tokens_price = self
|
||||
.get_or_insert_token_price(
|
||||
&format!("model_{}/output_tokens", model.id),
|
||||
&format!("{} (Output Tokens)", model.name),
|
||||
Cents::new(model.price_per_million_output_tokens as u32),
|
||||
)
|
||||
.await?;
|
||||
Ok(StripeModelTokenPrices {
|
||||
input_tokens_price,
|
||||
input_cache_creation_tokens_price,
|
||||
input_cache_read_tokens_price,
|
||||
output_tokens_price,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_or_insert_token_price(
|
||||
&self,
|
||||
meter_event_name: &str,
|
||||
price_description: &str,
|
||||
price_per_million_tokens: Cents,
|
||||
) -> Result<StripeBillingPrice> {
|
||||
// Fast code path when the meter and the price already exist.
|
||||
{
|
||||
let state = self.state.read().await;
|
||||
if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
return Ok(StripeBillingPrice {
|
||||
id: price_id.clone(),
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut state = self.state.write().await;
|
||||
let meter = if let Some(meter) = state.meters_by_event_name.get(meter_event_name) {
|
||||
meter.clone()
|
||||
} else {
|
||||
let meter = StripeMeter::create(
|
||||
&self.client,
|
||||
StripeCreateMeterParams {
|
||||
default_aggregation: DefaultAggregation { formula: "sum" },
|
||||
display_name: price_description.to_string(),
|
||||
event_name: meter_event_name,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.meters_by_event_name
|
||||
.insert(meter_event_name.to_string(), meter.clone());
|
||||
meter
|
||||
};
|
||||
|
||||
let price_id = if let Some(price_id) = state.price_ids_by_meter_id.get(&meter.id) {
|
||||
price_id.clone()
|
||||
} else {
|
||||
let price = stripe::Price::create(
|
||||
&self.client,
|
||||
stripe::CreatePrice {
|
||||
active: Some(true),
|
||||
billing_scheme: Some(stripe::PriceBillingScheme::PerUnit),
|
||||
currency: stripe::Currency::USD,
|
||||
currency_options: None,
|
||||
custom_unit_amount: None,
|
||||
expand: &[],
|
||||
lookup_key: None,
|
||||
metadata: None,
|
||||
nickname: None,
|
||||
product: None,
|
||||
product_data: Some(stripe::CreatePriceProductData {
|
||||
id: None,
|
||||
active: Some(true),
|
||||
metadata: None,
|
||||
name: price_description.to_string(),
|
||||
statement_descriptor: None,
|
||||
tax_code: None,
|
||||
unit_label: None,
|
||||
}),
|
||||
recurring: Some(stripe::CreatePriceRecurring {
|
||||
aggregate_usage: None,
|
||||
interval: stripe::CreatePriceRecurringInterval::Month,
|
||||
interval_count: None,
|
||||
trial_period_days: None,
|
||||
usage_type: Some(stripe::CreatePriceRecurringUsageType::Metered),
|
||||
meter: Some(meter.id.clone()),
|
||||
}),
|
||||
tax_behavior: None,
|
||||
tiers: None,
|
||||
tiers_mode: None,
|
||||
transfer_lookup_key: None,
|
||||
transform_quantity: None,
|
||||
unit_amount: None,
|
||||
unit_amount_decimal: Some(&format!(
|
||||
"{:.12}",
|
||||
price_per_million_tokens.0 as f64 / 1_000_000f64
|
||||
)),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
state
|
||||
.price_ids_by_meter_id
|
||||
.insert(meter.id, price.id.clone());
|
||||
price.id
|
||||
};
|
||||
|
||||
Ok(StripeBillingPrice {
|
||||
id: price_id,
|
||||
meter_event_name: meter_event_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_price(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
@@ -283,142 +135,6 @@ impl StripeBilling {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn subscribe_to_model(
|
||||
&self,
|
||||
subscription_id: &stripe::SubscriptionId,
|
||||
model: &StripeModelTokenPrices,
|
||||
) -> Result<()> {
|
||||
let subscription =
|
||||
stripe::Subscription::retrieve(&self.client, &subscription_id, &[]).await?;
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_creation_tokens_price.id)
|
||||
{
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_creation_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.input_cache_read_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.input_cache_read_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !subscription_contains_price(&subscription, &model.output_tokens_price.id) {
|
||||
items.push(stripe::UpdateSubscriptionItems {
|
||||
price: Some(model.output_tokens_price.id.to_string()),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if !items.is_empty() {
|
||||
items.extend(subscription.items.data.iter().map(|item| {
|
||||
stripe::UpdateSubscriptionItems {
|
||||
id: Some(item.id.to_string()),
|
||||
..Default::default()
|
||||
}
|
||||
}));
|
||||
|
||||
stripe::Subscription::update(
|
||||
&self.client,
|
||||
subscription_id,
|
||||
stripe::UpdateSubscription {
|
||||
items: Some(items),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_token_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
model: &StripeModelTokenPrices,
|
||||
event: &llm::db::billing_event::Model,
|
||||
) -> Result<()> {
|
||||
let timestamp = Utc::now().timestamp();
|
||||
|
||||
if event.input_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_creation_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_creation_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_creation_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_creation_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.input_cache_read_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("input_cache_read_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.input_cache_read_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.input_cache_read_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if event.output_tokens > 0 {
|
||||
StripeMeterEvent::create(
|
||||
&self.client,
|
||||
StripeCreateMeterEventParams {
|
||||
identifier: &format!("output_tokens/{}", event.idempotency_key),
|
||||
event_name: &model.output_tokens_price.meter_event_name,
|
||||
payload: StripeCreateMeterEventPayload {
|
||||
value: event.output_tokens as u64,
|
||||
stripe_customer_id: customer_id,
|
||||
},
|
||||
timestamp: Some(timestamp),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bill_model_request_usage(
|
||||
&self,
|
||||
customer_id: &stripe::CustomerId,
|
||||
@@ -445,47 +161,6 @@ impl StripeBilling {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn checkout(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
github_login: &str,
|
||||
model: &StripeModelTokenPrices,
|
||||
success_url: &str,
|
||||
) -> Result<String> {
|
||||
let first_of_next_month = Utc::now()
|
||||
.checked_add_months(chrono::Months::new(1))
|
||||
.unwrap()
|
||||
.with_day(1)
|
||||
.unwrap();
|
||||
|
||||
let mut params = stripe::CreateCheckoutSession::new();
|
||||
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
|
||||
params.customer = Some(customer_id);
|
||||
params.client_reference_id = Some(github_login);
|
||||
params.subscription_data = Some(stripe::CreateCheckoutSessionSubscriptionData {
|
||||
billing_cycle_anchor: Some(first_of_next_month.timestamp()),
|
||||
..Default::default()
|
||||
});
|
||||
params.line_items = Some(
|
||||
[
|
||||
&model.input_tokens_price.id,
|
||||
&model.input_cache_creation_tokens_price.id,
|
||||
&model.input_cache_read_tokens_price.id,
|
||||
&model.output_tokens_price.id,
|
||||
]
|
||||
.into_iter()
|
||||
.map(|price_id| stripe::CreateCheckoutSessionLineItems {
|
||||
price: Some(price_id.to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
params.success_url = Some(success_url);
|
||||
|
||||
let session = stripe::CheckoutSession::create(&self.client, params).await?;
|
||||
Ok(session.url.context("no checkout session URL")?)
|
||||
}
|
||||
|
||||
pub async fn checkout_with_zed_pro(
|
||||
&self,
|
||||
customer_id: stripe::CustomerId,
|
||||
@@ -587,18 +262,6 @@ impl StripeBilling {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DefaultAggregation {
|
||||
formula: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StripeCreateMeterParams<'a> {
|
||||
default_aggregation: DefaultAggregation,
|
||||
display_name: String,
|
||||
event_name: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
struct StripeMeter {
|
||||
id: String,
|
||||
@@ -606,13 +269,6 @@ struct StripeMeter {
|
||||
}
|
||||
|
||||
impl StripeMeter {
|
||||
pub fn create(
|
||||
client: &stripe::Client,
|
||||
params: StripeCreateMeterParams,
|
||||
) -> stripe::Response<Self> {
|
||||
client.post_form("/billing/meters", params)
|
||||
}
|
||||
|
||||
pub fn list(client: &stripe::Client) -> stripe::Response<stripe::List<Self>> {
|
||||
#[derive(Serialize)]
|
||||
struct Params {
|
||||
|
||||
@@ -22,6 +22,7 @@ use ui::{
|
||||
Avatar, Button, Icon, IconButton, IconName, Label, Tab, Tooltip, h_flex, prelude::*, v_flex,
|
||||
};
|
||||
use util::{ResultExt, TryFutureExt};
|
||||
use workspace::SuppressNotification;
|
||||
use workspace::notifications::{
|
||||
Notification as WorkspaceNotification, NotificationId, SuppressEvent,
|
||||
};
|
||||
@@ -823,11 +824,19 @@ impl Render for NotificationToast {
|
||||
.child(Label::new(self.text.clone()))
|
||||
.child(
|
||||
IconButton::new("close", IconName::Close)
|
||||
.tooltip(|window, cx| Tooltip::for_action("Close", &menu::Cancel, window, cx))
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
|
||||
)
|
||||
.child(
|
||||
IconButton::new("suppress", IconName::XCircle)
|
||||
.tooltip(Tooltip::text("Do not show until restart"))
|
||||
IconButton::new("suppress", IconName::SquareMinus)
|
||||
.tooltip(|window, cx| {
|
||||
Tooltip::for_action(
|
||||
"Do not show until restart",
|
||||
&SuppressNotification,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.emit(SuppressEvent))),
|
||||
)
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
|
||||
@@ -182,11 +182,11 @@ pub enum Tool {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Any,
|
||||
Tool { name: String },
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
|
||||
@@ -36,7 +36,6 @@ gpui.workspace = true
|
||||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
log.workspace = true
|
||||
lsp-types.workspace = true
|
||||
node_runtime.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
|
||||
@@ -78,6 +78,11 @@ impl From<DebugAdapterName> for SharedString {
|
||||
name.0
|
||||
}
|
||||
}
|
||||
impl From<SharedString> for DebugAdapterName {
|
||||
fn from(name: SharedString) -> Self {
|
||||
DebugAdapterName(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for DebugAdapterName {
|
||||
fn from(str: &'a str) -> DebugAdapterName {
|
||||
@@ -402,10 +407,6 @@ pub async fn fetch_latest_adapter_version_from_github(
|
||||
})
|
||||
}
|
||||
|
||||
pub trait InlineValueProvider {
|
||||
fn provide(&self, variables: Vec<(String, lsp_types::Range)>) -> Vec<lsp_types::InlineValue>;
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
pub trait DebugAdapter: 'static + Send + Sync {
|
||||
fn name(&self) -> DebugAdapterName;
|
||||
@@ -417,10 +418,6 @@ pub trait DebugAdapter: 'static + Send + Sync {
|
||||
user_installed_path: Option<PathBuf>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<DebugAdapterBinary>;
|
||||
|
||||
fn inline_value_provider(&self) -> Option<Box<dyn InlineValueProvider>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod adapters;
|
||||
pub mod client;
|
||||
pub mod debugger_settings;
|
||||
pub mod inline_value;
|
||||
pub mod proto_conversions;
|
||||
mod registry;
|
||||
pub mod transport;
|
||||
|
||||
277
crates/dap/src/inline_value.rs
Normal file
277
crates/dap/src/inline_value.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum VariableLookupKind {
|
||||
Variable,
|
||||
Expression,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum VariableScope {
|
||||
Local,
|
||||
Global,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InlineValueLocation {
|
||||
pub variable_name: String,
|
||||
pub scope: VariableScope,
|
||||
pub lookup: VariableLookupKind,
|
||||
pub row: usize,
|
||||
pub column: usize,
|
||||
}
|
||||
|
||||
/// A trait for providing inline values for debugging purposes.
|
||||
///
|
||||
/// Implementors of this trait are responsible for analyzing a given node in the
|
||||
/// source code and extracting variable information, including their names,
|
||||
/// scopes, and positions. This information is used to display inline values
|
||||
/// during debugging sessions. Implementors must also handle variable scoping
|
||||
/// themselves by traversing the syntax tree upwards to determine whether a
|
||||
/// variable is local or global.
|
||||
pub trait InlineValueProvider {
|
||||
/// Provides a list of inline value locations based on the given node and source code.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `node`: The root node of the active debug line. Implementors should traverse
|
||||
/// upwards from this node to gather variable information and determine their scope.
|
||||
/// - `source`: The source code as a string slice, used to extract variable names.
|
||||
/// - `max_row`: The maximum row to consider when collecting variables. Variables
|
||||
/// declared beyond this row should be ignored.
|
||||
///
|
||||
/// # Returns
|
||||
/// A vector of `InlineValueLocation` instances, each representing a variable's
|
||||
/// name, scope, and the position of the inline value should be shown.
|
||||
fn provide(
|
||||
&self,
|
||||
node: language::Node,
|
||||
source: &str,
|
||||
max_row: usize,
|
||||
) -> Vec<InlineValueLocation>;
|
||||
}
|
||||
|
||||
pub struct RustInlineValueProvider;
|
||||
|
||||
impl InlineValueProvider for RustInlineValueProvider {
|
||||
fn provide(
|
||||
&self,
|
||||
mut node: language::Node,
|
||||
source: &str,
|
||||
max_row: usize,
|
||||
) -> Vec<InlineValueLocation> {
|
||||
let mut variables = Vec::new();
|
||||
let mut variable_names = HashSet::new();
|
||||
let mut scope = VariableScope::Local;
|
||||
|
||||
loop {
|
||||
let mut variable_names_in_scope = HashMap::new();
|
||||
for child in node.named_children(&mut node.walk()) {
|
||||
if child.start_position().row >= max_row {
|
||||
break;
|
||||
}
|
||||
|
||||
if scope == VariableScope::Local && child.kind() == "let_declaration" {
|
||||
if let Some(identifier) = child.child_by_field_name("pattern") {
|
||||
let variable_name = source[identifier.byte_range()].to_string();
|
||||
|
||||
if variable_names.contains(&variable_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(index) = variable_names_in_scope.get(&variable_name) {
|
||||
variables.remove(*index);
|
||||
}
|
||||
|
||||
variable_names_in_scope.insert(variable_name.clone(), variables.len());
|
||||
variables.push(InlineValueLocation {
|
||||
variable_name,
|
||||
scope: VariableScope::Local,
|
||||
lookup: VariableLookupKind::Variable,
|
||||
row: identifier.end_position().row,
|
||||
column: identifier.end_position().column,
|
||||
});
|
||||
}
|
||||
} else if child.kind() == "static_item" {
|
||||
if let Some(name) = child.child_by_field_name("name") {
|
||||
let variable_name = source[name.byte_range()].to_string();
|
||||
variables.push(InlineValueLocation {
|
||||
variable_name,
|
||||
scope: scope.clone(),
|
||||
lookup: VariableLookupKind::Expression,
|
||||
row: name.end_position().row,
|
||||
column: name.end_position().column,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
variable_names.extend(variable_names_in_scope.keys().cloned());
|
||||
|
||||
if matches!(node.kind(), "function_item" | "closure_expression") {
|
||||
scope = VariableScope::Global;
|
||||
}
|
||||
|
||||
if let Some(parent) = node.parent() {
|
||||
node = parent;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
variables
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PythonInlineValueProvider;
|
||||
|
||||
impl InlineValueProvider for PythonInlineValueProvider {
|
||||
fn provide(
|
||||
&self,
|
||||
mut node: language::Node,
|
||||
source: &str,
|
||||
max_row: usize,
|
||||
) -> Vec<InlineValueLocation> {
|
||||
let mut variables = Vec::new();
|
||||
let mut variable_names = HashSet::new();
|
||||
let mut scope = VariableScope::Local;
|
||||
|
||||
loop {
|
||||
let mut variable_names_in_scope = HashMap::new();
|
||||
for child in node.named_children(&mut node.walk()) {
|
||||
if child.start_position().row >= max_row {
|
||||
break;
|
||||
}
|
||||
|
||||
if scope == VariableScope::Local {
|
||||
match child.kind() {
|
||||
"expression_statement" => {
|
||||
if let Some(expr) = child.child(0) {
|
||||
if expr.kind() == "assignment" {
|
||||
if let Some(param) = expr.child(0) {
|
||||
let param_identifier = if param.kind() == "identifier" {
|
||||
Some(param)
|
||||
} else if param.kind() == "typed_parameter" {
|
||||
param.child(0)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(identifier) = param_identifier {
|
||||
if identifier.kind() == "identifier" {
|
||||
let variable_name =
|
||||
source[identifier.byte_range()].to_string();
|
||||
|
||||
if variable_names.contains(&variable_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(index) =
|
||||
variable_names_in_scope.get(&variable_name)
|
||||
{
|
||||
variables.remove(*index);
|
||||
}
|
||||
|
||||
variable_names_in_scope
|
||||
.insert(variable_name.clone(), variables.len());
|
||||
variables.push(InlineValueLocation {
|
||||
variable_name,
|
||||
scope: VariableScope::Local,
|
||||
lookup: VariableLookupKind::Variable,
|
||||
row: identifier.end_position().row,
|
||||
column: identifier.end_position().column,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"function_definition" => {
|
||||
if let Some(params) = child.child_by_field_name("parameters") {
|
||||
for param in params.named_children(&mut params.walk()) {
|
||||
let param_identifier = if param.kind() == "identifier" {
|
||||
Some(param)
|
||||
} else if param.kind() == "typed_parameter" {
|
||||
param.child(0)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(identifier) = param_identifier {
|
||||
if identifier.kind() == "identifier" {
|
||||
let variable_name =
|
||||
source[identifier.byte_range()].to_string();
|
||||
|
||||
if variable_names.contains(&variable_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(index) =
|
||||
variable_names_in_scope.get(&variable_name)
|
||||
{
|
||||
variables.remove(*index);
|
||||
}
|
||||
|
||||
variable_names_in_scope
|
||||
.insert(variable_name.clone(), variables.len());
|
||||
variables.push(InlineValueLocation {
|
||||
variable_name,
|
||||
scope: VariableScope::Local,
|
||||
lookup: VariableLookupKind::Variable,
|
||||
row: identifier.end_position().row,
|
||||
column: identifier.end_position().column,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"for_statement" => {
|
||||
if let Some(target) = child.child_by_field_name("left") {
|
||||
if target.kind() == "identifier" {
|
||||
let variable_name = source[target.byte_range()].to_string();
|
||||
|
||||
if variable_names.contains(&variable_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(index) = variable_names_in_scope.get(&variable_name)
|
||||
{
|
||||
variables.remove(*index);
|
||||
}
|
||||
|
||||
variable_names_in_scope
|
||||
.insert(variable_name.clone(), variables.len());
|
||||
variables.push(InlineValueLocation {
|
||||
variable_name,
|
||||
scope: VariableScope::Local,
|
||||
lookup: VariableLookupKind::Variable,
|
||||
row: target.end_position().row,
|
||||
column: target.end_position().column,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
variable_names.extend(variable_names_in_scope.keys().cloned());
|
||||
|
||||
if matches!(node.kind(), "function_definition" | "module")
|
||||
&& node.range().end_point.row < max_row
|
||||
{
|
||||
scope = VariableScope::Global;
|
||||
}
|
||||
|
||||
if let Some(parent) = node.parent() {
|
||||
node = parent;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
variables
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,10 @@ use gpui::{App, Global, SharedString};
|
||||
use parking_lot::RwLock;
|
||||
use task::{DebugRequest, DebugScenario, SpawnInTerminal, TaskTemplate};
|
||||
|
||||
use crate::adapters::{DebugAdapter, DebugAdapterName};
|
||||
use crate::{
|
||||
adapters::{DebugAdapter, DebugAdapterName},
|
||||
inline_value::InlineValueProvider,
|
||||
};
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
/// Given a user build configuration, locator creates a fill-in debug target ([DebugRequest]) on behalf of the user.
|
||||
@@ -13,7 +16,12 @@ use std::{collections::BTreeMap, sync::Arc};
|
||||
pub trait DapLocator: Send + Sync {
|
||||
fn name(&self) -> SharedString;
|
||||
/// Determines whether this locator can generate debug target for given task.
|
||||
fn create_scenario(&self, build_config: &TaskTemplate, adapter: &str) -> Option<DebugScenario>;
|
||||
fn create_scenario(
|
||||
&self,
|
||||
build_config: &TaskTemplate,
|
||||
resolved_label: &str,
|
||||
adapter: DebugAdapterName,
|
||||
) -> Option<DebugScenario>;
|
||||
|
||||
async fn run(&self, build_config: SpawnInTerminal) -> Result<DebugRequest>;
|
||||
}
|
||||
@@ -22,6 +30,7 @@ pub trait DapLocator: Send + Sync {
|
||||
struct DapRegistryState {
|
||||
adapters: BTreeMap<DebugAdapterName, Arc<dyn DebugAdapter>>,
|
||||
locators: FxHashMap<SharedString, Arc<dyn DapLocator>>,
|
||||
inline_value_providers: FxHashMap<String, Arc<dyn InlineValueProvider>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@@ -58,6 +67,22 @@ impl DapRegistry {
|
||||
);
|
||||
}
|
||||
|
||||
pub fn add_inline_value_provider(
|
||||
&self,
|
||||
language: String,
|
||||
provider: Arc<dyn InlineValueProvider>,
|
||||
) {
|
||||
let _previous_value = self
|
||||
.0
|
||||
.write()
|
||||
.inline_value_providers
|
||||
.insert(language, provider);
|
||||
debug_assert!(
|
||||
_previous_value.is_none(),
|
||||
"Attempted to insert a new inline value provider when one is already registered"
|
||||
);
|
||||
}
|
||||
|
||||
pub fn locators(&self) -> FxHashMap<SharedString, Arc<dyn DapLocator>> {
|
||||
self.0.read().locators.clone()
|
||||
}
|
||||
@@ -66,6 +91,10 @@ impl DapRegistry {
|
||||
self.0.read().adapters.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn inline_value_provider(&self, language: &str) -> Option<Arc<dyn InlineValueProvider>> {
|
||||
self.0.read().inline_value_providers.get(language).cloned()
|
||||
}
|
||||
|
||||
pub fn enumerate_adapters(&self) -> Vec<DebugAdapterName> {
|
||||
self.0.read().adapters.keys().cloned().collect()
|
||||
}
|
||||
|
||||
@@ -580,21 +580,31 @@ impl TcpTransport {
|
||||
.unwrap_or(2000u64)
|
||||
});
|
||||
|
||||
let (rx, tx) = select! {
|
||||
let (mut process, (rx, tx)) = select! {
|
||||
_ = cx.background_executor().timer(Duration::from_millis(timeout)).fuse() => {
|
||||
return Err(anyhow!(format!("Connection to TCP DAP timeout {}:{}", host, port)))
|
||||
},
|
||||
result = cx.spawn(async move |cx| {
|
||||
loop {
|
||||
match TcpStream::connect(address).await {
|
||||
Ok(stream) => return stream.split(),
|
||||
Ok(stream) => return Ok((process, stream.split())),
|
||||
Err(_) => {
|
||||
if let Ok(Some(_)) = process.try_status() {
|
||||
let output = process.output().await?;
|
||||
let output = if output.stderr.is_empty() {
|
||||
String::from_utf8_lossy(&output.stdout).to_string()
|
||||
} else {
|
||||
String::from_utf8_lossy(&output.stderr).to_string()
|
||||
};
|
||||
return Err(anyhow!("{}\nerror: process exited before debugger attached.", output));
|
||||
}
|
||||
cx.background_executor().timer(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}).fuse() => result
|
||||
}).fuse() => result?
|
||||
};
|
||||
|
||||
log::info!(
|
||||
"Debug adapter has connected to TCP server {}:{}",
|
||||
host,
|
||||
|
||||
@@ -27,7 +27,6 @@ dap.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
lsp-types.workspace = true
|
||||
paths.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::{collections::HashMap, path::PathBuf, sync::OnceLock};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use dap::adapters::{DebugTaskDefinition, InlineValueProvider, latest_github_release};
|
||||
use dap::adapters::{DebugTaskDefinition, latest_github_release};
|
||||
use futures::StreamExt;
|
||||
use gpui::AsyncApp;
|
||||
use task::DebugRequest;
|
||||
@@ -159,25 +159,4 @@ impl DebugAdapter for CodeLldbDebugAdapter {
|
||||
connection: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn inline_value_provider(&self) -> Option<Box<dyn InlineValueProvider>> {
|
||||
Some(Box::new(CodeLldbInlineValueProvider))
|
||||
}
|
||||
}
|
||||
|
||||
struct CodeLldbInlineValueProvider;
|
||||
|
||||
impl InlineValueProvider for CodeLldbInlineValueProvider {
|
||||
fn provide(&self, variables: Vec<(String, lsp_types::Range)>) -> Vec<lsp_types::InlineValue> {
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|(variable, range)| {
|
||||
lsp_types::InlineValue::VariableLookup(lsp_types::InlineValueVariableLookup {
|
||||
range,
|
||||
variable_name: Some(variable),
|
||||
case_sensitive_lookup: true,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ mod go;
|
||||
mod javascript;
|
||||
mod php;
|
||||
mod python;
|
||||
mod ruby;
|
||||
|
||||
use std::{net::Ipv4Addr, sync::Arc};
|
||||
|
||||
@@ -16,6 +17,7 @@ use dap::{
|
||||
self, AdapterVersion, DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName,
|
||||
GithubRepo,
|
||||
},
|
||||
inline_value::{PythonInlineValueProvider, RustInlineValueProvider},
|
||||
};
|
||||
use gdb::GdbDebugAdapter;
|
||||
use go::GoDebugAdapter;
|
||||
@@ -23,6 +25,7 @@ use gpui::{App, BorrowAppContext};
|
||||
use javascript::JsDebugAdapter;
|
||||
use php::PhpDebugAdapter;
|
||||
use python::PythonDebugAdapter;
|
||||
use ruby::RubyDebugAdapter;
|
||||
use serde_json::{Value, json};
|
||||
use task::TcpArgumentsTemplate;
|
||||
|
||||
@@ -32,8 +35,13 @@ pub fn init(cx: &mut App) {
|
||||
registry.add_adapter(Arc::from(PythonDebugAdapter::default()));
|
||||
registry.add_adapter(Arc::from(PhpDebugAdapter::default()));
|
||||
registry.add_adapter(Arc::from(JsDebugAdapter::default()));
|
||||
registry.add_adapter(Arc::from(RubyDebugAdapter));
|
||||
registry.add_adapter(Arc::from(GoDebugAdapter));
|
||||
registry.add_adapter(Arc::from(GdbDebugAdapter));
|
||||
|
||||
registry.add_inline_value_provider("Rust".to_string(), Arc::from(RustInlineValueProvider));
|
||||
registry
|
||||
.add_inline_value_provider("Python".to_string(), Arc::from(PythonInlineValueProvider));
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
use crate::*;
|
||||
use dap::{
|
||||
DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition,
|
||||
adapters::InlineValueProvider,
|
||||
};
|
||||
use dap::{DebugRequest, StartDebuggingRequestArguments, adapters::DebugTaskDefinition};
|
||||
use gpui::AsyncApp;
|
||||
use std::{collections::HashMap, ffi::OsStr, path::PathBuf, sync::OnceLock};
|
||||
use util::ResultExt;
|
||||
@@ -182,34 +179,4 @@ impl DebugAdapter for PythonDebugAdapter {
|
||||
self.get_installed_binary(delegate, &config, user_installed_path, cx)
|
||||
.await
|
||||
}
|
||||
|
||||
fn inline_value_provider(&self) -> Option<Box<dyn InlineValueProvider>> {
|
||||
Some(Box::new(PythonInlineValueProvider))
|
||||
}
|
||||
}
|
||||
|
||||
struct PythonInlineValueProvider;
|
||||
|
||||
impl InlineValueProvider for PythonInlineValueProvider {
|
||||
fn provide(&self, variables: Vec<(String, lsp_types::Range)>) -> Vec<lsp_types::InlineValue> {
|
||||
variables
|
||||
.into_iter()
|
||||
.map(|(variable, range)| {
|
||||
if variable.contains(".") || variable.contains("[") {
|
||||
lsp_types::InlineValue::EvaluatableExpression(
|
||||
lsp_types::InlineValueEvaluatableExpression {
|
||||
range,
|
||||
expression: Some(variable),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
lsp_types::InlineValue::VariableLookup(lsp_types::InlineValueVariableLookup {
|
||||
range,
|
||||
variable_name: Some(variable),
|
||||
case_sensitive_lookup: true,
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
102
crates/dap_adapters/src/ruby.rs
Normal file
102
crates/dap_adapters/src/ruby.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use dap::{
|
||||
DebugRequest, StartDebuggingRequestArguments,
|
||||
adapters::{
|
||||
self, DapDelegate, DebugAdapter, DebugAdapterBinary, DebugAdapterName, DebugTaskDefinition,
|
||||
},
|
||||
};
|
||||
use gpui::AsyncApp;
|
||||
use std::path::PathBuf;
|
||||
use util::command::new_smol_command;
|
||||
|
||||
use crate::ToDap;
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct RubyDebugAdapter;
|
||||
|
||||
impl RubyDebugAdapter {
|
||||
const ADAPTER_NAME: &'static str = "Ruby";
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl DebugAdapter for RubyDebugAdapter {
|
||||
fn name(&self) -> DebugAdapterName {
|
||||
DebugAdapterName(Self::ADAPTER_NAME.into())
|
||||
}
|
||||
|
||||
async fn get_binary(
|
||||
&self,
|
||||
delegate: &dyn DapDelegate,
|
||||
definition: &DebugTaskDefinition,
|
||||
_user_installed_path: Option<PathBuf>,
|
||||
_cx: &mut AsyncApp,
|
||||
) -> Result<DebugAdapterBinary> {
|
||||
let adapter_path = paths::debug_adapters_dir().join(self.name().as_ref());
|
||||
let mut rdbg_path = adapter_path.join("rdbg");
|
||||
if !delegate.fs().is_file(&rdbg_path).await {
|
||||
match delegate.which("rdbg".as_ref()) {
|
||||
Some(path) => rdbg_path = path,
|
||||
None => {
|
||||
delegate.output_to_console(
|
||||
"rdbg not found on path, trying `gem install debug`".to_string(),
|
||||
);
|
||||
let output = new_smol_command("gem")
|
||||
.arg("install")
|
||||
.arg("--no-document")
|
||||
.arg("--bindir")
|
||||
.arg(adapter_path)
|
||||
.arg("debug")
|
||||
.output()
|
||||
.await?;
|
||||
if !output.status.success() {
|
||||
return Err(anyhow!(
|
||||
"Failed to install rdbg:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr).to_string()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tcp_connection = definition.tcp_connection.clone().unwrap_or_default();
|
||||
let (host, port, timeout) = crate::configure_tcp_connection(tcp_connection).await?;
|
||||
|
||||
let DebugRequest::Launch(mut launch) = definition.request.clone() else {
|
||||
anyhow::bail!("rdbg does not yet support attaching");
|
||||
};
|
||||
|
||||
let mut arguments = vec![
|
||||
"--open".to_string(),
|
||||
format!("--port={}", port),
|
||||
format!("--host={}", host),
|
||||
];
|
||||
if launch.args.is_empty() {
|
||||
let program = launch.program.clone();
|
||||
let mut split = program.split(" ");
|
||||
launch.program = split.next().unwrap().to_string();
|
||||
launch.args = split.map(|s| s.to_string()).collect();
|
||||
}
|
||||
if delegate.which(launch.program.as_ref()).is_some() {
|
||||
arguments.push("--command".to_string())
|
||||
}
|
||||
arguments.push(launch.program);
|
||||
arguments.extend(launch.args);
|
||||
|
||||
Ok(DebugAdapterBinary {
|
||||
command: rdbg_path.to_string_lossy().to_string(),
|
||||
arguments,
|
||||
connection: Some(adapters::TcpArguments {
|
||||
host,
|
||||
port,
|
||||
timeout,
|
||||
}),
|
||||
cwd: launch.cwd,
|
||||
envs: launch.env.into_iter().collect(),
|
||||
request_args: StartDebuggingRequestArguments {
|
||||
configuration: serde_json::Value::Object(Default::default()),
|
||||
request: definition.request.to_dap(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ doctest = false
|
||||
[features]
|
||||
test-support = [
|
||||
"dap/test-support",
|
||||
"dap_adapters/test-support",
|
||||
"editor/test-support",
|
||||
"gpui/test-support",
|
||||
"project/test-support",
|
||||
@@ -31,6 +32,7 @@ client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
dap.workspace = true
|
||||
dap_adapters = { workspace = true, optional = true }
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
@@ -63,6 +65,7 @@ unindent = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
dap = { workspace = true, features = ["test-support"] }
|
||||
dap_adapters = { workspace = true, features = ["test-support"] }
|
||||
debugger_tools = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -212,7 +212,6 @@ impl DebugPanel {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let dap_store = self.project.read(cx).dap_store();
|
||||
let workspace = self.workspace.clone();
|
||||
let session = dap_store.update(cx, |dap_store, cx| {
|
||||
dap_store.new_session(
|
||||
scenario.label.clone(),
|
||||
@@ -251,14 +250,14 @@ impl DebugPanel {
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
if let Err(error) = task.await {
|
||||
log::error!("{:?}", error);
|
||||
workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.show_error(&error, cx);
|
||||
})
|
||||
.ok();
|
||||
session
|
||||
.update(cx, |session, cx| session.shutdown(cx))?
|
||||
.update(cx, |session, cx| {
|
||||
session
|
||||
.console_output(cx)
|
||||
.unbounded_send(format!("error: {}", error))
|
||||
.ok();
|
||||
session.shutdown(cx)
|
||||
})?
|
||||
.await;
|
||||
}
|
||||
anyhow::Ok(())
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::Reverse,
|
||||
ops::Not,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
usize,
|
||||
};
|
||||
|
||||
use collections::{HashMap, HashSet};
|
||||
use dap::{
|
||||
DapRegistry, DebugRequest,
|
||||
adapters::{DebugAdapterName, DebugTaskDefinition},
|
||||
@@ -192,25 +191,22 @@ impl NewSessionModal {
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<ui::DropdownMenu> {
|
||||
let workspace = self.workspace.clone();
|
||||
let language_registry = self
|
||||
.workspace
|
||||
.update(cx, |this, _| this.app_state().languages.clone())
|
||||
.ok()?;
|
||||
let weak = cx.weak_entity();
|
||||
let label = self
|
||||
.debugger
|
||||
.as_ref()
|
||||
.map(|d| d.0.clone())
|
||||
.unwrap_or_else(|| SELECT_DEBUGGER_LABEL.clone());
|
||||
let active_buffer_language_name =
|
||||
self.task_contexts
|
||||
.active_item_context
|
||||
.as_ref()
|
||||
.and_then(|item| {
|
||||
item.1
|
||||
.as_ref()
|
||||
.and_then(|location| location.buffer.read(cx).language()?.name().into())
|
||||
});
|
||||
let active_buffer_language = self
|
||||
.task_contexts
|
||||
.active_item_context
|
||||
.as_ref()
|
||||
.and_then(|item| {
|
||||
item.1
|
||||
.as_ref()
|
||||
.and_then(|location| location.buffer.read(cx).language())
|
||||
})
|
||||
.cloned();
|
||||
DropdownMenu::new(
|
||||
"dap-adapter-picker",
|
||||
label,
|
||||
@@ -229,42 +225,19 @@ impl NewSessionModal {
|
||||
}
|
||||
};
|
||||
|
||||
let available_languages = language_registry.language_names();
|
||||
let mut debugger_to_languages = HashMap::default();
|
||||
for language in available_languages {
|
||||
let Some(language) =
|
||||
language_registry.available_language_for_name(language.as_str())
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
language.config().debuggers.iter().for_each(|adapter| {
|
||||
debugger_to_languages
|
||||
.entry(adapter.clone())
|
||||
.or_insert_with(HashSet::default)
|
||||
.insert(language.name());
|
||||
});
|
||||
}
|
||||
let mut available_adapters = workspace
|
||||
.update(cx, |_, cx| DapRegistry::global(cx).enumerate_adapters())
|
||||
.ok()
|
||||
.unwrap_or_default();
|
||||
|
||||
available_adapters.sort_by_key(|name| {
|
||||
let languages_for_debugger = debugger_to_languages.get(name.as_ref());
|
||||
let languages_count =
|
||||
languages_for_debugger.map_or(0, |languages| languages.len());
|
||||
let contains_language_of_active_buffer = languages_for_debugger
|
||||
.zip(active_buffer_language_name.as_ref())
|
||||
.map_or(false, |(languages, active_buffer_language)| {
|
||||
languages.contains(active_buffer_language)
|
||||
});
|
||||
|
||||
(
|
||||
Reverse(contains_language_of_active_buffer),
|
||||
Reverse(languages_count),
|
||||
)
|
||||
});
|
||||
if let Some(language) = active_buffer_language {
|
||||
available_adapters.sort_by_key(|adapter| {
|
||||
language
|
||||
.config()
|
||||
.debuggers
|
||||
.get_index_of(adapter.0.as_ref())
|
||||
.unwrap_or(usize::MAX)
|
||||
});
|
||||
}
|
||||
|
||||
for adapter in available_adapters.into_iter() {
|
||||
menu = menu.entry(adapter.0.clone(), None, setter_for_name(adapter.clone()));
|
||||
|
||||
@@ -731,19 +731,30 @@ impl RunningState {
|
||||
(task, None)
|
||||
}
|
||||
};
|
||||
let Some(task) = task.resolve_task("debug-build-task", &task_context) else {
|
||||
anyhow::bail!("Could not resolve task variables within a debug scenario");
|
||||
};
|
||||
|
||||
let locator_name = if let Some(locator_name) = locator_name {
|
||||
debug_assert!(request.is_none());
|
||||
Some(locator_name)
|
||||
} else if request.is_none() {
|
||||
dap_store
|
||||
.update(cx, |this, cx| {
|
||||
this.debug_scenario_for_build_task(task.clone(), adapter.clone(), cx)
|
||||
.and_then(|scenario| match scenario.build {
|
||||
this.debug_scenario_for_build_task(
|
||||
task.original_task().clone(),
|
||||
adapter.clone().into(),
|
||||
task.display_label().to_owned().into(),
|
||||
cx,
|
||||
)
|
||||
.and_then(|scenario| {
|
||||
match scenario.build {
|
||||
Some(BuildTaskDefinition::Template {
|
||||
locator_name, ..
|
||||
}) => locator_name,
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
@@ -751,10 +762,6 @@ impl RunningState {
|
||||
None
|
||||
};
|
||||
|
||||
let Some(task) = task.resolve_task("debug-build-task", &task_context) else {
|
||||
anyhow::bail!("Could not resolve task variables within a debug scenario");
|
||||
};
|
||||
|
||||
let builder = ShellBuilder::new(is_local, &task.resolved.shell);
|
||||
let command_label = builder.command_label(&task.resolved.command_label);
|
||||
let (command, args) =
|
||||
|
||||
@@ -152,7 +152,7 @@ impl Console {
|
||||
session
|
||||
.evaluate(
|
||||
expression,
|
||||
Some(dap::EvaluateArgumentsContext::Variables),
|
||||
Some(dap::EvaluateArgumentsContext::Repl),
|
||||
self.stack_frame_list.read(cx).selected_stack_frame_id(),
|
||||
None,
|
||||
cx,
|
||||
|
||||
@@ -675,6 +675,7 @@ impl VariableList {
|
||||
div()
|
||||
.id(var_ref as usize)
|
||||
.group("variable_list_entry")
|
||||
.pl_2()
|
||||
.border_1()
|
||||
.border_r_2()
|
||||
.border_color(border_color)
|
||||
@@ -692,8 +693,8 @@ impl VariableList {
|
||||
ListItem::new(SharedString::from(format!("scope-{}", var_ref)))
|
||||
.selectable(false)
|
||||
.disabled(self.disabled)
|
||||
.indent_level(state.depth + 1)
|
||||
.indent_step_size(px(20.))
|
||||
.indent_level(state.depth)
|
||||
.indent_step_size(px(10.))
|
||||
.always_show_disclosure_icon(true)
|
||||
.toggle(state.is_expanded)
|
||||
.on_toggle({
|
||||
@@ -772,6 +773,7 @@ impl VariableList {
|
||||
div()
|
||||
.id(variable.item_id())
|
||||
.group("variable_list_entry")
|
||||
.pl_2()
|
||||
.border_1()
|
||||
.border_r_2()
|
||||
.border_color(border_color)
|
||||
@@ -791,8 +793,8 @@ impl VariableList {
|
||||
)))
|
||||
.disabled(self.disabled)
|
||||
.selectable(false)
|
||||
.indent_level(state.depth + 1_usize)
|
||||
.indent_step_size(px(20.))
|
||||
.indent_level(state.depth)
|
||||
.indent_step_size(px(10.))
|
||||
.always_show_disclosure_icon(true)
|
||||
.when(var_ref > 0, |list_item| {
|
||||
list_item.toggle(state.is_expanded).on_toggle(cx.listener({
|
||||
|
||||
@@ -21,6 +21,8 @@ mod dap_logger;
|
||||
#[cfg(test)]
|
||||
mod debugger_panel;
|
||||
#[cfg(test)]
|
||||
mod inline_values;
|
||||
#[cfg(test)]
|
||||
mod module_list;
|
||||
#[cfg(test)]
|
||||
mod persistence;
|
||||
@@ -45,6 +47,7 @@ pub fn init_test(cx: &mut gpui::TestAppContext) {
|
||||
Project::init_settings(cx);
|
||||
editor::init(cx);
|
||||
crate::init(cx);
|
||||
dap_adapters::init(cx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user