Compare commits
27 Commits
agent-tool
...
agent-brin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd21ad8f7 | ||
|
|
187f851613 | ||
|
|
a77db45865 | ||
|
|
6bb6be826d | ||
|
|
7d9a55d101 | ||
|
|
57d8397f53 | ||
|
|
17ecf94f6f | ||
|
|
d492939bed | ||
|
|
720dfee803 | ||
|
|
a98c648201 | ||
|
|
c147daae4a | ||
|
|
d3911e34de | ||
|
|
87f85f1863 | ||
|
|
1a4dab97db | ||
|
|
cd365b0cf5 | ||
|
|
58604fba86 | ||
|
|
b0609272c0 | ||
|
|
a17807d8b1 | ||
|
|
f81e65ae7c | ||
|
|
952fe34aaa | ||
|
|
f527df6fa1 | ||
|
|
b54bbebc03 | ||
|
|
8bb7a1f9e7 | ||
|
|
e70d8d4dfd | ||
|
|
ea5ce2a1a4 | ||
|
|
fd8eeb537d | ||
|
|
92f21ee39d |
2
.github/workflows/eval.yml
vendored
2
.github/workflows/eval.yml
vendored
@@ -2,7 +2,7 @@ name: Run Agent Eval
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 * * * *"
|
||||
- cron: "0 0 * * *"
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
|
||||
78
Cargo.lock
generated
78
Cargo.lock
generated
@@ -61,7 +61,6 @@ dependencies = [
|
||||
"buffer_diff",
|
||||
"chrono",
|
||||
"client",
|
||||
"clock",
|
||||
"collections",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
@@ -96,12 +95,13 @@ dependencies = [
|
||||
"paths",
|
||||
"picker",
|
||||
"project",
|
||||
"prompt_library",
|
||||
"prompt_store",
|
||||
"proto",
|
||||
"rand 0.8.5",
|
||||
"ref-cast",
|
||||
"release_channel",
|
||||
"rope",
|
||||
"rules_library",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -498,11 +498,11 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"project",
|
||||
"prompt_library",
|
||||
"prompt_store",
|
||||
"proto",
|
||||
"rand 0.8.5",
|
||||
"rope",
|
||||
"rules_library",
|
||||
"schemars",
|
||||
"search",
|
||||
"serde",
|
||||
@@ -11083,32 +11083,6 @@ dependencies = [
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt_library"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"editor",
|
||||
"gpui",
|
||||
"language",
|
||||
"language_model",
|
||||
"log",
|
||||
"menu",
|
||||
"picker",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"rope",
|
||||
"serde",
|
||||
"settings",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt_store"
|
||||
version = "0.1.0"
|
||||
@@ -11742,6 +11716,26 @@ dependencies = [
|
||||
"thiserror 2.0.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf"
|
||||
dependencies = [
|
||||
"ref-cast-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ref-cast-impl"
|
||||
version = "1.0.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.100",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "refineable"
|
||||
version = "0.1.0"
|
||||
@@ -12271,6 +12265,32 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rules_library"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"collections",
|
||||
"editor",
|
||||
"gpui",
|
||||
"language",
|
||||
"language_model",
|
||||
"log",
|
||||
"menu",
|
||||
"picker",
|
||||
"prompt_store",
|
||||
"release_channel",
|
||||
"rope",
|
||||
"serde",
|
||||
"settings",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "runtimelib"
|
||||
version = "0.25.0"
|
||||
|
||||
@@ -109,7 +109,7 @@ members = [
|
||||
"crates/project",
|
||||
"crates/project_panel",
|
||||
"crates/project_symbols",
|
||||
"crates/prompt_library",
|
||||
"crates/rules_library",
|
||||
"crates/prompt_store",
|
||||
"crates/proto",
|
||||
"crates/recent_projects",
|
||||
@@ -319,7 +319,7 @@ prettier = { path = "crates/prettier" }
|
||||
project = { path = "crates/project" }
|
||||
project_panel = { path = "crates/project_panel" }
|
||||
project_symbols = { path = "crates/project_symbols" }
|
||||
prompt_library = { path = "crates/prompt_library" }
|
||||
rules_library = { path = "crates/rules_library" }
|
||||
prompt_store = { path = "crates/prompt_store" }
|
||||
proto = { path = "crates/proto" }
|
||||
recent_projects = { path = "crates/recent_projects" }
|
||||
@@ -500,6 +500,7 @@ prost-types = "0.9"
|
||||
pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
quote = "1.0.9"
|
||||
rand = "0.8.5"
|
||||
ref-cast = "1.0.24"
|
||||
rayon = "1.8"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
|
||||
@@ -212,7 +212,7 @@
|
||||
"ctrl-shift-g": "search::SelectPreviousMatch",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-k h": "assistant::DeployHistory",
|
||||
"ctrl-k l": "assistant::OpenPromptLibrary",
|
||||
"ctrl-k l": "assistant::OpenRulesLibrary",
|
||||
"new": "assistant::NewChat",
|
||||
"ctrl-t": "assistant::NewChat",
|
||||
"ctrl-n": "assistant::NewChat"
|
||||
@@ -241,7 +241,7 @@
|
||||
"ctrl-alt-n": "agent::NewTextThread",
|
||||
"ctrl-shift-h": "agent::OpenHistory",
|
||||
"ctrl-alt-c": "agent::OpenConfiguration",
|
||||
"ctrl-alt-p": "assistant::OpenPromptLibrary",
|
||||
"ctrl-alt-p": "assistant::OpenRulesLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -308,9 +308,9 @@
|
||||
{
|
||||
"context": "PromptLibrary",
|
||||
"bindings": {
|
||||
"new": "prompt_library::NewPrompt",
|
||||
"ctrl-n": "prompt_library::NewPrompt",
|
||||
"ctrl-shift-s": "prompt_library::ToggleDefaultPrompt"
|
||||
"new": "rules_library::NewRule",
|
||||
"ctrl-n": "rules_library::NewRule",
|
||||
"ctrl-shift-s": "rules_library::ToggleDefaultRule"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -675,7 +675,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
"shift-enter": "editor::ExpandExcerpts",
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
"context": "PromptLibrary",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-n": "prompt_library::NewPrompt",
|
||||
"cmd-shift-s": "prompt_library::ToggleDefaultPrompt",
|
||||
"cmd-n": "rules_library::NewRule",
|
||||
"cmd-shift-s": "rules_library::ToggleDefaultRule",
|
||||
"cmd-w": "workspace::CloseWindow"
|
||||
}
|
||||
},
|
||||
@@ -257,7 +257,7 @@
|
||||
"cmd-shift-g": "search::SelectPreviousMatch",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-k h": "assistant::DeployHistory",
|
||||
"cmd-k l": "assistant::OpenPromptLibrary",
|
||||
"cmd-k l": "assistant::OpenRulesLibrary",
|
||||
"cmd-t": "assistant::NewChat",
|
||||
"cmd-n": "assistant::NewChat"
|
||||
}
|
||||
@@ -286,7 +286,7 @@
|
||||
"cmd-alt-n": "agent::NewTextThread",
|
||||
"cmd-shift-h": "agent::OpenHistory",
|
||||
"cmd-alt-c": "agent::OpenConfiguration",
|
||||
"cmd-alt-p": "assistant::OpenPromptLibrary",
|
||||
"cmd-alt-p": "assistant::OpenRulesLibrary",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -738,7 +738,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "Editor && mode == full",
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
|
||||
@@ -1,2 +1,7 @@
|
||||
allow-private-module-inception = true
|
||||
avoid-breaking-exported-api = false
|
||||
ignore-interior-mutability = [
|
||||
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
|
||||
# and Hash impls do not use fields with interior mutability.
|
||||
"agent::context::AgentContextKey"
|
||||
]
|
||||
|
||||
@@ -28,7 +28,6 @@ async-watch.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -62,9 +61,10 @@ parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
picker.workspace = true
|
||||
project.workspace = true
|
||||
prompt_library.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
ref-cast.workspace = true
|
||||
release_channel.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::context::{AssistantContext, ContextId, RULES_ICON, format_context_as_string};
|
||||
use crate::context::{AgentContext, RULES_ICON};
|
||||
use crate::context_picker::MentionLink;
|
||||
use crate::thread::{
|
||||
LastRestoreCheckpoint, MessageId, MessageSegment, Thread, ThreadError, ThreadEvent,
|
||||
@@ -25,8 +25,8 @@ use gpui::{
|
||||
};
|
||||
use language::{Buffer, LanguageRegistry};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, RequestUsage, Role,
|
||||
StopReason,
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
|
||||
RequestUsage, Role, StopReason,
|
||||
};
|
||||
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
|
||||
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
|
||||
@@ -43,16 +43,14 @@ use ui::{
|
||||
Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, TextSize, Tooltip, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use util::markdown::MarkdownString;
|
||||
use workspace::{OpenOptions, Workspace};
|
||||
use zed_actions::assistant::OpenPromptLibrary;
|
||||
|
||||
use crate::context_store::ContextStore;
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
|
||||
pub struct ActiveThread {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
context_store: Entity<ContextStore>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
save_thread_task: Option<Task<()>>,
|
||||
messages: Vec<MessageId>,
|
||||
@@ -695,7 +693,7 @@ fn open_markdown_link(
|
||||
}),
|
||||
Some(MentionLink::Fetch(url)) => cx.open_url(&url),
|
||||
Some(MentionLink::Rules(prompt_id)) => window.dispatch_action(
|
||||
Box::new(OpenPromptLibrary {
|
||||
Box::new(OpenRulesLibrary {
|
||||
prompt_to_select: Some(prompt_id.0),
|
||||
}),
|
||||
cx,
|
||||
@@ -716,7 +714,6 @@ impl ActiveThread {
|
||||
thread: Entity<Thread>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
context_store: Entity<ContextStore>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -739,7 +736,6 @@ impl ActiveThread {
|
||||
language_registry,
|
||||
thread_store,
|
||||
thread: thread.clone(),
|
||||
context_store,
|
||||
workspace,
|
||||
save_thread_task: None,
|
||||
messages: Vec::new(),
|
||||
@@ -769,7 +765,7 @@ impl ActiveThread {
|
||||
this.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
tool_use.status.text(),
|
||||
cx,
|
||||
);
|
||||
@@ -779,10 +775,6 @@ impl ActiveThread {
|
||||
this
|
||||
}
|
||||
|
||||
pub fn context_store(&self) -> &Entity<ContextStore> {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
pub fn thread(&self) -> &Entity<Thread> {
|
||||
&self.thread
|
||||
}
|
||||
@@ -870,7 +862,7 @@ impl ActiveThread {
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_label: impl Into<SharedString>,
|
||||
tool_input: &serde_json::Value,
|
||||
tool_input: &str,
|
||||
tool_output: SharedString,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
@@ -893,11 +885,10 @@ impl ActiveThread {
|
||||
this.replace(tool_label, cx);
|
||||
});
|
||||
rendered.input.update(cx, |this, cx| {
|
||||
let input = format!(
|
||||
"```json\n{}\n```",
|
||||
serde_json::to_string_pretty(tool_input).unwrap_or_default()
|
||||
this.replace(
|
||||
MarkdownString::code_block("json", tool_input).to_string(),
|
||||
cx,
|
||||
);
|
||||
this.replace(input, cx);
|
||||
});
|
||||
rendered.output.update(cx, |this, cx| {
|
||||
this.replace(tool_output, cx);
|
||||
@@ -988,7 +979,7 @@ impl ActiveThread {
|
||||
self.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
"".into(),
|
||||
cx,
|
||||
);
|
||||
@@ -1002,7 +993,7 @@ impl ActiveThread {
|
||||
self.render_tool_use_markdown(
|
||||
tool_use_id.clone(),
|
||||
ui_text.clone(),
|
||||
input,
|
||||
&serde_json::to_string_pretty(&input).unwrap_or_default(),
|
||||
"".into(),
|
||||
cx,
|
||||
);
|
||||
@@ -1014,7 +1005,7 @@ impl ActiveThread {
|
||||
self.render_tool_use_markdown(
|
||||
tool_use.id.clone(),
|
||||
tool_use.ui_text.clone(),
|
||||
&tool_use.input,
|
||||
&serde_json::to_string_pretty(&tool_use.input).unwrap_or_default(),
|
||||
self.thread
|
||||
.read(cx)
|
||||
.output_for_tool(&tool_use.id)
|
||||
@@ -1026,6 +1017,23 @@ impl ActiveThread {
|
||||
}
|
||||
ThreadEvent::CheckpointChanged => cx.notify(),
|
||||
ThreadEvent::ReceivedTextChunk => {}
|
||||
ThreadEvent::InvalidToolInput {
|
||||
tool_use_id,
|
||||
ui_text,
|
||||
invalid_input_json,
|
||||
} => {
|
||||
self.render_tool_use_markdown(
|
||||
tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json,
|
||||
self.thread
|
||||
.read(cx)
|
||||
.output_for_tool(tool_use_id)
|
||||
.map(|output| output.clone().into())
|
||||
.unwrap_or("".into()),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1256,26 +1264,36 @@ impl ActiveThread {
|
||||
}
|
||||
|
||||
let token_count = if let Some(task) = cx.update(|cx| {
|
||||
let context = thread.read(cx).context_for_message(message_id);
|
||||
let new_context = thread.read(cx).filter_new_context(context);
|
||||
let context_text =
|
||||
format_context_as_string(new_context, cx).unwrap_or(String::new());
|
||||
let Some(message) = thread.read(cx).message(message_id) else {
|
||||
log::error!("Message that was being edited no longer exists");
|
||||
return None;
|
||||
};
|
||||
let message_text = editor.read(cx).text(cx);
|
||||
|
||||
let content = context_text + &message_text;
|
||||
|
||||
if content.is_empty() {
|
||||
if message_text.is_empty() && message.loaded_context.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
message
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
|
||||
if !message_text.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message_text));
|
||||
}
|
||||
|
||||
let request = language_model::LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![content.into()],
|
||||
cache: false,
|
||||
}],
|
||||
messages: vec![request_message],
|
||||
tools: vec![],
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
@@ -1470,13 +1488,21 @@ impl ActiveThread {
|
||||
return Empty.into_any();
|
||||
};
|
||||
|
||||
let context_store = self.context_store.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let thread = self.thread.read(cx);
|
||||
let prompt_store = self.thread_store.read(cx).prompt_store().as_ref();
|
||||
|
||||
// Get all the data we need from thread before we start using it in closures
|
||||
let checkpoint = thread.checkpoint_for_message(message_id);
|
||||
let context = thread.context_for_message(message_id).collect::<Vec<_>>();
|
||||
let added_context = if let Some(workspace) = workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().read(cx);
|
||||
thread
|
||||
.context_for_message(message_id)
|
||||
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
return Empty.into_any();
|
||||
};
|
||||
|
||||
let tool_uses = thread.tool_uses_for_message(message_id, cx);
|
||||
let has_tool_uses = !tool_uses.is_empty();
|
||||
@@ -1485,41 +1511,13 @@ impl ActiveThread {
|
||||
let is_first_message = ix == 0;
|
||||
let is_last_message = ix == self.messages.len() - 1;
|
||||
|
||||
let show_feedback = (!is_generating && is_last_message && message.role != Role::User)
|
||||
|| self.messages.get(ix + 1).map_or(false, |next_id| {
|
||||
self.thread
|
||||
.read(cx)
|
||||
.message(*next_id)
|
||||
.map_or(false, |next_message| {
|
||||
next_message.role == Role::User
|
||||
&& thread.tool_uses_for_message(*next_id, cx).is_empty()
|
||||
&& thread.tool_results_for_message(*next_id).is_empty()
|
||||
})
|
||||
});
|
||||
let show_feedback = thread.is_turn_end(ix);
|
||||
|
||||
let needs_confirmation = tool_uses.iter().any(|tool_use| tool_use.needs_confirmation);
|
||||
|
||||
let generating_label = (is_generating && is_last_message)
|
||||
.then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
|
||||
|
||||
// Don't render user messages that are just there for returning tool results.
|
||||
if message.role == Role::User && thread.message_has_tool_results(message_id) {
|
||||
if let Some(generating_label) = generating_label {
|
||||
return h_flex()
|
||||
.w_full()
|
||||
.h_10()
|
||||
.py_1p5()
|
||||
.pl_4()
|
||||
.pb_3()
|
||||
.child(generating_label)
|
||||
.into_any_element();
|
||||
}
|
||||
|
||||
return Empty.into_any();
|
||||
}
|
||||
|
||||
let allow_editing_message = message.role == Role::User;
|
||||
|
||||
let edit_message_editor = self
|
||||
.editing_message
|
||||
.as_ref()
|
||||
@@ -1652,90 +1650,78 @@ impl ActiveThread {
|
||||
};
|
||||
|
||||
let message_is_empty = message.should_display_content();
|
||||
let has_content = !message_is_empty || !context.is_empty();
|
||||
let has_content = !message_is_empty || !added_context.is_empty();
|
||||
|
||||
let message_content =
|
||||
has_content.then(|| {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(
|
||||
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small.rems(cx);
|
||||
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
|
||||
let message_content = has_content.then(|| {
|
||||
v_flex()
|
||||
.gap_1()
|
||||
.when(!message_is_empty, |parent| {
|
||||
parent.child(
|
||||
if let Some(edit_message_editor) = edit_message_editor.clone() {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small.rems(cx);
|
||||
let line_height = font_size.to_pixels(window.rem_size()) * 1.5;
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
div()
|
||||
.key_context("EditMessageEditor")
|
||||
.on_action(cx.listener(Self::cancel_editing_message))
|
||||
.on_action(cx.listener(Self::confirm_editing_message))
|
||||
.min_h_6()
|
||||
.child(EditorElement::new(
|
||||
&edit_message_editor,
|
||||
EditorStyle {
|
||||
background: colors.editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
.into_any()
|
||||
} else {
|
||||
div()
|
||||
.min_h_6()
|
||||
.child(self.render_message_content(
|
||||
message_id,
|
||||
rendered_message,
|
||||
has_tool_uses,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
))
|
||||
.into_any()
|
||||
},
|
||||
)
|
||||
})
|
||||
.when(!context.is_empty(), |parent| {
|
||||
parent.child(h_flex().flex_wrap().gap_1().children(
|
||||
context.into_iter().map(|context| {
|
||||
let context_id = context.id();
|
||||
ContextPill::added(
|
||||
AddedContext::new(context, cx),
|
||||
false,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.on_click(Rc::new(cx.listener({
|
||||
div()
|
||||
.key_context("EditMessageEditor")
|
||||
.on_action(cx.listener(Self::cancel_editing_message))
|
||||
.on_action(cx.listener(Self::confirm_editing_message))
|
||||
.min_h_6()
|
||||
.child(EditorElement::new(
|
||||
&edit_message_editor,
|
||||
EditorStyle {
|
||||
background: colors.editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
.into_any()
|
||||
} else {
|
||||
div()
|
||||
.min_h_6()
|
||||
.child(self.render_message_content(
|
||||
message_id,
|
||||
rendered_message,
|
||||
has_tool_uses,
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
))
|
||||
.into_any()
|
||||
},
|
||||
)
|
||||
})
|
||||
.when(!added_context.is_empty(), |parent| {
|
||||
parent.child(h_flex().flex_wrap().gap_1().children(
|
||||
added_context.into_iter().map(|added_context| {
|
||||
let context = added_context.context.clone();
|
||||
ContextPill::added(added_context, false, false, None).on_click(Rc::new(
|
||||
cx.listener({
|
||||
let workspace = workspace.clone();
|
||||
let context_store = context_store.clone();
|
||||
move |_, _, window, cx| {
|
||||
if let Some(workspace) = workspace.upgrade() {
|
||||
open_context(
|
||||
context_id,
|
||||
context_store.clone(),
|
||||
workspace,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
open_context(&context, workspace, window, cx);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
})))
|
||||
}),
|
||||
))
|
||||
})
|
||||
});
|
||||
}),
|
||||
))
|
||||
}),
|
||||
))
|
||||
})
|
||||
});
|
||||
|
||||
let styled_message = match message.role {
|
||||
Role::User => v_flex()
|
||||
@@ -1752,93 +1738,70 @@ impl ActiveThread {
|
||||
.pb_4()
|
||||
.child(
|
||||
v_flex()
|
||||
.id(("user-message", ix))
|
||||
.bg(editor_bg_color)
|
||||
.rounded_lg()
|
||||
.shadow_md()
|
||||
.border_1()
|
||||
.border_color(colors.border)
|
||||
.shadow_md()
|
||||
.child(div().p_2().children(message_content))
|
||||
.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.border_t_1()
|
||||
.border_color(colors.border_variant)
|
||||
.justify_end()
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.when_some(
|
||||
edit_message_editor.clone(),
|
||||
|this, edit_message_editor| {
|
||||
let focus_handle =
|
||||
edit_message_editor.focus_handle(cx);
|
||||
this.child(
|
||||
Button::new("cancel-edit-message", "Cancel")
|
||||
.label_size(LabelSize::Small)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(
|
||||
cx.listener(Self::handle_cancel_click),
|
||||
),
|
||||
.hover(|hover| hover.border_color(colors.text_accent.opacity(0.5)))
|
||||
.cursor_pointer()
|
||||
.child(div().p_2().pt_2p5().children(message_content))
|
||||
.when_some(edit_message_editor.clone(), |this, edit_editor| {
|
||||
let focus_handle = edit_editor.focus_handle(cx);
|
||||
|
||||
this.child(
|
||||
h_flex()
|
||||
.p_1()
|
||||
.border_t_1()
|
||||
.border_color(colors.border_variant)
|
||||
.gap_1()
|
||||
.justify_end()
|
||||
.child(
|
||||
Button::new("cancel-edit-message", "Cancel")
|
||||
.label_size(LabelSize::Small)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Cancel,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.child(
|
||||
Button::new(
|
||||
"confirm-edit-message",
|
||||
"Regenerate",
|
||||
)
|
||||
.disabled(
|
||||
edit_message_editor.read(cx).is_empty(cx),
|
||||
)
|
||||
.label_size(LabelSize::Small)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(
|
||||
cx.listener(Self::handle_regenerate_click),
|
||||
),
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(Self::handle_cancel_click)),
|
||||
)
|
||||
.child(
|
||||
Button::new("confirm-edit-message", "Regenerate")
|
||||
.disabled(edit_editor.read(cx).is_empty(cx))
|
||||
.label_size(LabelSize::Small)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(
|
||||
&menu::Confirm,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
)
|
||||
.when(
|
||||
edit_message_editor.is_none() && allow_editing_message,
|
||||
|this| {
|
||||
this.child(
|
||||
Button::new("edit-message", "Edit Message")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Pencil)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(cx.listener({
|
||||
let message_segments =
|
||||
message.segments.clone();
|
||||
move |this, _, window, cx| {
|
||||
this.start_editing_message(
|
||||
message_id,
|
||||
&message_segments,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
)
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(cx.listener(Self::handle_regenerate_click)),
|
||||
),
|
||||
)
|
||||
})
|
||||
.when(edit_message_editor.is_none(), |this| {
|
||||
this.tooltip(Tooltip::text("Click To Edit"))
|
||||
})
|
||||
.on_click(cx.listener({
|
||||
let message_segments = message.segments.clone();
|
||||
move |this, _, window, cx| {
|
||||
this.start_editing_message(
|
||||
message_id,
|
||||
&message_segments,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})),
|
||||
),
|
||||
Role::Assistant => v_flex()
|
||||
.id(("message-container", ix))
|
||||
@@ -2995,11 +2958,11 @@ impl ActiveThread {
|
||||
.shape(ui::IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Ignored)
|
||||
// TODO: Figure out a way to pass focus handle here so we can display the `OpenPromptLibrary` keybinding
|
||||
// TODO: Figure out a way to pass focus handle here so we can display the `OpenRulesLibrary` keybinding
|
||||
.tooltip(Tooltip::text("View User Rules"))
|
||||
.on_click(move |_event, window, cx| {
|
||||
window.dispatch_action(
|
||||
Box::new(OpenPromptLibrary {
|
||||
Box::new(OpenRulesLibrary {
|
||||
prompt_to_select: first_user_rules_id,
|
||||
}),
|
||||
cx,
|
||||
@@ -3207,20 +3170,14 @@ impl Render for ActiveThread {
|
||||
}
|
||||
|
||||
pub(crate) fn open_context(
|
||||
id: ContextId,
|
||||
context_store: Entity<ContextStore>,
|
||||
context: &AgentContext,
|
||||
workspace: Entity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(context) = context_store.read(cx).context_for_id(id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
match context {
|
||||
AssistantContext::File(file_context) => {
|
||||
if let Some(project_path) = file_context.context_buffer.buffer.read(cx).project_path(cx)
|
||||
{
|
||||
AgentContext::File(file_context) => {
|
||||
if let Some(project_path) = file_context.project_path(cx) {
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace
|
||||
.open_path(project_path, None, true, window, cx)
|
||||
@@ -3228,7 +3185,8 @@ pub(crate) fn open_context(
|
||||
});
|
||||
}
|
||||
}
|
||||
AssistantContext::Directory(directory_context) => {
|
||||
|
||||
AgentContext::Directory(directory_context) => {
|
||||
let entry_id = directory_context.entry_id;
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
workspace.project().update(cx, |_project, cx| {
|
||||
@@ -3236,61 +3194,51 @@ pub(crate) fn open_context(
|
||||
})
|
||||
})
|
||||
}
|
||||
AssistantContext::Symbol(symbol_context) => {
|
||||
if let Some(project_path) = symbol_context
|
||||
.context_symbol
|
||||
.buffer
|
||||
.read(cx)
|
||||
.project_path(cx)
|
||||
{
|
||||
let snapshot = symbol_context.context_symbol.buffer.read(cx).snapshot();
|
||||
let target_position = symbol_context
|
||||
.context_symbol
|
||||
.id
|
||||
.range
|
||||
.start
|
||||
.to_point(&snapshot);
|
||||
|
||||
AgentContext::Symbol(symbol_context) => {
|
||||
let buffer = symbol_context.buffer.read(cx);
|
||||
if let Some(project_path) = buffer.project_path(cx) {
|
||||
let snapshot = buffer.snapshot();
|
||||
let target_position = symbol_context.range.start.to_point(&snapshot);
|
||||
open_editor_at_position(project_path, target_position, &workspace, window, cx)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
AssistantContext::Selection(selection_context) => {
|
||||
if let Some(project_path) = selection_context
|
||||
.context_buffer
|
||||
.buffer
|
||||
.read(cx)
|
||||
.project_path(cx)
|
||||
{
|
||||
let snapshot = selection_context.context_buffer.buffer.read(cx).snapshot();
|
||||
|
||||
AgentContext::Selection(selection_context) => {
|
||||
let buffer = selection_context.buffer.read(cx);
|
||||
if let Some(project_path) = buffer.project_path(cx) {
|
||||
let snapshot = buffer.snapshot();
|
||||
let target_position = selection_context.range.start.to_point(&snapshot);
|
||||
|
||||
open_editor_at_position(project_path, target_position, &workspace, window, cx)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
AssistantContext::FetchedUrl(fetched_url_context) => {
|
||||
|
||||
AgentContext::FetchedUrl(fetched_url_context) => {
|
||||
cx.open_url(&fetched_url_context.url);
|
||||
}
|
||||
AssistantContext::Thread(thread_context) => {
|
||||
let thread_id = thread_context.thread.read(cx).id().clone();
|
||||
workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel
|
||||
.open_thread(&thread_id, window, cx)
|
||||
.detach_and_log_err(cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
AssistantContext::Rules(rules_context) => window.dispatch_action(
|
||||
Box::new(OpenPromptLibrary {
|
||||
|
||||
AgentContext::Thread(thread_context) => workspace.update(cx, |workspace, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
panel.update(cx, |panel, cx| {
|
||||
let thread_id = thread_context.thread.read(cx).id().clone();
|
||||
panel
|
||||
.open_thread(&thread_id, window, cx)
|
||||
.detach_and_log_err(cx)
|
||||
});
|
||||
}
|
||||
}),
|
||||
|
||||
AgentContext::Rules(rules_context) => window.dispatch_action(
|
||||
Box::new(OpenRulesLibrary {
|
||||
prompt_to_select: Some(rules_context.prompt_id.0),
|
||||
}),
|
||||
cx,
|
||||
),
|
||||
AssistantContext::Image(_) => {}
|
||||
|
||||
AgentContext::Image(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -962,11 +962,13 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let prompt_store = None;
|
||||
let thread_store = cx
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
prompt_store,
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -39,6 +39,7 @@ use thread::ThreadId;
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
|
||||
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
|
||||
@@ -24,9 +24,9 @@ use language::LanguageRegistry;
|
||||
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
|
||||
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
|
||||
use proto::Plan;
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use time::UtcOffset;
|
||||
use ui::{
|
||||
@@ -36,7 +36,7 @@ use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
use workspace::dock::{DockPosition, Panel, PanelEvent};
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
|
||||
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
||||
|
||||
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
|
||||
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
|
||||
@@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_prompt_library(action, window, cx)
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
@@ -189,6 +189,7 @@ pub struct AssistantPanel {
|
||||
message_editor: Entity<MessageEditor>,
|
||||
_active_thread_subscriptions: Vec<Subscription>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
configuration: Option<Entity<AssistantConfiguration>>,
|
||||
configuration_subscription: Option<Subscription>,
|
||||
local_timezone: UtcOffset,
|
||||
@@ -205,14 +206,25 @@ impl AssistantPanel {
|
||||
pub fn load(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: AsyncWindowContext,
|
||||
mut cx: AsyncWindowContext,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = match prompt_store {
|
||||
Ok(prompt_store) => prompt_store.await.ok(),
|
||||
Err(_) => None,
|
||||
};
|
||||
let tools = cx.new(|_| ToolWorkingSet::default())?;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
ThreadStore::load(
|
||||
project,
|
||||
tools.clone(),
|
||||
prompt_store.clone(),
|
||||
prompt_builder.clone(),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
@@ -230,7 +242,16 @@ impl AssistantPanel {
|
||||
.await?;
|
||||
|
||||
workspace.update_in(cx, |workspace, window, cx| {
|
||||
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
|
||||
cx.new(|cx| {
|
||||
Self::new(
|
||||
workspace,
|
||||
thread_store,
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -239,6 +260,7 @@ impl AssistantPanel {
|
||||
workspace: &Workspace,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -262,6 +284,7 @@ impl AssistantPanel {
|
||||
fs.clone(),
|
||||
workspace.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
thread.clone(),
|
||||
window,
|
||||
@@ -293,7 +316,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -322,6 +344,7 @@ impl AssistantPanel {
|
||||
message_editor_subscription,
|
||||
],
|
||||
context_store,
|
||||
prompt_store,
|
||||
configuration: None,
|
||||
configuration_subscription: None,
|
||||
local_timezone: UtcOffset::from_whole_seconds(
|
||||
@@ -355,6 +378,10 @@ impl AssistantPanel {
|
||||
self.local_timezone
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
|
||||
&self.thread_store
|
||||
}
|
||||
@@ -411,7 +438,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -430,6 +456,7 @@ impl AssistantPanel {
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -484,13 +511,13 @@ impl AssistantPanel {
|
||||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn deploy_prompt_library(
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenPromptLibrary,
|
||||
action: &OpenRulesLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_prompt_library(
|
||||
open_rules_library(
|
||||
self.language_registry.clone(),
|
||||
Box::new(PromptLibraryInlineAssist::new(self.workspace.clone())),
|
||||
Arc::new(|| {
|
||||
@@ -500,9 +527,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action.prompt_to_select.map(|uuid| PromptId::User {
|
||||
uuid: UserPromptId(uuid),
|
||||
}),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -598,7 +625,6 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
this.thread_store.clone(),
|
||||
this.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
this.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -617,6 +643,7 @@ impl AssistantPanel {
|
||||
this.fs.clone(),
|
||||
this.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
this.prompt_store.clone(),
|
||||
this.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -1120,7 +1147,7 @@ impl AssistantPanel {
|
||||
"New Text Thread",
|
||||
NewTextThread.boxed_clone(),
|
||||
)
|
||||
.action("Prompt Library", Box::new(OpenPromptLibrary::default()))
|
||||
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.separator()
|
||||
.header("MCPs")
|
||||
@@ -1833,7 +1860,7 @@ impl Render for AssistantPanel {
|
||||
this.open_configuration(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(Self::open_active_thread_as_markdown))
|
||||
.on_action(cx.listener(Self::deploy_prompt_library))
|
||||
.on_action(cx.listener(Self::deploy_rules_library))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.on_action(cx.listener(Self::go_back))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
@@ -1860,13 +1887,13 @@ impl PromptLibraryInlineAssist {
|
||||
}
|
||||
}
|
||||
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
_initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
let Some(project) = self
|
||||
@@ -1876,11 +1903,14 @@ impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let prompt_store = None;
|
||||
let thread_store = None;
|
||||
assistant.assist(
|
||||
&prompt_editor,
|
||||
self.workspace.clone(),
|
||||
project,
|
||||
None,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -1959,8 +1989,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
// being updated.
|
||||
cx.defer_in(window, move |panel, window, cx| {
|
||||
if panel.has_active_thread() {
|
||||
panel.thread.update(cx, |thread, cx| {
|
||||
thread.context_store().update(cx, |store, cx| {
|
||||
panel.message_editor.update(cx, |message_editor, cx| {
|
||||
message_editor.context_store().update(cx, |store, cx| {
|
||||
let buffer = buffer.read(cx);
|
||||
let selection_ranges = selection_ranges
|
||||
.into_iter()
|
||||
@@ -1977,9 +2007,7 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (buffer, range) in selection_ranges {
|
||||
store
|
||||
.add_selection(buffer, range, cx)
|
||||
.detach_and_log_err(cx);
|
||||
store.add_selection(buffer, range, cx);
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context::ContextLoadResult;
|
||||
use crate::inline_prompt_editor::CodegenStatus;
|
||||
use crate::{context::load_context, context_store::ContextStore};
|
||||
use anyhow::Result;
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
@@ -14,7 +16,9 @@ use language_model::{
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
@@ -39,6 +43,8 @@ pub struct BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
pub is_insertion: bool,
|
||||
@@ -50,6 +56,8 @@ impl BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -60,6 +68,8 @@ impl BufferCodegen {
|
||||
range.clone(),
|
||||
false,
|
||||
Some(context_store.clone()),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
Some(telemetry.clone()),
|
||||
builder.clone(),
|
||||
cx,
|
||||
@@ -75,6 +85,8 @@ impl BufferCodegen {
|
||||
range,
|
||||
initial_transaction_id,
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
builder,
|
||||
};
|
||||
@@ -153,6 +165,8 @@ impl BufferCodegen {
|
||||
self.range.clone(),
|
||||
false,
|
||||
Some(self.context_store.clone()),
|
||||
self.project.clone(),
|
||||
self.prompt_store.clone(),
|
||||
Some(self.telemetry.clone()),
|
||||
self.builder.clone(),
|
||||
cx,
|
||||
@@ -229,13 +243,14 @@ pub struct CodegenAlternative {
|
||||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
builder: Arc<PromptBuilder>,
|
||||
active: bool,
|
||||
edits: Vec<(Range<Anchor>, String)>,
|
||||
line_operations: Vec<LineOperation>,
|
||||
request: Option<LanguageModelRequest>,
|
||||
elapsed_time: Option<f64>,
|
||||
completion: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
@@ -249,6 +264,8 @@ impl CodegenAlternative {
|
||||
range: Range<Anchor>,
|
||||
active: bool,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -290,6 +307,8 @@ impl CodegenAlternative {
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
builder,
|
||||
@@ -297,7 +316,6 @@ impl CodegenAlternative {
|
||||
edits: Vec::new(),
|
||||
line_operations: Vec::new(),
|
||||
range,
|
||||
request: None,
|
||||
elapsed_time: None,
|
||||
completion: None,
|
||||
}
|
||||
@@ -366,16 +384,18 @@ impl CodegenAlternative {
|
||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(user_prompt, cx)?;
|
||||
self.request = Some(request.clone());
|
||||
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
@@ -408,30 +428,44 @@ impl CodegenAlternative {
|
||||
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
let context_task = self.context_store.as_ref().map(|context_store| {
|
||||
if let Some(project) = self.project.upgrade() {
|
||||
let context = context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
load_context(context, &project, &self.prompt_store, cx)
|
||||
} else {
|
||||
Task::ready(ContextLoadResult::default())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(context_store) = &self.context_store {
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
if let Some(context_task) = context_task {
|
||||
context_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
@@ -508,7 +542,9 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks = StripInvalidSpans::new(
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
@@ -1034,6 +1070,7 @@ impl Diff {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fs::FakeFs;
|
||||
use futures::{
|
||||
Stream,
|
||||
stream::{self},
|
||||
@@ -1076,12 +1113,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1140,12 +1181,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1207,12 +1252,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1274,12 +1323,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1329,12 +1382,16 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
false,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,8 +10,11 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use completion_provider::ContextPickerCompletionProvider;
|
||||
use editor::display_map::{Crease, FoldId};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use fetch_context_picker::FetchContextPicker;
|
||||
use file_context_picker::FileContextPicker;
|
||||
use file_context_picker::render_file_context_entry;
|
||||
use gpui::{
|
||||
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
|
||||
@@ -20,10 +23,10 @@ use gpui::{
|
||||
use language::Buffer;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use project::{Entry, ProjectPath};
|
||||
use prompt_store::UserPromptId;
|
||||
use rules_context_picker::RulesContextEntry;
|
||||
use prompt_store::{PromptStore, UserPromptId};
|
||||
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
|
||||
use symbol_context_picker::SymbolContextPicker;
|
||||
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
|
||||
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
|
||||
use ui::{
|
||||
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
|
||||
};
|
||||
@@ -32,11 +35,6 @@ use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::AssistantPanel;
|
||||
use crate::context::RULES_ICON;
|
||||
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::context_picker::fetch_context_picker::FetchContextPicker;
|
||||
use crate::context_picker::file_context_picker::FileContextPicker;
|
||||
use crate::context_picker::rules_context_picker::RulesContextPicker;
|
||||
use crate::context_picker::thread_context_picker::ThreadContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::ThreadId;
|
||||
use crate::thread_store::ThreadStore;
|
||||
@@ -166,6 +164,7 @@ pub(super) struct ContextPicker {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -193,6 +192,13 @@ impl ContextPicker {
|
||||
)
|
||||
.collect::<Vec<Subscription>>();
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
ContextPicker {
|
||||
mode: ContextPickerState::Default(ContextMenu::build(
|
||||
window,
|
||||
@@ -202,6 +208,7 @@ impl ContextPicker {
|
||||
workspace,
|
||||
context_store,
|
||||
thread_store,
|
||||
prompt_store,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
@@ -226,7 +233,12 @@ impl ContextPicker {
|
||||
.workspace
|
||||
.upgrade()
|
||||
.map(|workspace| {
|
||||
available_context_picker_entries(&self.thread_store, &workspace, cx)
|
||||
available_context_picker_entries(
|
||||
&self.prompt_store,
|
||||
&self.thread_store,
|
||||
&workspace,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
@@ -304,10 +316,10 @@ impl ContextPicker {
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Rules => {
|
||||
if let Some(thread_store) = self.thread_store.as_ref() {
|
||||
if let Some(prompt_store) = self.prompt_store.as_ref() {
|
||||
self.mode = ContextPickerState::Rules(cx.new(|cx| {
|
||||
RulesContextPicker::new(
|
||||
thread_store.clone(),
|
||||
prompt_store.clone(),
|
||||
context_picker.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
@@ -526,6 +538,7 @@ enum RecentEntry {
|
||||
}
|
||||
|
||||
fn available_context_picker_entries(
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
thread_store: &Option<WeakEntity<ThreadStore>>,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -550,6 +563,9 @@ fn available_context_picker_entries(
|
||||
|
||||
if thread_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
|
||||
}
|
||||
|
||||
if prompt_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
|
||||
}
|
||||
|
||||
@@ -585,22 +601,21 @@ fn recent_context_picker_entries(
|
||||
}),
|
||||
);
|
||||
|
||||
let mut current_threads = context_store.read(cx).thread_ids();
|
||||
let current_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
if let Some(active_thread) = workspace
|
||||
let active_thread_id = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|panel| panel.read(cx).active_thread(cx))
|
||||
{
|
||||
current_threads.insert(active_thread.read(cx).id().clone());
|
||||
}
|
||||
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
|
||||
|
||||
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
|
||||
recent.extend(
|
||||
thread_store
|
||||
.read(cx)
|
||||
.threads()
|
||||
.reverse_chronological_threads()
|
||||
.into_iter()
|
||||
.filter(|thread| !current_threads.contains(&thread.id))
|
||||
.filter(|thread| {
|
||||
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
|
||||
})
|
||||
.take(2)
|
||||
.map(|thread| {
|
||||
RecentEntry::Thread(ThreadContextEntry {
|
||||
@@ -622,9 +637,7 @@ fn add_selections_as_context(
|
||||
let selection_ranges = selection_ranges(workspace, cx);
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in selection_ranges {
|
||||
context_store
|
||||
.add_selection(buffer, range, cx)
|
||||
.detach_and_log_err(cx);
|
||||
context_store.add_selection(buffer, range, cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,22 +15,21 @@ use itertools::Itertools;
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use lsp::CompletionContext;
|
||||
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
|
||||
use prompt_store::PromptId;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Point;
|
||||
use text::{Anchor, OffsetRangeExt, ToPoint};
|
||||
use ui::prelude::*;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::file_context_picker::search_files;
|
||||
use crate::context_picker::symbol_context_picker::search_symbols;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
use super::fetch_context_picker::fetch_url_content;
|
||||
use super::file_context_picker::FileMatch;
|
||||
use super::file_context_picker::{FileMatch, search_files};
|
||||
use super::rules_context_picker::{RulesContextEntry, search_rules};
|
||||
use super::symbol_context_picker::SymbolMatch;
|
||||
use super::symbol_context_picker::search_symbols;
|
||||
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
|
||||
use super::{
|
||||
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
|
||||
@@ -38,8 +37,8 @@ use super::{
|
||||
};
|
||||
|
||||
pub(crate) enum Match {
|
||||
Symbol(SymbolMatch),
|
||||
File(FileMatch),
|
||||
Symbol(SymbolMatch),
|
||||
Thread(ThreadMatch),
|
||||
Fetch(SharedString),
|
||||
Rules(RulesContextEntry),
|
||||
@@ -69,6 +68,7 @@ fn search(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
recent_entries: Vec<RecentEntry>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
workspace: Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -85,6 +85,7 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Symbol) => {
|
||||
let search_symbols_task =
|
||||
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
@@ -96,6 +97,7 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Thread) => {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
let search_threads_task =
|
||||
@@ -111,6 +113,7 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Fetch) => {
|
||||
if !query.is_empty() {
|
||||
Task::ready(vec![Match::Fetch(query.into())])
|
||||
@@ -118,10 +121,11 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Rules) => {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
let search_rules_task =
|
||||
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
|
||||
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_rules_task
|
||||
.await
|
||||
@@ -133,6 +137,7 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
if query.is_empty() {
|
||||
let mut matches = recent_entries
|
||||
@@ -163,7 +168,7 @@ fn search(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
matches.extend(
|
||||
available_context_picker_entries(&thread_store, &workspace, cx)
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
|
||||
.into_iter()
|
||||
.map(|mode| {
|
||||
Match::Entry(EntryMatch {
|
||||
@@ -180,7 +185,8 @@ fn search(
|
||||
let search_files_task =
|
||||
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
|
||||
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
|
||||
let entries =
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
|
||||
let entry_candidates = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -307,9 +313,11 @@ impl ContextPickerCompletionProvider {
|
||||
move |_, _: &mut Window, cx: &mut App| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in &selections {
|
||||
context_store
|
||||
.add_selection(buffer.clone(), range.clone(), cx)
|
||||
.detach_and_log_err(cx)
|
||||
context_store.add_selection(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -437,7 +445,6 @@ impl ContextPickerCompletionProvider {
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
context_store: Entity<ContextStore>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
) -> Completion {
|
||||
let new_text = MentionLink::for_rules(&rules);
|
||||
let new_text_len = new_text.len();
|
||||
@@ -457,29 +464,10 @@ impl ContextPickerCompletionProvider {
|
||||
new_text_len,
|
||||
editor.clone(),
|
||||
move |cx| {
|
||||
let prompt_uuid = rules.prompt_id;
|
||||
let prompt_id = PromptId::User { uuid: prompt_uuid };
|
||||
let context_store = context_store.clone();
|
||||
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
|
||||
log::error!("Can't add user rules as prompt store is missing.");
|
||||
return;
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(title) = metadata.title else {
|
||||
return;
|
||||
};
|
||||
let text_task = prompt_store.load(prompt_id, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let text = text_task.await?;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(prompt_uuid, title, text, false, cx)
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
let user_prompt_id = rules.prompt_id;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(user_prompt_id, false, cx);
|
||||
});
|
||||
},
|
||||
)),
|
||||
}
|
||||
@@ -516,7 +504,7 @@ impl ContextPickerCompletionProvider {
|
||||
let url_to_fetch = url_to_fetch.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if context_store.update(cx, |context_store, _| {
|
||||
context_store.includes_url(&url_to_fetch).is_some()
|
||||
context_store.includes_url(&url_to_fetch)
|
||||
})? {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -592,7 +580,7 @@ impl ContextPickerCompletionProvider {
|
||||
move |cx| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
let task = if is_directory {
|
||||
context_store.add_directory(project_path.clone(), false, cx)
|
||||
Task::ready(context_store.add_directory(&project_path, false, cx))
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path.clone(), false, cx)
|
||||
};
|
||||
@@ -732,11 +720,19 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
);
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
let search_task = search(
|
||||
mode,
|
||||
query,
|
||||
Arc::<AtomicBool>::default(),
|
||||
recent_entries,
|
||||
prompt_store,
|
||||
thread_store.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
@@ -768,6 +764,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
|
||||
symbol,
|
||||
excerpt_id,
|
||||
@@ -777,6 +774,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
workspace.clone(),
|
||||
cx,
|
||||
),
|
||||
|
||||
Match::Thread(ThreadMatch {
|
||||
thread, is_recent, ..
|
||||
}) => {
|
||||
@@ -791,17 +789,15 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
Match::Rules(user_rules) => {
|
||||
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
|
||||
Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Rules(user_rules) => Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
)),
|
||||
|
||||
Match::Fetch(url) => Some(Self::completion_for_fetch(
|
||||
source_range.clone(),
|
||||
url,
|
||||
@@ -810,6 +806,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
context_store.clone(),
|
||||
http_client.clone(),
|
||||
)),
|
||||
|
||||
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
|
||||
entry,
|
||||
excerpt_id,
|
||||
|
||||
@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let added = self.context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store.read(cx).includes_url(&self.url).is_some()
|
||||
context_store.read(cx).includes_url(&self.url)
|
||||
});
|
||||
|
||||
Some(
|
||||
|
||||
@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
if is_directory {
|
||||
context_store.add_directory(project_path, true, cx)
|
||||
Task::ready(context_store.add_directory(&project_path, true, cx))
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path, true, cx)
|
||||
context_store.add_file_from_path(project_path.clone(), true, cx)
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
|
||||
path: path.clone(),
|
||||
};
|
||||
if is_directory {
|
||||
context_store.read(cx).includes_directory(&project_path)
|
||||
} else {
|
||||
context_store
|
||||
.read(cx)
|
||||
.will_include_file_path(&project_path, cx)
|
||||
.path_included_in_directory(&project_path, cx)
|
||||
} else {
|
||||
context_store.read(cx).file_path_included(&project_path, cx)
|
||||
}
|
||||
});
|
||||
|
||||
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
|
||||
})),
|
||||
)
|
||||
.when_some(added, |el, added| match added {
|
||||
FileInclusion::Direct(_) => el.child(
|
||||
FileInclusion::Direct => el.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_end()
|
||||
@@ -369,9 +369,8 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Added").size(LabelSize::Small)),
|
||||
),
|
||||
FileInclusion::InDirectory(directory_project_path) => {
|
||||
// TODO: Consider using worktree full_path to include worktree name.
|
||||
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
|
||||
FileInclusion::InDirectory { full_path } => {
|
||||
let directory_full_path = full_path.to_string_lossy().into_owned();
|
||||
|
||||
el.child(
|
||||
h_flex()
|
||||
@@ -385,7 +384,7 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Included").size(LabelSize::Small)),
|
||||
)
|
||||
.tooltip(Tooltip::text(format!("in {directory_path}")))
|
||||
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use prompt_store::{PromptId, PromptStore, UserPromptId};
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::{self, ContextStore};
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
pub struct RulesContextPicker {
|
||||
picker: Entity<Picker<RulesContextPickerDelegate>>,
|
||||
@@ -18,13 +17,13 @@ pub struct RulesContextPicker {
|
||||
|
||||
impl RulesContextPicker {
|
||||
pub fn new(
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
|
||||
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
RulesContextPicker { picker }
|
||||
@@ -50,7 +49,7 @@ pub struct RulesContextEntry {
|
||||
}
|
||||
|
||||
pub struct RulesContextPickerDelegate {
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
matches: Vec<RulesContextEntry>,
|
||||
@@ -59,12 +58,12 @@ pub struct RulesContextPickerDelegate {
|
||||
|
||||
impl RulesContextPickerDelegate {
|
||||
pub fn new(
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
prompt_store: Entity<PromptStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
) -> Self {
|
||||
RulesContextPickerDelegate {
|
||||
thread_store,
|
||||
prompt_store,
|
||||
context_picker,
|
||||
context_store,
|
||||
matches: Vec::new(),
|
||||
@@ -103,11 +102,12 @@ impl PickerDelegate for RulesContextPickerDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return Task::ready(());
|
||||
};
|
||||
|
||||
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
|
||||
let search_task = search_rules(
|
||||
query,
|
||||
Arc::new(AtomicBool::default()),
|
||||
&self.prompt_store,
|
||||
cx,
|
||||
);
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let matches = search_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
@@ -124,31 +124,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let prompt_id = entry.prompt_id;
|
||||
|
||||
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
|
||||
thread_store.load_rules(prompt_id, cx)
|
||||
});
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (metadata, text) = load_rules_task.await?;
|
||||
let Some(title) = metadata.title else {
|
||||
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.delegate
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(prompt_id, title, text, true, cx)
|
||||
})
|
||||
.ok();
|
||||
self.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(entry.prompt_id, true, cx)
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
.log_err();
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
@@ -179,11 +159,10 @@ pub fn render_thread_context_entry(
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store
|
||||
let added = context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store
|
||||
.read(cx)
|
||||
.includes_user_rules(&user_rules.prompt_id)
|
||||
.is_some()
|
||||
.includes_user_rules(user_rules.prompt_id)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -218,12 +197,9 @@ pub fn render_thread_context_entry(
|
||||
pub(crate) fn search_rules(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
prompt_store: &Entity<PromptStore>,
|
||||
cx: &mut App,
|
||||
) -> Task<Vec<RulesContextEntry>> {
|
||||
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
|
||||
return Task::ready(vec![]);
|
||||
};
|
||||
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_task
|
||||
|
||||
@@ -10,7 +10,6 @@ use gpui::{
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use project::{DocumentSymbol, Symbol};
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
@@ -228,18 +227,16 @@ pub(crate) fn add_symbol(
|
||||
)
|
||||
})?;
|
||||
|
||||
context_store
|
||||
.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
context_store.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -353,38 +350,13 @@ fn compute_symbol_entries(
|
||||
context_store: &ContextStore,
|
||||
cx: &App,
|
||||
) -> Vec<SymbolEntry> {
|
||||
let mut symbol_entries = Vec::with_capacity(symbols.len());
|
||||
for SymbolMatch { symbol, .. } in symbols {
|
||||
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
|
||||
let is_included = if let Some(symbols_for_path) = symbols_for_path {
|
||||
let mut is_included = false;
|
||||
for included_symbol_id in symbols_for_path {
|
||||
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
|
||||
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let included_symbol_range =
|
||||
included_symbol_id.range.to_point_utf16(&snapshot);
|
||||
|
||||
if included_symbol_range.start == symbol.range.start.0
|
||||
&& included_symbol_range.end == symbol.range.end.0
|
||||
{
|
||||
is_included = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
is_included
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
symbol_entries.push(SymbolEntry {
|
||||
symbols
|
||||
.into_iter()
|
||||
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
|
||||
is_included: context_store.includes_symbol(&symbol, cx),
|
||||
symbol,
|
||||
is_included,
|
||||
})
|
||||
}
|
||||
symbol_entries
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
|
||||
|
||||
@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store.read(cx).includes_thread(&thread.id).is_some()
|
||||
ctx_store.read(cx).includes_thread(&thread.id)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -219,7 +219,7 @@ pub(crate) fn search_threads(
|
||||
) -> Task<Vec<ThreadMatch>> {
|
||||
let threads = thread_store
|
||||
.read(cx)
|
||||
.threads()
|
||||
.reverse_chronological_threads()
|
||||
.into_iter()
|
||||
.map(|thread| ThreadContextEntry {
|
||||
id: thread.id,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,9 @@ use itertools::Itertools;
|
||||
use language::Buffer;
|
||||
use project::ProjectItem;
|
||||
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context::{ContextId, ContextKind};
|
||||
use crate::context::{AgentContext, ContextKind};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::Thread;
|
||||
@@ -32,6 +32,7 @@ pub struct ContextStrip {
|
||||
focus_handle: FocusHandle,
|
||||
suggest_context_kind: SuggestContextKind,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
focused_index: Option<usize>,
|
||||
children_bounds: Option<Vec<Bounds<Pixels>>>,
|
||||
@@ -73,12 +74,31 @@ impl ContextStrip {
|
||||
focus_handle,
|
||||
suggest_context_kind,
|
||||
workspace,
|
||||
thread_store,
|
||||
_subscriptions: subscriptions,
|
||||
focused_index: None,
|
||||
children_bounds: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().read(cx);
|
||||
let prompt_store = self
|
||||
.thread_store
|
||||
.as_ref()
|
||||
.and_then(|thread_store| thread_store.upgrade())
|
||||
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
|
||||
self.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
match self.suggest_context_kind {
|
||||
SuggestContextKind::File => self.suggested_file(cx),
|
||||
@@ -93,22 +113,19 @@ impl ContextStrip {
|
||||
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
|
||||
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
|
||||
let active_buffer = active_buffer_entity.read(cx);
|
||||
|
||||
let project_path = active_buffer.project_path(cx)?;
|
||||
|
||||
if self
|
||||
.context_store
|
||||
.read(cx)
|
||||
.will_include_buffer(active_buffer.remote_id(), &project_path)
|
||||
.file_path_included(&project_path, cx)
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let file_name = active_buffer.file()?.file_name(cx);
|
||||
|
||||
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
|
||||
|
||||
Some(SuggestedContext::File {
|
||||
name: file_name.to_string_lossy().into_owned().into(),
|
||||
buffer: active_buffer_entity.downgrade(),
|
||||
@@ -135,7 +152,6 @@ impl ContextStrip {
|
||||
.context_store
|
||||
.read(cx)
|
||||
.includes_thread(active_thread.id())
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -272,12 +288,12 @@ impl ContextStrip {
|
||||
best.map(|(index, _, _)| index)
|
||||
}
|
||||
|
||||
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
|
||||
fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
|
||||
crate::active_thread::open_context(context, workspace, window, cx);
|
||||
}
|
||||
|
||||
fn remove_focused_context(
|
||||
@@ -287,17 +303,17 @@ impl ContextStrip {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(index) = self.focused_index {
|
||||
let mut is_empty = false;
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let Some(context) = added_contexts.get(index) else {
|
||||
return;
|
||||
};
|
||||
|
||||
self.context_store.update(cx, |this, cx| {
|
||||
if let Some(item) = this.context().get(index) {
|
||||
this.remove_context(item.id(), cx);
|
||||
}
|
||||
|
||||
is_empty = this.context().is_empty();
|
||||
this.remove_context(&context.context, cx);
|
||||
});
|
||||
|
||||
if is_empty {
|
||||
let is_now_empty = added_contexts.len() == 1;
|
||||
if is_now_empty {
|
||||
cx.emit(ContextStripEvent::BlurredEmpty);
|
||||
} else {
|
||||
self.focused_index = Some(index.saturating_sub(1));
|
||||
@@ -306,49 +322,28 @@ impl ContextStrip {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
|
||||
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
|
||||
// We only suggest one item after the actual context
|
||||
self.focused_index == Some(context.len())
|
||||
self.focused_index == Some(added_contexts.len())
|
||||
}
|
||||
|
||||
fn accept_suggested_context(
|
||||
&mut self,
|
||||
_: &AcceptSuggestedContext,
|
||||
window: &mut Window,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(suggested) = self.suggested_context(cx) {
|
||||
let context_store = self.context_store.read(cx);
|
||||
|
||||
if self.is_suggested_focused(context_store.context()) {
|
||||
self.add_suggested_context(&suggested, window, cx);
|
||||
if self.is_suggested_focused(&self.added_contexts(cx)) {
|
||||
self.add_suggested_context(&suggested, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_suggested_context(
|
||||
&mut self,
|
||||
suggested: &SuggestedContext,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let task = self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.accept_suggested_context(&suggested, cx)
|
||||
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
|
||||
self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_suggested_context(&suggested, cx)
|
||||
});
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
match task.await.notify_async_err(cx) {
|
||||
None => {}
|
||||
Some(()) => {
|
||||
if let Some(this) = this.upgrade() {
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
@@ -361,17 +356,10 @@ impl Focusable for ContextStrip {
|
||||
|
||||
impl Render for ContextStrip {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let context_store = self.context_store.read(cx);
|
||||
let context = context_store.context();
|
||||
let context_picker = self.context_picker.clone();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let suggested_context = self.suggested_context(cx);
|
||||
|
||||
let added_contexts = context
|
||||
.iter()
|
||||
.map(|c| AddedContext::new(c, cx))
|
||||
.collect::<Vec<_>>();
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let dupe_names = added_contexts
|
||||
.iter()
|
||||
.map(|c| c.name.clone())
|
||||
@@ -380,6 +368,14 @@ impl Render for ContextStrip {
|
||||
.filter(|(a, b)| a == b)
|
||||
.map(|(a, _)| a)
|
||||
.collect::<HashSet<SharedString>>();
|
||||
let no_added_context = added_contexts.is_empty();
|
||||
|
||||
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
|
||||
(
|
||||
suggested_context,
|
||||
self.is_suggested_focused(&added_contexts),
|
||||
)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
.flex_wrap()
|
||||
@@ -436,7 +432,7 @@ impl Render for ContextStrip {
|
||||
})
|
||||
.with_handle(self.context_picker_menu_handle.clone()),
|
||||
)
|
||||
.when(context.is_empty() && suggested_context.is_none(), {
|
||||
.when(no_added_context && suggested_context.is_none(), {
|
||||
|parent| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
@@ -466,16 +462,17 @@ impl Render for ContextStrip {
|
||||
.enumerate()
|
||||
.map(|(i, added_context)| {
|
||||
let name = added_context.name.clone();
|
||||
let id = added_context.id;
|
||||
let context = added_context.context.clone();
|
||||
ContextPill::added(
|
||||
added_context,
|
||||
dupe_names.contains(&name),
|
||||
self.focused_index == Some(i),
|
||||
Some({
|
||||
let context = context.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
Rc::new(cx.listener(move |_this, _event, _window, cx| {
|
||||
context_store.update(cx, |this, cx| {
|
||||
this.remove_context(id, cx);
|
||||
this.remove_context(&context, cx);
|
||||
});
|
||||
cx.notify();
|
||||
}))
|
||||
@@ -484,7 +481,7 @@ impl Render for ContextStrip {
|
||||
.on_click({
|
||||
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
|
||||
if event.down.click_count > 1 {
|
||||
this.open_context(id, window, cx);
|
||||
this.open_context(&context, window, cx);
|
||||
} else {
|
||||
this.focused_index = Some(i);
|
||||
}
|
||||
@@ -493,22 +490,22 @@ impl Render for ContextStrip {
|
||||
})
|
||||
}),
|
||||
)
|
||||
.when_some(suggested_context, |el, suggested| {
|
||||
.when_some(suggested_context, |el, (suggested, focused)| {
|
||||
el.child(
|
||||
ContextPill::suggested(
|
||||
suggested.name().clone(),
|
||||
suggested.icon_path(),
|
||||
suggested.kind(),
|
||||
self.is_suggested_focused(&context),
|
||||
focused,
|
||||
)
|
||||
.on_click(Rc::new(cx.listener(
|
||||
move |this, _event, window, cx| {
|
||||
this.add_suggested_context(&suggested, window, cx);
|
||||
move |this, _event, _window, cx| {
|
||||
this.add_suggested_context(&suggested, cx);
|
||||
},
|
||||
))),
|
||||
)
|
||||
})
|
||||
.when(!context.is_empty(), {
|
||||
.when(!no_added_context, {
|
||||
move |parent| {
|
||||
parent.child(
|
||||
IconButton::new("remove-all-context", IconName::Eraser)
|
||||
@@ -534,6 +531,7 @@ impl Render for ContextStrip {
|
||||
)
|
||||
}
|
||||
})
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,10 @@ impl HistoryStore {
|
||||
return history_entries;
|
||||
}
|
||||
|
||||
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
|
||||
for thread in self
|
||||
.thread_store
|
||||
.update(cx, |this, _cx| this.reverse_chronological_threads())
|
||||
{
|
||||
history_entries.push(HistoryEntry::Thread(thread));
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ use project::LspAction;
|
||||
use project::Project;
|
||||
use project::{CodeAction, ProjectTransaction};
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
|
||||
@@ -245,9 +246,13 @@ impl InlineAssistant {
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx))
|
||||
};
|
||||
|
||||
let thread_store = workspace
|
||||
let assistant_panel = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
|
||||
.map(|assistant_panel| assistant_panel.read(cx));
|
||||
let prompt_store = assistant_panel
|
||||
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
|
||||
let thread_store =
|
||||
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
|
||||
|
||||
let handle_assist =
|
||||
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
|
||||
@@ -257,6 +262,7 @@ impl InlineAssistant {
|
||||
&active_editor,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -269,6 +275,7 @@ impl InlineAssistant {
|
||||
&active_terminal,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -323,6 +330,7 @@ impl InlineAssistant {
|
||||
editor: &Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -437,6 +445,8 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
None,
|
||||
context_store.clone(),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -525,6 +535,7 @@ impl InlineAssistant {
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
focus: bool,
|
||||
workspace: Entity<Workspace>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -543,7 +554,7 @@ impl InlineAssistant {
|
||||
}
|
||||
|
||||
let project = workspace.read(cx).project().downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
|
||||
|
||||
let codegen = cx.new(|cx| {
|
||||
BufferCodegen::new(
|
||||
@@ -551,6 +562,8 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
initial_transaction_id,
|
||||
context_store.clone(),
|
||||
project,
|
||||
prompt_store,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -1789,6 +1802,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
let editor = self.editor.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
window.spawn(cx, async move |cx| {
|
||||
let workspace = workspace.upgrade().context("workspace was released")?;
|
||||
let editor = editor.upgrade().context("editor was released")?;
|
||||
@@ -1829,6 +1843,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
})?
|
||||
.context("invalid range")?;
|
||||
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
|
||||
let assist_id = assistant.suggest_assist(
|
||||
&editor,
|
||||
@@ -1837,6 +1852,7 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
None,
|
||||
true,
|
||||
workspace,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
|
||||
@@ -13,7 +13,7 @@ use editor::{
|
||||
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
|
||||
actions::{MoveDown, MoveUp},
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedPro};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
|
||||
@@ -132,7 +132,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -931,7 +931,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedPro>()
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::ModelType;
|
||||
use crate::context::{AssistantContext, format_context_as_string};
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::HashSet;
|
||||
@@ -13,6 +13,8 @@ use editor::{
|
||||
};
|
||||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, future};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
|
||||
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
@@ -22,6 +24,7 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
@@ -31,7 +34,7 @@ use workspace::Workspace;
|
||||
|
||||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
|
||||
use crate::context_store::{ContextStore, refresh_context_store_text};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::profile_selector::ProfileSelector;
|
||||
use crate::thread::{Thread, TokenUsageRatio};
|
||||
@@ -49,9 +52,12 @@ pub struct MessageEditor {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AssistantModelSelector>,
|
||||
last_loaded_context: Option<ContextLoadResult>,
|
||||
context_load_task: Option<Shared<Task<()>>>,
|
||||
profile_selector: Entity<ProfileSelector>,
|
||||
edits_expanded: bool,
|
||||
editor_is_expanded: bool,
|
||||
@@ -68,6 +74,7 @@ impl MessageEditor {
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
window: &mut Window,
|
||||
@@ -135,13 +142,11 @@ impl MessageEditor {
|
||||
let subscriptions = vec![
|
||||
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
|
||||
cx.subscribe(&editor, |this, _, event, cx| match event {
|
||||
EditorEvent::BufferEdited => {
|
||||
this.message_or_context_changed(true, cx);
|
||||
}
|
||||
EditorEvent::BufferEdited => this.handle_message_changed(cx),
|
||||
_ => {}
|
||||
}),
|
||||
cx.observe(&context_store, |this, _, cx| {
|
||||
this.message_or_context_changed(false, cx);
|
||||
let _ = this.start_context_load(cx);
|
||||
}),
|
||||
];
|
||||
|
||||
@@ -152,8 +157,11 @@ impl MessageEditor {
|
||||
incompatible_tools_state: incompatible_tools.clone(),
|
||||
workspace,
|
||||
context_store,
|
||||
prompt_store,
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
context_load_task: None,
|
||||
last_loaded_context: None,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
@@ -175,6 +183,10 @@ impl MessageEditor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_store(&self) -> &Entity<ContextStore> {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.notify();
|
||||
}
|
||||
@@ -214,6 +226,7 @@ impl MessageEditor {
|
||||
) {
|
||||
self.context_picker_menu_handle.toggle(window, cx);
|
||||
}
|
||||
|
||||
pub fn remove_all_context(
|
||||
&mut self,
|
||||
_: &RemoveAllContext,
|
||||
@@ -270,68 +283,22 @@ impl MessageEditor {
|
||||
self.last_estimated_token_count.take();
|
||||
cx.emit(MessageEditorEvent::EstimatedTokenCount);
|
||||
|
||||
let refresh_task =
|
||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
|
||||
|
||||
let thread = self.thread.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
|
||||
let context_task = self.load_context(cx);
|
||||
let window_handle = window.window_handle();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
refresh_task.await;
|
||||
wait_for_images.await;
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
|
||||
let loaded_context = loaded_context.unwrap_or_default();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let context = context_store.read(cx).context().clone();
|
||||
thread.insert_user_message(user_message, context, checkpoint, cx);
|
||||
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
let excerpt_ids = context_store
|
||||
.context()
|
||||
.iter()
|
||||
.filter(|ctx| {
|
||||
matches!(
|
||||
ctx,
|
||||
AssistantContext::Selection(_) | AssistantContext::Image(_)
|
||||
)
|
||||
})
|
||||
.map(|ctx| ctx.id())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for id in excerpt_ids {
|
||||
context_store.remove_context(id, cx);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(wait_for_summaries) = context_store
|
||||
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
|
||||
.log_err()
|
||||
{
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = true;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
|
||||
wait_for_summaries.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = false;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
// Send to model after summaries are done
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
@@ -342,6 +309,30 @@ impl MessageEditor {
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
|
||||
let context_store = self.context_store.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
if let Some(wait_for_summaries) = context_store
|
||||
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
|
||||
.ok()
|
||||
{
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = true;
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
|
||||
wait_for_summaries.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = false;
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let cancelled = self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_last_completion(Some(window.window_handle()), cx)
|
||||
@@ -1015,6 +1006,48 @@ impl MessageEditor {
|
||||
self.update_token_count_task.is_some()
|
||||
}
|
||||
|
||||
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
|
||||
self.message_or_context_changed(true, cx);
|
||||
}
|
||||
|
||||
fn start_context_load(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
|
||||
let summaries_task = self.wait_for_summaries(cx);
|
||||
let load_task = cx.spawn(async move |this, cx| {
|
||||
// Waits for detailed summaries before `load_context`, as it directly reads these from
|
||||
// the thread. TODO: Would be cleaner to have context loading await on summarization.
|
||||
summaries_task.await;
|
||||
let Ok(load_task) = this.update(cx, |this, cx| {
|
||||
let new_context = this.context_store.read_with(cx, |context_store, cx| {
|
||||
context_store.new_context_for_thread(this.thread.read(cx))
|
||||
});
|
||||
load_context(new_context, &this.project, &this.prompt_store, cx)
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let result = load_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.last_loaded_context = Some(result);
|
||||
this.context_load_task = None;
|
||||
this.message_or_context_changed(false, cx);
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
// Replace existing load task, if any, causing it to be cancelled.
|
||||
let load_task = load_task.shared();
|
||||
self.context_load_task = Some(load_task.clone());
|
||||
load_task
|
||||
}
|
||||
|
||||
fn load_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
|
||||
let context_load_task = self.start_context_load(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
context_load_task.await;
|
||||
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Changed);
|
||||
self.update_token_count_task.take();
|
||||
@@ -1024,9 +1057,7 @@ impl MessageEditor {
|
||||
return;
|
||||
};
|
||||
|
||||
let context_store = self.context_store.clone();
|
||||
let editor = self.editor.clone();
|
||||
let thread = self.thread.clone();
|
||||
|
||||
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
|
||||
if debounce {
|
||||
@@ -1035,27 +1066,33 @@ impl MessageEditor {
|
||||
.await;
|
||||
}
|
||||
|
||||
let token_count = if let Some(task) = cx.update(|cx| {
|
||||
let context = context_store.read(cx).context().iter();
|
||||
let new_context = thread.read(cx).filter_new_context(context);
|
||||
let context_text =
|
||||
format_context_as_string(new_context, cx).unwrap_or(String::new());
|
||||
let token_count = if let Some(task) = this.update(cx, |this, cx| {
|
||||
let loaded_context = this
|
||||
.last_loaded_context
|
||||
.as_ref()
|
||||
.map(|context_load_result| &context_load_result.loaded_context);
|
||||
let message_text = editor.read(cx).text(cx);
|
||||
|
||||
let content = context_text + &message_text;
|
||||
|
||||
if content.is_empty() {
|
||||
if message_text.is_empty()
|
||||
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
if let Some(loaded_context) = loaded_context {
|
||||
loaded_context.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
let request = language_model::LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![content.into()],
|
||||
cache: false,
|
||||
}],
|
||||
messages: vec![request_message],
|
||||
tools: vec![],
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
|
||||
@@ -32,7 +32,7 @@ impl TerminalCodegen {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
|
||||
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
@@ -45,6 +45,7 @@ impl TerminalCodegen {
|
||||
self.status = CodegenStatus::Pending;
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
self.generation = cx.spawn(async move |this, cx| {
|
||||
let prompt = prompt_task.await;
|
||||
let model_telemetry_id = model.telemetry_id();
|
||||
let model_provider_id = model.provider_id();
|
||||
let response = model.stream_completion_text(prompt, &cx).await;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context::load_context;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::inline_prompt_editor::{
|
||||
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
|
||||
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use editor::{MultiBuffer, actions::SelectAll};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
|
||||
use language::Buffer;
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::{PromptBuilder, PromptStore};
|
||||
use std::sync::Arc;
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::TerminalView;
|
||||
@@ -69,6 +69,7 @@ impl TerminalInlineAssistant {
|
||||
terminal_view: &Entity<TerminalView>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -109,6 +110,7 @@ impl TerminalInlineAssistant {
|
||||
prompt_editor,
|
||||
workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -196,11 +198,11 @@ impl TerminalInlineAssistant {
|
||||
.log_err();
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
return;
|
||||
};
|
||||
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
|
||||
}
|
||||
|
||||
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
|
||||
@@ -217,7 +219,7 @@ impl TerminalInlineAssistant {
|
||||
&self,
|
||||
assist_id: TerminalInlineAssistId,
|
||||
cx: &mut App,
|
||||
) -> Result<LanguageModelRequest> {
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
||||
|
||||
let shell = std::env::var("SHELL").ok();
|
||||
@@ -246,28 +248,40 @@ impl TerminalInlineAssistant {
|
||||
&latest_output,
|
||||
)?;
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
let contexts = assist
|
||||
.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
|
||||
let project = workspace.project();
|
||||
load_context(contexts, project, &assist.prompt_store, cx)
|
||||
})?;
|
||||
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
assist.context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
Ok(cx.background_spawn(async move {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
context_load_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
})
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
fn finish_assist(
|
||||
@@ -380,6 +394,7 @@ struct TerminalInlineAssist {
|
||||
codegen: Entity<TerminalCodegen>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -390,6 +405,7 @@ impl TerminalInlineAssist {
|
||||
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
@@ -400,6 +416,7 @@ impl TerminalInlineAssist {
|
||||
codegen: codegen.clone(),
|
||||
workspace: workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
_subscriptions: vec![
|
||||
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
|
||||
TerminalInlineAssistant::update_global(cx, |this, cx| {
|
||||
|
||||
@@ -8,7 +8,7 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use collections::HashMap;
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt, StreamExt as _};
|
||||
@@ -17,8 +17,8 @@ use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelImage, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
|
||||
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
|
||||
ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
|
||||
@@ -35,7 +35,7 @@ use thiserror::Error;
|
||||
use util::{ResultExt as _, TryFutureExt as _, post_inc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, format_context_as_string};
|
||||
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
|
||||
use crate::thread_store::{
|
||||
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
|
||||
SerializedToolUse, SharedProjectContext,
|
||||
@@ -98,8 +98,7 @@ pub struct Message {
|
||||
pub id: MessageId,
|
||||
pub role: Role,
|
||||
pub segments: Vec<MessageSegment>,
|
||||
pub context: String,
|
||||
pub images: Vec<LanguageModelImage>,
|
||||
pub loaded_context: LoadedContext,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
@@ -138,8 +137,8 @@ impl Message {
|
||||
pub fn to_string(&self) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
if !self.context.is_empty() {
|
||||
result.push_str(&self.context);
|
||||
if !self.loaded_context.text.is_empty() {
|
||||
result.push_str(&self.loaded_context.text);
|
||||
}
|
||||
|
||||
for segment in &self.segments {
|
||||
@@ -294,8 +293,6 @@ pub struct Thread {
|
||||
messages: Vec<Message>,
|
||||
next_message_id: MessageId,
|
||||
last_prompt_id: PromptId,
|
||||
context: BTreeMap<ContextId, AssistantContext>,
|
||||
context_by_message: HashMap<MessageId, Vec<ContextId>>,
|
||||
project_context: SharedProjectContext,
|
||||
checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
|
||||
completion_count: usize,
|
||||
@@ -345,8 +342,6 @@ impl Thread {
|
||||
messages: Vec::new(),
|
||||
next_message_id: MessageId(0),
|
||||
last_prompt_id: PromptId::new(),
|
||||
context: BTreeMap::default(),
|
||||
context_by_message: HashMap::default(),
|
||||
project_context: system_prompt,
|
||||
checkpoints_by_message: HashMap::default(),
|
||||
completion_count: 0,
|
||||
@@ -391,8 +386,7 @@ impl Thread {
|
||||
.map(|message| message.id.0 + 1)
|
||||
.unwrap_or(0),
|
||||
);
|
||||
let tool_use =
|
||||
ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
|
||||
let tool_use = ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages);
|
||||
|
||||
Self {
|
||||
id,
|
||||
@@ -419,14 +413,15 @@ impl Thread {
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
context: message.context,
|
||||
images: Vec::new(),
|
||||
loaded_context: LoadedContext {
|
||||
contexts: Vec::new(),
|
||||
text: message.context,
|
||||
images: Vec::new(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
next_message_id,
|
||||
last_prompt_id: PromptId::new(),
|
||||
context: BTreeMap::default(),
|
||||
context_by_message: HashMap::default(),
|
||||
project_context,
|
||||
checkpoints_by_message: HashMap::default(),
|
||||
completion_count: 0,
|
||||
@@ -524,7 +519,12 @@ impl Thread {
|
||||
}
|
||||
|
||||
pub fn message(&self, id: MessageId) -> Option<&Message> {
|
||||
self.messages.iter().find(|message| message.id == id)
|
||||
let index = self
|
||||
.messages
|
||||
.binary_search_by(|message| message.id.cmp(&id))
|
||||
.ok()?;
|
||||
|
||||
self.messages.get(index)
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> impl ExactSizeIterator<Item = &Message> {
|
||||
@@ -656,21 +656,43 @@ impl Thread {
|
||||
return;
|
||||
};
|
||||
for deleted_message in self.messages.drain(message_ix..) {
|
||||
self.context_by_message.remove(&deleted_message.id);
|
||||
self.checkpoints_by_message.remove(&deleted_message.id);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
|
||||
self.context_by_message
|
||||
.get(&id)
|
||||
pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AgentContext> {
|
||||
self.messages
|
||||
.iter()
|
||||
.find(|message| message.id == id)
|
||||
.into_iter()
|
||||
.flat_map(|context| {
|
||||
context
|
||||
.iter()
|
||||
.filter_map(|context_id| self.context.get(&context_id))
|
||||
.flat_map(|message| message.loaded_context.contexts.iter())
|
||||
}
|
||||
|
||||
pub fn is_turn_end(&self, ix: usize) -> bool {
|
||||
if self.messages.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if !self.is_generating() && ix == self.messages.len() - 1 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(message) = self.messages.get(ix) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
if message.role != Role::Assistant {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.messages
|
||||
.get(ix + 1)
|
||||
.and_then(|message| {
|
||||
self.message(message.id)
|
||||
.map(|next_message| next_message.role == Role::User)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Returns whether all of the tool uses have finished running.
|
||||
@@ -687,8 +709,11 @@ impl Thread {
|
||||
self.tool_use.tool_uses_for_message(id, cx)
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
self.tool_use.tool_results_for_message(id)
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
self.tool_use.tool_results_for_message(assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
|
||||
@@ -703,95 +728,27 @@ impl Thread {
|
||||
self.tool_use.tool_result_card(id).cloned()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_use.message_has_tool_results(message_id)
|
||||
}
|
||||
|
||||
/// Filter out contexts that have already been included in previous messages
|
||||
pub fn filter_new_context<'a>(
|
||||
&self,
|
||||
context: impl Iterator<Item = &'a AssistantContext>,
|
||||
) -> impl Iterator<Item = &'a AssistantContext> {
|
||||
context.filter(|ctx| self.is_context_new(ctx))
|
||||
}
|
||||
|
||||
fn is_context_new(&self, context: &AssistantContext) -> bool {
|
||||
!self.context.contains_key(&context.id())
|
||||
}
|
||||
|
||||
pub fn insert_user_message(
|
||||
&mut self,
|
||||
text: impl Into<String>,
|
||||
context: Vec<AssistantContext>,
|
||||
loaded_context: ContextLoadResult,
|
||||
git_checkpoint: Option<GitStoreCheckpoint>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let text = text.into();
|
||||
|
||||
let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
|
||||
|
||||
let new_context: Vec<_> = context
|
||||
.into_iter()
|
||||
.filter(|ctx| self.is_context_new(ctx))
|
||||
.collect();
|
||||
|
||||
if !new_context.is_empty() {
|
||||
if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
|
||||
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
|
||||
message.context = context_string;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
|
||||
message.images = new_context
|
||||
.iter()
|
||||
.filter_map(|context| {
|
||||
if let AssistantContext::Image(image_context) = context {
|
||||
image_context.image_task.clone().now_or_never().flatten()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
|
||||
if !loaded_context.referenced_buffers.is_empty() {
|
||||
self.action_log.update(cx, |log, cx| {
|
||||
// Track all buffers added as context
|
||||
for ctx in &new_context {
|
||||
match ctx {
|
||||
AssistantContext::File(file_ctx) => {
|
||||
log.track_buffer(file_ctx.context_buffer.buffer.clone(), cx);
|
||||
}
|
||||
AssistantContext::Directory(dir_ctx) => {
|
||||
for context_buffer in &dir_ctx.context_buffers {
|
||||
log.track_buffer(context_buffer.buffer.clone(), cx);
|
||||
}
|
||||
}
|
||||
AssistantContext::Symbol(symbol_ctx) => {
|
||||
log.track_buffer(symbol_ctx.context_symbol.buffer.clone(), cx);
|
||||
}
|
||||
AssistantContext::Selection(selection_context) => {
|
||||
log.track_buffer(selection_context.context_buffer.buffer.clone(), cx);
|
||||
}
|
||||
AssistantContext::FetchedUrl(_)
|
||||
| AssistantContext::Thread(_)
|
||||
| AssistantContext::Rules(_)
|
||||
| AssistantContext::Image(_) => {}
|
||||
}
|
||||
for buffer in loaded_context.referenced_buffers {
|
||||
log.track_buffer(buffer, cx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let context_ids = new_context
|
||||
.iter()
|
||||
.map(|context| context.id())
|
||||
.collect::<Vec<_>>();
|
||||
self.context.extend(
|
||||
new_context
|
||||
.into_iter()
|
||||
.map(|context| (context.id(), context)),
|
||||
let message_id = self.insert_message(
|
||||
Role::User,
|
||||
vec![MessageSegment::Text(text.into())],
|
||||
loaded_context.loaded_context,
|
||||
cx,
|
||||
);
|
||||
self.context_by_message.insert(message_id, context_ids);
|
||||
|
||||
if let Some(git_checkpoint) = git_checkpoint {
|
||||
self.pending_checkpoint = Some(ThreadCheckpoint {
|
||||
@@ -805,10 +762,19 @@ impl Thread {
|
||||
message_id
|
||||
}
|
||||
|
||||
pub fn insert_assistant_message(
|
||||
&mut self,
|
||||
segments: Vec<MessageSegment>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
self.insert_message(Role::Assistant, segments, LoadedContext::default(), cx)
|
||||
}
|
||||
|
||||
pub fn insert_message(
|
||||
&mut self,
|
||||
role: Role,
|
||||
segments: Vec<MessageSegment>,
|
||||
loaded_context: LoadedContext,
|
||||
cx: &mut Context<Self>,
|
||||
) -> MessageId {
|
||||
let id = self.next_message_id.post_inc();
|
||||
@@ -816,8 +782,7 @@ impl Thread {
|
||||
id,
|
||||
role,
|
||||
segments,
|
||||
context: String::new(),
|
||||
images: Vec::new(),
|
||||
loaded_context,
|
||||
});
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageAdded(id));
|
||||
@@ -846,7 +811,6 @@ impl Thread {
|
||||
return false;
|
||||
};
|
||||
self.messages.remove(index);
|
||||
self.context_by_message.remove(&id);
|
||||
self.touch_updated_at();
|
||||
cx.emit(ThreadEvent::MessageDeleted(id));
|
||||
true
|
||||
@@ -933,7 +897,7 @@ impl Thread {
|
||||
content: tool_result.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
context: message.context.clone(),
|
||||
context: message.loaded_context.text.clone(),
|
||||
})
|
||||
.collect(),
|
||||
initial_project_snapshot,
|
||||
@@ -967,26 +931,21 @@ impl Thread {
|
||||
|
||||
let mut request = self.to_completion_request(cx);
|
||||
if model.supports_tools() {
|
||||
request.tools = {
|
||||
let mut tools = Vec::new();
|
||||
tools.extend(
|
||||
self.tools()
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
// Skip tools that cannot be supported
|
||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
}),
|
||||
);
|
||||
|
||||
tools
|
||||
};
|
||||
request.tools = self
|
||||
.tools()
|
||||
.read(cx)
|
||||
.enabled_tools(cx)
|
||||
.into_iter()
|
||||
.filter_map(|tool| {
|
||||
// Skip tools that cannot be supported
|
||||
let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
|
||||
Some(LanguageModelRequestTool {
|
||||
name: tool.name(),
|
||||
description: tool.description(),
|
||||
input_schema,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
self.stream_completion(request, model, window, cx);
|
||||
@@ -1051,29 +1010,9 @@ impl Thread {
|
||||
cache: false,
|
||||
};
|
||||
|
||||
self.tool_use
|
||||
.attach_tool_results(message.id, &mut request_message);
|
||||
|
||||
if !message.context.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text(message.context.to_string()));
|
||||
}
|
||||
|
||||
if !message.images.is_empty() {
|
||||
// Some providers only support image parts after an initial text part
|
||||
if request_message.content.is_empty() {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Text("Images attached by user:".to_string()));
|
||||
}
|
||||
|
||||
for image in &message.images {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::Image(image.clone()))
|
||||
}
|
||||
}
|
||||
message
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
@@ -1103,10 +1042,11 @@ impl Thread {
|
||||
self.tool_use
|
||||
.attach_tool_uses(message.id, &mut request_message);
|
||||
|
||||
// Skip empty messages to avoid sending them to the LLM model
|
||||
// if !request_message.content.is_empty() {
|
||||
request.messages.push(request_message);
|
||||
// }
|
||||
|
||||
if let Some(tool_results_message) = self.tool_use.tool_results_message(message.id) {
|
||||
request.messages.push(tool_results_message);
|
||||
}
|
||||
}
|
||||
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
@@ -1136,11 +1076,6 @@ impl Thread {
|
||||
cache: false,
|
||||
};
|
||||
|
||||
// Skip tool results during summarization.
|
||||
if self.tool_use.message_has_tool_results(message.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for segment in &message.segments {
|
||||
match segment {
|
||||
MessageSegment::Text(text) => request_message
|
||||
@@ -1245,22 +1180,45 @@ impl Thread {
|
||||
.ok();
|
||||
}
|
||||
|
||||
let mut request_assistant_message_id = None;
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
if let Some((_, response_events)) = request_callback_parameters.as_mut() {
|
||||
response_events
|
||||
.push(event.as_ref().map_err(|error| error.to_string()).cloned());
|
||||
}
|
||||
|
||||
let event = event?;
|
||||
|
||||
thread.update(cx, |thread, cx| {
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text(String::new())],
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id,
|
||||
tool_name,
|
||||
raw_input: invalid_input_json,
|
||||
json_parse_error,
|
||||
}) => {
|
||||
thread.receive_invalid_tool_json(
|
||||
id,
|
||||
tool_name,
|
||||
invalid_input_json,
|
||||
json_parse_error,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
Err(LanguageModelCompletionError::Other(error)) => {
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
request_assistant_message_id =
|
||||
Some(thread.insert_assistant_message(
|
||||
vec![MessageSegment::Text(String::new())],
|
||||
cx,
|
||||
));
|
||||
}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
stop_reason = reason;
|
||||
@@ -1275,7 +1233,9 @@ impl Thread {
|
||||
LanguageModelCompletionEvent::Text(chunk) => {
|
||||
cx.emit(ThreadEvent::ReceivedTextChunk);
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
if last_message.role == Role::Assistant
|
||||
&& !thread.tool_use.has_tool_results(last_message.id)
|
||||
{
|
||||
last_message.push_text(&chunk);
|
||||
cx.emit(ThreadEvent::StreamedAssistantText(
|
||||
last_message.id,
|
||||
@@ -1287,11 +1247,11 @@ impl Thread {
|
||||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Text(chunk.to_string())],
|
||||
cx,
|
||||
);
|
||||
request_assistant_message_id =
|
||||
Some(thread.insert_assistant_message(
|
||||
vec![MessageSegment::Text(chunk.to_string())],
|
||||
cx,
|
||||
));
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1300,7 +1260,9 @@ impl Thread {
|
||||
signature,
|
||||
} => {
|
||||
if let Some(last_message) = thread.messages.last_mut() {
|
||||
if last_message.role == Role::Assistant {
|
||||
if last_message.role == Role::Assistant
|
||||
&& !thread.tool_use.has_tool_results(last_message.id)
|
||||
{
|
||||
last_message.push_thinking(&chunk, signature);
|
||||
cx.emit(ThreadEvent::StreamedAssistantThinking(
|
||||
last_message.id,
|
||||
@@ -1312,25 +1274,25 @@ impl Thread {
|
||||
//
|
||||
// Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
|
||||
// will result in duplicating the text of the chunk in the rendered Markdown.
|
||||
thread.insert_message(
|
||||
Role::Assistant,
|
||||
vec![MessageSegment::Thinking {
|
||||
text: chunk.to_string(),
|
||||
signature,
|
||||
}],
|
||||
cx,
|
||||
);
|
||||
request_assistant_message_id =
|
||||
Some(thread.insert_assistant_message(
|
||||
vec![MessageSegment::Thinking {
|
||||
text: chunk.to_string(),
|
||||
signature,
|
||||
}],
|
||||
cx,
|
||||
));
|
||||
};
|
||||
}
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(tool_use) => {
|
||||
let last_assistant_message_id = thread
|
||||
.messages
|
||||
.iter_mut()
|
||||
.rfind(|message| message.role == Role::Assistant)
|
||||
.map(|message| message.id)
|
||||
let last_assistant_message_id = request_assistant_message_id
|
||||
.unwrap_or_else(|| {
|
||||
thread.insert_message(Role::Assistant, vec![], cx)
|
||||
let new_assistant_message_id =
|
||||
thread.insert_assistant_message(vec![], cx);
|
||||
request_assistant_message_id =
|
||||
Some(new_assistant_message_id);
|
||||
new_assistant_message_id
|
||||
});
|
||||
|
||||
let tool_use_id = tool_use.id.clone();
|
||||
@@ -1362,7 +1324,8 @@ impl Thread {
|
||||
cx.notify();
|
||||
|
||||
thread.auto_capture_telemetry(cx);
|
||||
})?;
|
||||
Ok(())
|
||||
})??;
|
||||
|
||||
smol::future::yield_now().await;
|
||||
}
|
||||
@@ -1653,6 +1616,41 @@ impl Thread {
|
||||
pending_tool_uses
|
||||
}
|
||||
|
||||
pub fn receive_invalid_tool_json(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
invalid_json: Arc<str>,
|
||||
error: String,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Thread>,
|
||||
) {
|
||||
log::error!("The model returned invalid input JSON: {invalid_json}");
|
||||
|
||||
let pending_tool_use = self.tool_use.insert_tool_output(
|
||||
tool_use_id.clone(),
|
||||
tool_name,
|
||||
Err(anyhow!("Error parsing input JSON: {error}")),
|
||||
cx,
|
||||
);
|
||||
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
|
||||
pending_tool_use.ui_text.clone()
|
||||
} else {
|
||||
log::error!(
|
||||
"There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)."
|
||||
);
|
||||
format!("Unknown tool {}", tool_use_id).into()
|
||||
};
|
||||
|
||||
cx.emit(ThreadEvent::InvalidToolInput {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
ui_text,
|
||||
invalid_input_json: invalid_json,
|
||||
});
|
||||
|
||||
self.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
|
||||
}
|
||||
|
||||
pub fn run_tool(
|
||||
&mut self,
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
@@ -1726,32 +1724,21 @@ impl Thread {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.all_tools_finished() {
|
||||
dbg!("All tools finished");
|
||||
let model_registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
|
||||
dbg!("Attach tool results");
|
||||
self.attach_tool_results(cx);
|
||||
if !canceled {
|
||||
self.send_to_model(model, window, cx);
|
||||
}
|
||||
self.auto_capture_telemetry(cx);
|
||||
}
|
||||
}
|
||||
|
||||
println!("ToolFinished {}", tool_use_id);
|
||||
cx.emit(ThreadEvent::ToolFinished {
|
||||
tool_use_id,
|
||||
pending_tool_use,
|
||||
});
|
||||
}
|
||||
|
||||
/// Insert an empty message to be populated with tool results upon send.
|
||||
pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
|
||||
// Tool results are assumed to be waiting on the next message id, so they will populate
|
||||
// this empty message before sending to model. Would prefer this to be more straightforward.
|
||||
self.insert_message(Role::User, vec![], cx);
|
||||
self.auto_capture_telemetry(cx);
|
||||
}
|
||||
|
||||
/// Cancels the last pending completion, if there are any pending.
|
||||
///
|
||||
/// Returns whether a completion was canceled.
|
||||
@@ -1760,22 +1747,19 @@ impl Thread {
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> bool {
|
||||
let canceled = if self.pending_completions.pop().is_some() {
|
||||
true
|
||||
} else {
|
||||
let mut canceled = false;
|
||||
for pending_tool_use in self.tool_use.cancel_pending() {
|
||||
canceled = true;
|
||||
self.tool_finished(
|
||||
pending_tool_use.id.clone(),
|
||||
Some(pending_tool_use),
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
canceled
|
||||
};
|
||||
let mut canceled = self.pending_completions.pop().is_some();
|
||||
|
||||
for pending_tool_use in self.tool_use.cancel_pending() {
|
||||
canceled = true;
|
||||
self.tool_finished(
|
||||
pending_tool_use.id.clone(),
|
||||
Some(pending_tool_use),
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
self.finalize_pending_checkpoint(cx);
|
||||
canceled
|
||||
}
|
||||
@@ -2026,8 +2010,16 @@ impl Thread {
|
||||
}
|
||||
)?;
|
||||
|
||||
if !message.context.is_empty() {
|
||||
writeln!(markdown, "{}", message.context)?;
|
||||
if !message.loaded_context.text.is_empty() {
|
||||
writeln!(markdown, "{}", message.loaded_context.text)?;
|
||||
}
|
||||
|
||||
if !message.loaded_context.images.is_empty() {
|
||||
writeln!(
|
||||
markdown,
|
||||
"\n{} images attached as context.\n",
|
||||
message.loaded_context.images.len()
|
||||
)?;
|
||||
}
|
||||
|
||||
for segment in &message.segments {
|
||||
@@ -2056,7 +2048,7 @@ impl Thread {
|
||||
}
|
||||
|
||||
for tool_result in self.tool_results_for_message(message.id) {
|
||||
write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
|
||||
write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?;
|
||||
if tool_result.is_error {
|
||||
write!(markdown, " (Error)")?;
|
||||
}
|
||||
@@ -2105,7 +2097,7 @@ impl Thread {
|
||||
}
|
||||
|
||||
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
|
||||
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
|
||||
if !cx.has_flag::<feature_flags::ThreadAutoCaptureFeatureFlag>() {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2268,6 +2260,11 @@ pub enum ThreadEvent {
|
||||
ui_text: Arc<str>,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
InvalidToolInput {
|
||||
tool_use_id: LanguageModelToolUseId,
|
||||
ui_text: Arc<str>,
|
||||
invalid_input_json: Arc<str>,
|
||||
},
|
||||
Stopped(Result<StopReason, Arc<anyhow::Error>>),
|
||||
MessageAdded(MessageId),
|
||||
MessageEdited(MessageId),
|
||||
@@ -2297,7 +2294,7 @@ struct PendingCompletion {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ThreadStore, context_store::ContextStore, thread_store};
|
||||
use crate::{ThreadStore, context::load_context, context_store::ContextStore, thread_store};
|
||||
use assistant_settings::AssistantSettings;
|
||||
use context_server::ContextServerSettings;
|
||||
use editor::EditorSettings;
|
||||
@@ -2328,12 +2325,14 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context =
|
||||
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
|
||||
let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(vec![context], &project, &None, cx))
|
||||
.await;
|
||||
|
||||
// Insert user message with context
|
||||
let message_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Please explain this code", vec![context], None, cx)
|
||||
thread.insert_user_message("Please explain this code", loaded_context, None, cx)
|
||||
});
|
||||
|
||||
// Check content and context in message object
|
||||
@@ -2367,7 +2366,7 @@ fn main() {{
|
||||
message.segments[0],
|
||||
MessageSegment::Text("Please explain this code".to_string())
|
||||
);
|
||||
assert_eq!(message.context, expected_context);
|
||||
assert_eq!(message.loaded_context.text, expected_context);
|
||||
|
||||
// Check message in request
|
||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
||||
@@ -2394,48 +2393,50 @@ fn main() {{
|
||||
let (_, _thread_store, thread, context_store) =
|
||||
setup_test_environment(cx, project.clone()).await;
|
||||
|
||||
// Open files individually
|
||||
// First message with context 1
|
||||
add_file_to_context(&project, &context_store, "test/file1.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Get the context objects
|
||||
let contexts = context_store.update(cx, |store, _| store.context().clone());
|
||||
assert_eq!(contexts.len(), 3);
|
||||
|
||||
// First message with context 1
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx))
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(new_contexts, &project, &None, cx))
|
||||
.await;
|
||||
let message1_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
|
||||
thread.insert_user_message("Message 1", loaded_context, None, cx)
|
||||
});
|
||||
|
||||
// Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
|
||||
add_file_to_context(&project, &context_store, "test/file2.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx))
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(new_contexts, &project, &None, cx))
|
||||
.await;
|
||||
let message2_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"Message 2",
|
||||
vec![contexts[0].clone(), contexts[1].clone()],
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
thread.insert_user_message("Message 2", loaded_context, None, cx)
|
||||
});
|
||||
|
||||
// Third message with all three contexts (contexts 1 and 2 should be skipped)
|
||||
//
|
||||
add_file_to_context(&project, &context_store, "test/file3.rs", cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let new_contexts = context_store.update(cx, |store, cx| {
|
||||
store.new_context_for_thread(thread.read(cx))
|
||||
});
|
||||
assert_eq!(new_contexts.len(), 1);
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(new_contexts, &project, &None, cx))
|
||||
.await;
|
||||
let message3_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(
|
||||
"Message 3",
|
||||
vec![
|
||||
contexts[0].clone(),
|
||||
contexts[1].clone(),
|
||||
contexts[2].clone(),
|
||||
],
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
thread.insert_user_message("Message 3", loaded_context, None, cx)
|
||||
});
|
||||
|
||||
// Check what contexts are included in each message
|
||||
@@ -2448,16 +2449,16 @@ fn main() {{
|
||||
});
|
||||
|
||||
// First message should include context 1
|
||||
assert!(message1.context.contains("file1.rs"));
|
||||
assert!(message1.loaded_context.text.contains("file1.rs"));
|
||||
|
||||
// Second message should include only context 2 (not 1)
|
||||
assert!(!message2.context.contains("file1.rs"));
|
||||
assert!(message2.context.contains("file2.rs"));
|
||||
assert!(!message2.loaded_context.text.contains("file1.rs"));
|
||||
assert!(message2.loaded_context.text.contains("file2.rs"));
|
||||
|
||||
// Third message should include only context 3 (not 1 or 2)
|
||||
assert!(!message3.context.contains("file1.rs"));
|
||||
assert!(!message3.context.contains("file2.rs"));
|
||||
assert!(message3.context.contains("file3.rs"));
|
||||
assert!(!message3.loaded_context.text.contains("file1.rs"));
|
||||
assert!(!message3.loaded_context.text.contains("file2.rs"));
|
||||
assert!(message3.loaded_context.text.contains("file3.rs"));
|
||||
|
||||
// Check entire request to make sure all contexts are properly included
|
||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
||||
@@ -2494,7 +2495,12 @@ fn main() {{
|
||||
|
||||
// Insert user message without any context (empty context vector)
|
||||
let message_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
|
||||
thread.insert_user_message(
|
||||
"What is the best way to learn Rust?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Check content and context in message object
|
||||
@@ -2507,7 +2513,7 @@ fn main() {{
|
||||
message.segments[0],
|
||||
MessageSegment::Text("What is the best way to learn Rust?".to_string())
|
||||
);
|
||||
assert_eq!(message.context, "");
|
||||
assert_eq!(message.loaded_context.text, "");
|
||||
|
||||
// Check message in request
|
||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
||||
@@ -2520,12 +2526,17 @@ fn main() {{
|
||||
|
||||
// Add second message, also without context
|
||||
let message2_id = thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Are there any good books?", vec![], None, cx)
|
||||
thread.insert_user_message(
|
||||
"Are there any good books?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let message2 =
|
||||
thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
|
||||
assert_eq!(message2.context, "");
|
||||
assert_eq!(message2.loaded_context.text, "");
|
||||
|
||||
// Check that both messages appear in the request
|
||||
let request = thread.update(cx, |thread, cx| thread.to_completion_request(cx));
|
||||
@@ -2559,12 +2570,14 @@ fn main() {{
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let context =
|
||||
context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
|
||||
let context = context_store.update(cx, |store, _| store.context().next().cloned().unwrap());
|
||||
let loaded_context = cx
|
||||
.update(|cx| load_context(vec![context], &project, &None, cx))
|
||||
.await;
|
||||
|
||||
// Insert user message with the buffer as context
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("Explain this code", vec![context], None, cx)
|
||||
thread.insert_user_message("Explain this code", loaded_context, None, cx)
|
||||
});
|
||||
|
||||
// Create a request and check that it doesn't have a stale buffer warning yet
|
||||
@@ -2592,7 +2605,12 @@ fn main() {{
|
||||
|
||||
// Insert another user message without context
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.insert_user_message("What does the code do now?", vec![], None, cx)
|
||||
thread.insert_user_message(
|
||||
"What does the code do now?",
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Create a new request and check for the stale buffer warning
|
||||
@@ -2659,6 +2677,7 @@ fn main() {{
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
None,
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
@@ -2683,15 +2702,15 @@ fn main() {{
|
||||
.unwrap();
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(buffer_path, cx))
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(buffer_path.clone(), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
context_store
|
||||
.update(cx, |store, cx| {
|
||||
store.add_file_from_buffer(buffer.clone(), cx)
|
||||
})
|
||||
.await?;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx);
|
||||
});
|
||||
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
|
||||
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
|
||||
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
|
||||
UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
@@ -82,12 +82,11 @@ impl ThreadStore {
|
||||
pub fn load(
|
||||
project: Entity<Project>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
let (thread_store, ready_rx) = cx.update(|cx| {
|
||||
let mut option_ready_rx = None;
|
||||
let thread_store = cx.new(|cx| {
|
||||
@@ -349,25 +348,8 @@ impl ThreadStore {
|
||||
self.context_server_manager.clone()
|
||||
}
|
||||
|
||||
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
|
||||
self.prompt_store.clone()
|
||||
}
|
||||
|
||||
pub fn load_rules(
|
||||
&self,
|
||||
prompt_id: UserPromptId,
|
||||
cx: &App,
|
||||
) -> Task<Result<(PromptMetadata, String)>> {
|
||||
let prompt_id = PromptId::User { uuid: prompt_id };
|
||||
let Some(prompt_store) = self.prompt_store.as_ref() else {
|
||||
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return Task::ready(Err(anyhow!("User rules not found in library.")));
|
||||
};
|
||||
let text_task = prompt_store.load(prompt_id, cx);
|
||||
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
|
||||
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Entity<ToolWorkingSet> {
|
||||
@@ -379,16 +361,12 @@ impl ThreadStore {
|
||||
self.threads.len()
|
||||
}
|
||||
|
||||
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
|
||||
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
|
||||
threads
|
||||
}
|
||||
|
||||
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
|
||||
self.threads().into_iter().take(limit).collect()
|
||||
}
|
||||
|
||||
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
|
||||
cx.new(|cx| {
|
||||
Thread::new(
|
||||
@@ -639,12 +617,17 @@ pub struct SerializedThread {
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
SerializedThreadV0_1_0::VERSION => {
|
||||
let saved_thread =
|
||||
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
|
||||
Ok(saved_thread.upgrade())
|
||||
}
|
||||
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
|
||||
saved_thread_json,
|
||||
)?),
|
||||
@@ -666,6 +649,38 @@ impl SerializedThread {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedThreadV0_1_0(
|
||||
// The structure did not change, so we are reusing the latest SerializedThread.
|
||||
// When making the next version, make sure this points to SerializedThreadV0_2_0
|
||||
SerializedThread,
|
||||
);
|
||||
|
||||
impl SerializedThreadV0_1_0 {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn upgrade(self) -> SerializedThread {
|
||||
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
|
||||
|
||||
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
|
||||
|
||||
for message in self.0.messages {
|
||||
if message.role == Role::User && !message.tool_results.is_empty() {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
|
||||
last_message.tool_results = message.tool_results;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
SerializedThread { messages, ..self.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
|
||||
@@ -30,7 +30,6 @@ pub struct ToolUse {
|
||||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
@@ -42,7 +41,6 @@ impl ToolUseState {
|
||||
Self {
|
||||
tools,
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
@@ -56,7 +54,6 @@ impl ToolUseState {
|
||||
pub fn from_serialized_messages(
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||
) -> Self {
|
||||
let mut this = Self::new(tools);
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
@@ -68,7 +65,6 @@ impl ToolUseState {
|
||||
let tool_uses = message
|
||||
.tool_uses
|
||||
.iter()
|
||||
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
|
||||
.map(|tool_use| LanguageModelToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone().into(),
|
||||
@@ -86,14 +82,6 @@ impl ToolUseState {
|
||||
|
||||
this.tool_uses_by_assistant_message
|
||||
.insert(message.id, tool_uses);
|
||||
}
|
||||
}
|
||||
Role::User => {
|
||||
if !message.tool_results.is_empty() {
|
||||
let tool_uses_by_user_message = this
|
||||
.tool_uses_by_user_message
|
||||
.entry(message.id)
|
||||
.or_default();
|
||||
|
||||
for tool_result in &message.tool_results {
|
||||
let tool_use_id = tool_result.tool_use_id.clone();
|
||||
@@ -102,11 +90,6 @@ impl ToolUseState {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !(filter_by_tool_name)(tool_use.as_ref()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||
this.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
@@ -119,7 +102,7 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::System => {}
|
||||
Role::System | Role::User => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,20 +212,26 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
let empty = Vec::new();
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
let Some(tool_uses) = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.unwrap_or(&empty)
|
||||
tool_uses
|
||||
.iter()
|
||||
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
|
||||
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
||||
@@ -294,14 +283,6 @@ impl ToolUseState {
|
||||
self.tool_use_metadata_by_id
|
||||
.insert(tool_use.id.clone(), metadata);
|
||||
|
||||
// The tool use is being requested by the Assistant, so we want to
|
||||
// attach the tool results to the next user message.
|
||||
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
|
||||
self.tool_uses_by_user_message
|
||||
.entry(next_user_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.id.clone());
|
||||
|
||||
PendingToolUseStatus::Idle
|
||||
} else {
|
||||
PendingToolUseStatus::InputStillStreaming
|
||||
@@ -450,7 +431,6 @@ impl ToolUseState {
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
dbg!(&self.tool_uses_by_assistant_message, &message_id);
|
||||
if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
|
||||
for tool_use in tool_uses {
|
||||
if self.tool_results.contains_key(&tool_use.id) {
|
||||
@@ -468,32 +448,49 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_tool_results(
|
||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.contains_key(&assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_results_message(
|
||||
&self,
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
dbg!(&self.tool_uses_by_user_message, &message_id);
|
||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
|
||||
for tool_use_id in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||
request_message.content.push(MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
assistant_message_id: MessageId,
|
||||
) -> Option<LanguageModelRequestMessage> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)?;
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
));
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Some(request_message)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
use std::{rc::Rc, time::Duration};
|
||||
|
||||
use file_icons::FileIcons;
|
||||
use futures::FutureExt;
|
||||
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
|
||||
use gpui::{ClickEvent, Task};
|
||||
use language_model::LanguageModelImage;
|
||||
use gpui::{
|
||||
Animation, AnimationExt as _, AnyView, ClickEvent, Entity, MouseButton, pulsating_between,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
|
||||
|
||||
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
|
||||
use crate::context::{AgentContext, ContextKind, ImageStatus};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub enum ContextPill {
|
||||
@@ -73,9 +74,7 @@ impl ContextPill {
|
||||
|
||||
pub fn id(&self) -> ElementId {
|
||||
match self {
|
||||
Self::Added { context, .. } => {
|
||||
ElementId::NamedInteger("context-pill".into(), context.id.0)
|
||||
}
|
||||
Self::Added { context, .. } => context.context.element_id("context-pill".into()),
|
||||
Self::Suggested { .. } => "suggested-context-pill".into(),
|
||||
}
|
||||
}
|
||||
@@ -170,14 +169,9 @@ impl RenderOnce for ContextPill {
|
||||
.when_some(
|
||||
context.render_preview.as_ref(),
|
||||
|element, render_preview| {
|
||||
element.hoverable_tooltip({
|
||||
let render_preview = render_preview.clone();
|
||||
move |_, cx| {
|
||||
cx.new(|_| ContextPillPreview {
|
||||
render_preview: render_preview.clone(),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
let render_preview = render_preview.clone();
|
||||
element.hoverable_tooltip(move |window, cx| {
|
||||
render_preview(window, cx)
|
||||
})
|
||||
},
|
||||
)
|
||||
@@ -199,14 +193,17 @@ impl RenderOnce for ContextPill {
|
||||
)
|
||||
.when_some(on_remove.as_ref(), |element, on_remove| {
|
||||
element.child(
|
||||
IconButton::new(("remove", context.id.0), IconName::Close)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
.on_click({
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
IconButton::new(
|
||||
context.context.element_id("remove".into()),
|
||||
IconName::Close,
|
||||
)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
.on_click({
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
@@ -262,23 +259,34 @@ pub enum ContextStatus {
|
||||
Error { message: SharedString },
|
||||
}
|
||||
|
||||
#[derive(RegisterComponent)]
|
||||
// TODO: Component commented out due to new dependency on `Project`.
|
||||
//
|
||||
// #[derive(RegisterComponent)]
|
||||
pub struct AddedContext {
|
||||
pub id: ContextId,
|
||||
pub context: AgentContext,
|
||||
pub kind: ContextKind,
|
||||
pub name: SharedString,
|
||||
pub parent: Option<SharedString>,
|
||||
pub tooltip: Option<SharedString>,
|
||||
pub icon_path: Option<SharedString>,
|
||||
pub status: ContextStatus,
|
||||
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement + 'static>>,
|
||||
pub render_preview: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyView + 'static>>,
|
||||
}
|
||||
|
||||
impl AddedContext {
|
||||
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
|
||||
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
|
||||
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
|
||||
///
|
||||
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
|
||||
pub fn new(
|
||||
context: AgentContext,
|
||||
prompt_store: Option<&Entity<PromptStore>>,
|
||||
project: &Project,
|
||||
cx: &App,
|
||||
) -> Option<AddedContext> {
|
||||
match context {
|
||||
AssistantContext::File(file_context) => {
|
||||
let full_path = file_context.context_buffer.full_path(cx);
|
||||
AgentContext::File(ref file_context) => {
|
||||
let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
@@ -289,8 +297,7 @@ impl AddedContext {
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
id: file_context.id,
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::File,
|
||||
name,
|
||||
parent,
|
||||
@@ -298,18 +305,16 @@ impl AddedContext {
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
}
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::Directory(directory_context) => {
|
||||
let worktree = directory_context.worktree.read(cx);
|
||||
// If the directory no longer exists, use its last known path.
|
||||
let full_path = worktree
|
||||
.entry_for_id(directory_context.entry_id)
|
||||
.map_or_else(
|
||||
|| directory_context.last_path.clone(),
|
||||
|entry| worktree.full_path(&entry.path).into(),
|
||||
);
|
||||
AgentContext::Directory(ref directory_context) => {
|
||||
let worktree = project
|
||||
.worktree_for_entry(directory_context.entry_id, cx)?
|
||||
.read(cx);
|
||||
let entry = worktree.entry_for_id(directory_context.entry_id)?;
|
||||
let full_path = worktree.full_path(&entry.path);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
@@ -320,8 +325,7 @@ impl AddedContext {
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
AddedContext {
|
||||
id: directory_context.id,
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Directory,
|
||||
name,
|
||||
parent,
|
||||
@@ -329,33 +333,34 @@ impl AddedContext {
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
}
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::Symbol(symbol_context) => AddedContext {
|
||||
id: symbol_context.id,
|
||||
AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
|
||||
kind: ContextKind::Symbol,
|
||||
name: symbol_context.context_symbol.id.name.clone(),
|
||||
name: symbol_context.symbol.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
},
|
||||
context,
|
||||
}),
|
||||
|
||||
AssistantContext::Selection(selection_context) => {
|
||||
let full_path = selection_context.context_buffer.full_path(cx);
|
||||
AgentContext::Selection(ref selection_context) => {
|
||||
let buffer = selection_context.buffer.read(cx);
|
||||
let full_path = buffer.file()?.full_path(cx);
|
||||
let mut full_path_string = full_path.to_string_lossy().into_owned();
|
||||
let mut name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
|
||||
let line_range_text = format!(
|
||||
" ({}-{})",
|
||||
selection_context.line_range.start.row + 1,
|
||||
selection_context.line_range.end.row + 1
|
||||
);
|
||||
let line_range = selection_context.range.to_point(&buffer.snapshot());
|
||||
|
||||
let line_range_text =
|
||||
format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
|
||||
|
||||
full_path_string.push_str(&line_range_text);
|
||||
name.push_str(&line_range_text);
|
||||
@@ -365,8 +370,7 @@ impl AddedContext {
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
|
||||
AddedContext {
|
||||
id: selection_context.id,
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Selection,
|
||||
name: name.into(),
|
||||
parent,
|
||||
@@ -374,22 +378,31 @@ impl AddedContext {
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: Some(Rc::new({
|
||||
let content = selection_context.context_buffer.text.clone();
|
||||
let buffer = selection_context.buffer.clone();
|
||||
let range = selection_context.range.clone();
|
||||
move |_, cx| {
|
||||
div()
|
||||
.id("context-pill-selection-preview")
|
||||
.overflow_scroll()
|
||||
.max_w_128()
|
||||
.max_h_96()
|
||||
.child(Label::new(content.clone()).buffer_font(cx))
|
||||
.into_any_element()
|
||||
let text: SharedString = buffer
|
||||
.read(cx)
|
||||
.text_for_range(range.clone())
|
||||
.collect::<String>()
|
||||
.into();
|
||||
ContextPillPreview::new(cx, move |_, cx| {
|
||||
div()
|
||||
.id("context-pill-selection-preview")
|
||||
.overflow_scroll()
|
||||
.max_w_128()
|
||||
.max_h_96()
|
||||
.child(Label::new(text.clone()).buffer_font(cx))
|
||||
.into_any_element()
|
||||
})
|
||||
.into()
|
||||
}
|
||||
})),
|
||||
}
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
|
||||
id: fetched_url_context.id,
|
||||
AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
|
||||
kind: ContextKind::FetchedUrl,
|
||||
name: fetched_url_context.url.clone(),
|
||||
parent: None,
|
||||
@@ -397,12 +410,12 @@ impl AddedContext {
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
},
|
||||
context,
|
||||
}),
|
||||
|
||||
AssistantContext::Thread(thread_context) => AddedContext {
|
||||
id: thread_context.id,
|
||||
AgentContext::Thread(ref thread_context) => Some(AddedContext {
|
||||
kind: ContextKind::Thread,
|
||||
name: thread_context.summary(cx),
|
||||
name: thread_context.name(cx),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
@@ -418,53 +431,74 @@ impl AddedContext {
|
||||
ContextStatus::Ready
|
||||
},
|
||||
render_preview: None,
|
||||
},
|
||||
context,
|
||||
}),
|
||||
|
||||
AssistantContext::Rules(user_rules_context) => AddedContext {
|
||||
id: user_rules_context.id,
|
||||
kind: ContextKind::Rules,
|
||||
name: user_rules_context.title.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
},
|
||||
AgentContext::Rules(ref user_rules_context) => {
|
||||
let name = prompt_store
|
||||
.as_ref()?
|
||||
.read(cx)
|
||||
.metadata(user_rules_context.prompt_id.into())?
|
||||
.title?;
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Rules,
|
||||
name: name.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
|
||||
AssistantContext::Image(image_context) => AddedContext {
|
||||
id: image_context.id,
|
||||
AgentContext::Image(ref image_context) => Some(AddedContext {
|
||||
kind: ContextKind::Image,
|
||||
name: "Image".into(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: if image_context.is_loading() {
|
||||
ContextStatus::Loading {
|
||||
status: match image_context.status() {
|
||||
ImageStatus::Loading => ContextStatus::Loading {
|
||||
message: "Loading…".into(),
|
||||
}
|
||||
} else if image_context.is_error() {
|
||||
ContextStatus::Error {
|
||||
},
|
||||
ImageStatus::Error => ContextStatus::Error {
|
||||
message: "Failed to load image".into(),
|
||||
}
|
||||
} else {
|
||||
ContextStatus::Ready
|
||||
},
|
||||
ImageStatus::Ready => ContextStatus::Ready,
|
||||
},
|
||||
render_preview: Some(Rc::new({
|
||||
let image = image_context.original_image.clone();
|
||||
move |_, _| {
|
||||
gpui::img(image.clone())
|
||||
.max_w_96()
|
||||
.max_h_96()
|
||||
.into_any_element()
|
||||
move |_, cx| {
|
||||
let image = image.clone();
|
||||
ContextPillPreview::new(cx, move |_, _| {
|
||||
gpui::img(image.clone())
|
||||
.max_w_96()
|
||||
.max_h_96()
|
||||
.into_any_element()
|
||||
})
|
||||
.into()
|
||||
}
|
||||
})),
|
||||
},
|
||||
context,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextPillPreview {
|
||||
render_preview: Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>,
|
||||
render_preview: Box<dyn Fn(&mut Window, &mut App) -> AnyElement>,
|
||||
}
|
||||
|
||||
impl ContextPillPreview {
|
||||
fn new(
|
||||
cx: &mut App,
|
||||
render_preview: impl Fn(&mut Window, &mut App) -> AnyElement + 'static,
|
||||
) -> Entity<Self> {
|
||||
cx.new(|_| Self {
|
||||
render_preview: Box::new(render_preview),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ContextPillPreview {
|
||||
@@ -478,6 +512,8 @@ impl Render for ContextPillPreview {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Component commented out due to new dependency on `Project`.
|
||||
/*
|
||||
impl Component for AddedContext {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
@@ -487,12 +523,13 @@ impl Component for AddedContext {
|
||||
"AddedContext"
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
let next_context_id = ContextId::zero();
|
||||
let image_ready = (
|
||||
"Ready",
|
||||
AddedContext::new(
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(0),
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
||||
}),
|
||||
@@ -503,8 +540,8 @@ impl Component for AddedContext {
|
||||
let image_loading = (
|
||||
"Loading",
|
||||
AddedContext::new(
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(1),
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: cx
|
||||
.background_spawn(async move {
|
||||
@@ -520,8 +557,8 @@ impl Component for AddedContext {
|
||||
let image_error = (
|
||||
"Error",
|
||||
AddedContext::new(
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(2),
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(None).shared(),
|
||||
}),
|
||||
@@ -544,5 +581,8 @@ impl Component for AddedContext {
|
||||
)
|
||||
.into_any(),
|
||||
)
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
mod supported_countries;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -11,8 +9,6 @@ use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
|
||||
@@ -1,225 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Returns whether the given country code is supported by Anthropic.
|
||||
///
|
||||
/// <https://www.anthropic.com/supported-countries>
|
||||
pub fn is_supported_country(country_code: &str) -> bool {
|
||||
SUPPORTED_COUNTRIES.contains(&country_code)
|
||||
}
|
||||
|
||||
/// The list of country codes supported by Anthropic.
|
||||
///
|
||||
/// https://www.anthropic.com/supported-countries
|
||||
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
vec![
|
||||
"AL", // Albania
|
||||
"DZ", // Algeria
|
||||
"AS", // American Samoa (US)
|
||||
"AD", // Andorra
|
||||
"AO", // Angola
|
||||
"AI", // Anguilla (UK)
|
||||
"AG", // Antigua and Barbuda
|
||||
"AR", // Argentina
|
||||
"AM", // Armenia
|
||||
"AU", // Australia
|
||||
"AT", // Austria
|
||||
"AZ", // Azerbaijan
|
||||
"BS", // Bahamas
|
||||
"BH", // Bahrain
|
||||
"BD", // Bangladesh
|
||||
"BB", // Barbados
|
||||
"BE", // Belgium
|
||||
"BZ", // Belize
|
||||
"BJ", // Benin
|
||||
"BM", // Bermuda (UK)
|
||||
"BT", // Bhutan
|
||||
"BO", // Bolivia
|
||||
"BA", // Bosnia and Herzegovina
|
||||
"BW", // Botswana
|
||||
"BR", // Brazil
|
||||
"IO", // British Indian Ocean Territory (UK)
|
||||
"BN", // Brunei
|
||||
"BG", // Bulgaria
|
||||
"BF", // Burkina Faso
|
||||
"BI", // Burundi
|
||||
"CV", // Cabo Verde
|
||||
"KH", // Cambodia
|
||||
"CM", // Cameroon
|
||||
"CA", // Canada
|
||||
"KY", // Cayman Islands (UK)
|
||||
"TD", // Chad
|
||||
"CL", // Chile
|
||||
"CX", // Christmas Island (AU)
|
||||
"CC", // Cocos (Keeling) Islands (AU)
|
||||
"CO", // Colombia
|
||||
"KM", // Comoros
|
||||
"CG", // Congo (Brazzaville)
|
||||
"CK", // Cook Islands (NZ)
|
||||
"CR", // Costa Rica
|
||||
"CI", // Côte d'Ivoire
|
||||
"HR", // Croatia
|
||||
"CY", // Cyprus
|
||||
"CZ", // Czechia (Czech Republic)
|
||||
"DK", // Denmark
|
||||
"DJ", // Djibouti
|
||||
"DM", // Dominica
|
||||
"DO", // Dominican Republic
|
||||
"EC", // Ecuador
|
||||
"EG", // Egypt
|
||||
"SV", // El Salvador
|
||||
"GQ", // Equatorial Guinea
|
||||
"EE", // Estonia
|
||||
"SZ", // Eswatini
|
||||
"FK", // Falkland Islands (UK)
|
||||
"FJ", // Fiji
|
||||
"FI", // Finland
|
||||
"FR", // France
|
||||
"GF", // French Guiana (FR)
|
||||
"PF", // French Polynesia (FR)
|
||||
"TF", // French Southern Territories
|
||||
"GA", // Gabon
|
||||
"GM", // Gambia
|
||||
"GE", // Georgia
|
||||
"DE", // Germany
|
||||
"GH", // Ghana
|
||||
"GI", // Gibraltar (UK)
|
||||
"GR", // Greece
|
||||
"GD", // Grenada
|
||||
"GT", // Guatemala
|
||||
"GU", // Guam (US)
|
||||
"GN", // Guinea
|
||||
"GW", // Guinea-Bissau
|
||||
"GY", // Guyana
|
||||
"HT", // Haiti
|
||||
"HM", // Heard Island and McDonald Islands (AU)
|
||||
"HN", // Honduras
|
||||
"HU", // Hungary
|
||||
"IS", // Iceland
|
||||
"IN", // India
|
||||
"ID", // Indonesia
|
||||
"IQ", // Iraq
|
||||
"IE", // Ireland
|
||||
"IL", // Israel
|
||||
"IT", // Italy
|
||||
"JM", // Jamaica
|
||||
"JP", // Japan
|
||||
"JO", // Jordan
|
||||
"KZ", // Kazakhstan
|
||||
"KE", // Kenya
|
||||
"KI", // Kiribati
|
||||
"KW", // Kuwait
|
||||
"KG", // Kyrgyzstan
|
||||
"LA", // Laos
|
||||
"LV", // Latvia
|
||||
"LB", // Lebanon
|
||||
"LS", // Lesotho
|
||||
"LR", // Liberia
|
||||
"LI", // Liechtenstein
|
||||
"LT", // Lithuania
|
||||
"LU", // Luxembourg
|
||||
"MG", // Madagascar
|
||||
"MW", // Malawi
|
||||
"MY", // Malaysia
|
||||
"MV", // Maldives
|
||||
"MT", // Malta
|
||||
"MH", // Marshall Islands
|
||||
"MR", // Mauritania
|
||||
"MU", // Mauritius
|
||||
"MX", // Mexico
|
||||
"FM", // Micronesia
|
||||
"MD", // Moldova
|
||||
"MC", // Monaco
|
||||
"MN", // Mongolia
|
||||
"MS", // Montserrat (UK)
|
||||
"ME", // Montenegro
|
||||
"MA", // Morocco
|
||||
"MZ", // Mozambique
|
||||
"NA", // Namibia
|
||||
"NR", // Nauru
|
||||
"NP", // Nepal
|
||||
"NL", // Netherlands
|
||||
"NZ", // New Zealand
|
||||
"NE", // Niger
|
||||
"NG", // Nigeria
|
||||
"NF", // Norfolk Island (AU)
|
||||
"MK", // North Macedonia
|
||||
"MI", // Northern Mariana Islands (UK)
|
||||
"NO", // Norway
|
||||
"NU", // Niue (NZ)
|
||||
"OM", // Oman
|
||||
"PK", // Pakistan
|
||||
"PW", // Palau
|
||||
"PS", // Palestine
|
||||
"PA", // Panama
|
||||
"PG", // Papua New Guinea
|
||||
"PY", // Paraguay
|
||||
"PE", // Peru
|
||||
"PH", // Philippines
|
||||
"PN", // Pitcairn (UK)
|
||||
"PL", // Poland
|
||||
"PT", // Portugal
|
||||
"PR", // Puerto Rico (US)
|
||||
"QA", // Qatar
|
||||
"RO", // Romania
|
||||
"RW", // Rwanda
|
||||
"BL", // Saint Barthélemy (FR)
|
||||
"KN", // Saint Kitts and Nevis
|
||||
"LC", // Saint Lucia
|
||||
"MF", // Saint Martin (FR)
|
||||
"PM", // Saint Pierre and Miquelon (FR)
|
||||
"VC", // Saint Vincent and the Grenadines
|
||||
"WS", // Samoa
|
||||
"SM", // San Marino
|
||||
"ST", // São Tomé and Príncipe
|
||||
"SA", // Saudi Arabia
|
||||
"SN", // Senegal
|
||||
"RS", // Serbia
|
||||
"SC", // Seychelles
|
||||
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
|
||||
"SL", // Sierra Leone
|
||||
"SG", // Singapore
|
||||
"SK", // Slovakia
|
||||
"SI", // Slovenia
|
||||
"SB", // Solomon Islands
|
||||
"ZA", // South Africa
|
||||
"KR", // South Korea
|
||||
"ES", // Spain
|
||||
"LK", // Sri Lanka
|
||||
"SR", // Suriname
|
||||
"SE", // Sweden
|
||||
"CH", // Switzerland
|
||||
"TW", // Taiwan
|
||||
"TJ", // Tajikistan
|
||||
"TZ", // Tanzania
|
||||
"TH", // Thailand
|
||||
"TL", // Timor-Leste
|
||||
"TG", // Togo
|
||||
"TK", // Tokelau (NZ)
|
||||
"TO", // Tonga
|
||||
"TT", // Trinidad and Tobago
|
||||
"TN", // Tunisia
|
||||
"TR", // Türkiye (Turkey)
|
||||
"TM", // Turkmenistan
|
||||
"TC", // Turks and Caicos Islands (UK)
|
||||
"TV", // Tuvalu
|
||||
"UG", // Uganda
|
||||
"UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
|
||||
"AE", // United Arab Emirates
|
||||
"GB", // United Kingdom
|
||||
"UM", // United States Minor Outlying Islands (US)
|
||||
"US", // United States of America
|
||||
"UY", // Uruguay
|
||||
"UZ", // Uzbekistan
|
||||
"VU", // Vanuatu
|
||||
"VA", // Vatican City
|
||||
"VN", // Vietnam
|
||||
"VI", // Virgin Islands (US)
|
||||
"VG", // Virgin Islands (UK)
|
||||
"WF", // Wallis and Futuna (FR)
|
||||
"ZM", // Zambia
|
||||
"ZW", // Zimbabwe
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
@@ -49,7 +49,7 @@ menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
prompt_library.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
rope.workspace = true
|
||||
|
||||
@@ -101,7 +101,7 @@ pub fn init(
|
||||
SlashCommandSettings::register(cx);
|
||||
|
||||
assistant_context_editor::init(client.clone(), cx);
|
||||
prompt_library::init(cx);
|
||||
rules_library::init(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -25,8 +25,8 @@ use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
|
||||
use prompt_store::{PromptBuilder, UserPromptId};
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
|
||||
use search::{BufferSearchBar, buffer_search::DivRegistrar};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -43,7 +43,7 @@ use workspace::{
|
||||
dock::{DockPosition, Panel, PanelEvent},
|
||||
pane,
|
||||
};
|
||||
use zed_actions::assistant::{InlineAssist, OpenPromptLibrary, ShowConfiguration, ToggleFocus};
|
||||
use zed_actions::assistant::{InlineAssist, OpenRulesLibrary, ShowConfiguration, ToggleFocus};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
|
||||
@@ -57,11 +57,11 @@ pub fn init(cx: &mut App) {
|
||||
.register_action(AssistantPanel::show_configuration)
|
||||
.register_action(AssistantPanel::create_new_context)
|
||||
.register_action(AssistantPanel::restart_context_servers)
|
||||
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_prompt_library(action, window, cx)
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -272,8 +272,8 @@ impl AssistantPanel {
|
||||
.action("New Chat", Box::new(NewChat))
|
||||
.action("History", Box::new(DeployHistory))
|
||||
.action(
|
||||
"Prompt Library",
|
||||
Box::new(OpenPromptLibrary::default()),
|
||||
"Rules Library",
|
||||
Box::new(OpenRulesLibrary::default()),
|
||||
)
|
||||
.action("Configure", Box::new(ShowConfiguration))
|
||||
.action(zoom_label, Box::new(ToggleZoom))
|
||||
@@ -1043,13 +1043,13 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn deploy_prompt_library(
|
||||
fn deploy_rules_library(
|
||||
&mut self,
|
||||
action: &OpenPromptLibrary,
|
||||
action: &OpenRulesLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_prompt_library(
|
||||
open_rules_library(
|
||||
self.languages.clone(),
|
||||
Box::new(PromptLibraryInlineAssist),
|
||||
Arc::new(|| {
|
||||
@@ -1059,9 +1059,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action.prompt_to_select.map(|uuid| PromptId::User {
|
||||
uuid: UserPromptId(uuid),
|
||||
}),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -1235,7 +1235,7 @@ impl Render for AssistantPanel {
|
||||
this.show_configuration_tab(window, cx)
|
||||
}))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_history))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_prompt_library))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_rules_library))
|
||||
.child(registrar.size_full().child(self.pane.clone()))
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -1350,13 +1350,13 @@ impl Focusable for AssistantPanel {
|
||||
|
||||
struct PromptLibraryInlineAssist;
|
||||
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assistant.assist(&prompt_editor, None, None, initial_prompt, window, cx)
|
||||
|
||||
@@ -18,11 +18,11 @@ use editor::{
|
||||
},
|
||||
};
|
||||
use feature_flags::{
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedPro,
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedProFeatureFlag,
|
||||
};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt,
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
channel::mpsc,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
join,
|
||||
@@ -1652,7 +1652,7 @@ impl Render for PromptEditor {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -1966,7 +1966,7 @@ impl PromptEditor {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedPro>()
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
@@ -3056,7 +3056,8 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
let chunks =
|
||||
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -10,6 +10,11 @@ pub fn adapt_schema_to_format(
|
||||
json: &mut Value,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<()> {
|
||||
if let Value::Object(obj) = json {
|
||||
obj.remove("$schema");
|
||||
obj.remove("title");
|
||||
}
|
||||
|
||||
match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
|
||||
@@ -30,10 +35,7 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
|
||||
for key in KEYS_TO_REMOVE {
|
||||
obj.remove(key);
|
||||
}
|
||||
obj.remove("format");
|
||||
|
||||
if let Some(default) = obj.get("default") {
|
||||
let is_null = default.is_null();
|
||||
|
||||
@@ -59,7 +59,6 @@ use crate::thinking_tool::ThinkingTool;
|
||||
pub use create_file_tool::CreateFileToolInput;
|
||||
pub use edit_file_tool::EditFileToolInput;
|
||||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use list_directory_tool::ListDirectoryToolInput;
|
||||
pub use read_file_tool::ReadFileToolInput;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
@@ -111,11 +110,38 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::Client;
|
||||
use clock::FakeSystemClock;
|
||||
use http_client::FakeHttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_json_schema() {
|
||||
#[derive(Serialize, JsonSchema)]
|
||||
struct GetWeatherTool {
|
||||
location: String,
|
||||
}
|
||||
|
||||
let schema = schema::json_schema_for::<GetWeatherTool>(
|
||||
language_model::LanguageModelToolSchemaFormat::JsonSchema,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
schema,
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
|
||||
@@ -14,7 +14,6 @@ pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod processed_stripe_events;
|
||||
pub mod projects;
|
||||
pub mod rate_buckets;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
pub mod users;
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
use super::*;
|
||||
use crate::db::tables::rate_buckets;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
impl Database {
|
||||
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
|
||||
/// than the currently saved timestamp.
|
||||
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
|
||||
if buckets.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
|
||||
rate_buckets::ActiveModel {
|
||||
user_id: ActiveValue::Set(bucket.user_id),
|
||||
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
|
||||
token_count: ActiveValue::Set(bucket.token_count),
|
||||
last_refill: ActiveValue::Set(bucket.last_refill),
|
||||
}
|
||||
}))
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
rate_buckets::Column::UserId,
|
||||
rate_buckets::Column::RateLimitName,
|
||||
])
|
||||
.update_columns([
|
||||
rate_buckets::Column::TokenCount,
|
||||
rate_buckets::Column::LastRefill,
|
||||
])
|
||||
.to_owned(),
|
||||
)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieves the rate limit for the given user and rate limit name.
|
||||
pub async fn get_rate_bucket(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
rate_limit_name: &str,
|
||||
) -> Result<Option<rate_buckets::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
let rate_limit = rate_buckets::Entity::find()
|
||||
.filter(rate_buckets::Column::UserId.eq(user_id))
|
||||
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(rate_limit)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,6 @@ pub mod project;
|
||||
pub mod project_collaborator;
|
||||
pub mod project_repository;
|
||||
pub mod project_repository_statuses;
|
||||
pub mod rate_buckets;
|
||||
pub mod room;
|
||||
pub mod room_participant;
|
||||
pub mod server;
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
use crate::db::UserId;
|
||||
use sea_orm::entity::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
|
||||
#[sea_orm(table_name = "rate_buckets")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub user_id: UserId,
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub rate_limit_name: String,
|
||||
pub token_count: i32,
|
||||
pub last_refill: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {
|
||||
#[sea_orm(
|
||||
belongs_to = "super::user::Entity",
|
||||
from = "Column::UserId",
|
||||
to = "super::user::Column::Id"
|
||||
)]
|
||||
User,
|
||||
}
|
||||
|
||||
impl Related<super::user::Entity> for Entity {
|
||||
fn to() -> RelationDef {
|
||||
Relation::User.def()
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveModelBehavior for ActiveModel {}
|
||||
@@ -6,7 +6,6 @@ pub mod env;
|
||||
pub mod executor;
|
||||
pub mod llm;
|
||||
pub mod migrations;
|
||||
mod rate_limiter;
|
||||
pub mod rpc;
|
||||
pub mod seed;
|
||||
pub mod stripe_billing;
|
||||
@@ -25,7 +24,6 @@ pub use cents::*;
|
||||
use db::{ChannelId, Database};
|
||||
use executor::Executor;
|
||||
use llm::db::LlmDatabase;
|
||||
pub use rate_limiter::*;
|
||||
use serde::Deserialize;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use util::ResultExt;
|
||||
@@ -295,7 +293,6 @@ pub struct AppState {
|
||||
pub blob_store_client: Option<aws_sdk_s3::Client>,
|
||||
pub stripe_client: Option<Arc<stripe::Client>>,
|
||||
pub stripe_billing: Option<Arc<StripeBilling>>,
|
||||
pub rate_limiter: Arc<RateLimiter>,
|
||||
pub executor: Executor,
|
||||
pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
|
||||
pub config: Config,
|
||||
@@ -348,7 +345,6 @@ impl AppState {
|
||||
.clone()
|
||||
.map(|stripe_client| Arc::new(StripeBilling::new(stripe_client))),
|
||||
stripe_client,
|
||||
rate_limiter: Arc::new(RateLimiter::new(db)),
|
||||
executor,
|
||||
kinesis_client: if config.kinesis_access_key.is_some() {
|
||||
build_kinesis_client(&config).await.log_err()
|
||||
|
||||
@@ -13,8 +13,8 @@ use collab::llm::db::LlmDatabase;
|
||||
use collab::migrations::run_database_migrations;
|
||||
use collab::user_backfiller::spawn_user_backfiller;
|
||||
use collab::{
|
||||
AppState, Config, RateLimiter, Result, api::fetch_extensions_from_blob_store_periodically, db,
|
||||
env, executor::Executor, rpc::ResultExt,
|
||||
AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env,
|
||||
executor::Executor, rpc::ResultExt,
|
||||
};
|
||||
use collab::{ServiceMode, api::billing::poll_stripe_events_periodically};
|
||||
use db::Database;
|
||||
@@ -111,10 +111,6 @@ async fn main() -> Result<()> {
|
||||
|
||||
if mode.is_collab() {
|
||||
state.db.purge_old_embeddings().await.trace_err();
|
||||
RateLimiter::save_periodically(
|
||||
state.rate_limiter.clone(),
|
||||
state.executor.clone(),
|
||||
);
|
||||
|
||||
let epoch = state
|
||||
.db
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
use crate::{Database, Error, Result, db::UserId, executor::Executor};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use rpc::ErrorCodeExt;
|
||||
use sea_orm::prelude::DateTimeUtc;
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub trait RateLimit: Send + Sync {
|
||||
fn capacity(&self) -> usize;
|
||||
fn refill_duration(&self) -> Duration;
|
||||
fn db_name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
/// Used to enforce per-user rate limits
|
||||
pub struct RateLimiter {
|
||||
buckets: DashMap<(UserId, String), RateBucket>,
|
||||
dirty_buckets: DashSet<(UserId, String)>,
|
||||
db: Arc<Database>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(db: Arc<Database>) -> Self {
|
||||
RateLimiter {
|
||||
buckets: DashMap::new(),
|
||||
dirty_buckets: DashSet::new(),
|
||||
db,
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a new task that periodically saves rate limit data to the database.
|
||||
pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
|
||||
const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
executor.clone().spawn_detached(async move {
|
||||
loop {
|
||||
executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
|
||||
rate_limiter.save().await.log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Returns an error if the user has exceeded the specified `RateLimit`.
|
||||
/// Attempts to read the from the database if no cached RateBucket currently exists.
|
||||
pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
|
||||
self.check_internal(limit, user_id, Utc::now()).await
|
||||
}
|
||||
|
||||
async fn check_internal(
|
||||
&self,
|
||||
limit: &dyn RateLimit,
|
||||
user_id: UserId,
|
||||
now: DateTimeUtc,
|
||||
) -> Result<()> {
|
||||
let bucket_key = (user_id, limit.db_name().to_string());
|
||||
|
||||
// Attempt to fetch the bucket from the database if it hasn't been cached.
|
||||
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
|
||||
// but this enforces limits across restarts so long as the database is reachable.
|
||||
if !self.buckets.contains_key(&bucket_key) {
|
||||
if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
|
||||
self.buckets.insert(bucket_key.clone(), bucket);
|
||||
self.dirty_buckets.insert(bucket_key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut bucket = self
|
||||
.buckets
|
||||
.entry(bucket_key.clone())
|
||||
.or_insert_with(|| RateBucket::new(limit, now));
|
||||
|
||||
if bucket.value_mut().allow(now) {
|
||||
self.dirty_buckets.insert(bucket_key);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(rpc::proto::ErrorCode::RateLimitExceeded
|
||||
.message("rate limit exceeded".into())
|
||||
.anyhow())?
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_bucket(
|
||||
&self,
|
||||
limit: &dyn RateLimit,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<RateBucket>, Error> {
|
||||
Ok(self
|
||||
.db
|
||||
.get_rate_bucket(user_id, limit.db_name())
|
||||
.await?
|
||||
.map(|saved_bucket| {
|
||||
RateBucket::from_db(
|
||||
limit,
|
||||
saved_bucket.token_count as usize,
|
||||
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn save(&self) -> Result<()> {
|
||||
let mut buckets = Vec::new();
|
||||
self.dirty_buckets.retain(|key| {
|
||||
if let Some(bucket) = self.buckets.get(key) {
|
||||
buckets.push(crate::db::rate_buckets::Model {
|
||||
user_id: key.0,
|
||||
rate_limit_name: key.1.clone(),
|
||||
token_count: bucket.token_count as i32,
|
||||
last_refill: bucket.last_refill.naive_utc(),
|
||||
});
|
||||
}
|
||||
false
|
||||
});
|
||||
|
||||
match self.db.save_rate_buckets(&buckets).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) => {
|
||||
for bucket in buckets {
|
||||
self.dirty_buckets
|
||||
.insert((bucket.user_id, bucket.rate_limit_name));
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct RateBucket {
|
||||
capacity: usize,
|
||||
token_count: usize,
|
||||
refill_time_per_token: Duration,
|
||||
last_refill: DateTimeUtc,
|
||||
}
|
||||
|
||||
impl RateBucket {
|
||||
fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
|
||||
Self {
|
||||
capacity: limit.capacity(),
|
||||
token_count: limit.capacity(),
|
||||
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
|
||||
last_refill: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
|
||||
Self {
|
||||
capacity: limit.capacity(),
|
||||
token_count,
|
||||
refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
|
||||
last_refill,
|
||||
}
|
||||
}
|
||||
|
||||
fn allow(&mut self, now: DateTimeUtc) -> bool {
|
||||
self.refill(now);
|
||||
if self.token_count > 0 {
|
||||
self.token_count -= 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self, now: DateTimeUtc) {
|
||||
let elapsed = now - self.last_refill;
|
||||
if elapsed >= self.refill_time_per_token {
|
||||
let new_tokens =
|
||||
elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
|
||||
self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
|
||||
|
||||
let unused_refill_time = Duration::milliseconds(
|
||||
elapsed.num_milliseconds() % self.refill_time_per_token.num_milliseconds(),
|
||||
);
|
||||
self.last_refill = now - unused_refill_time;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::db::{NewUserParams, TestDb};
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_rate_limiter(cx: &mut TestAppContext) {
|
||||
let test_db = TestDb::sqlite(cx.executor().clone());
|
||||
let db = test_db.db().clone();
|
||||
let user_1 = db
|
||||
.create_user(
|
||||
"user-1@zed.dev",
|
||||
None,
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user-1".into(),
|
||||
github_user_id: 1,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
let user_2 = db
|
||||
.create_user(
|
||||
"user-2@zed.dev",
|
||||
None,
|
||||
false,
|
||||
NewUserParams {
|
||||
github_login: "user-2".into(),
|
||||
github_user_id: 2,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.user_id;
|
||||
|
||||
let mut now = Utc::now();
|
||||
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
let rate_limit_a = Box::new(RateLimitA);
|
||||
let rate_limit_b = Box::new(RateLimitB);
|
||||
|
||||
// User 1 can access resource A two times before being rate-limited.
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// User 2 can access resource A and user 1 can access resource B.
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_b, user_2, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_b, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// After 1.5s, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::milliseconds(1500);
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// After 500ms, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::milliseconds(500);
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
rate_limiter.save().await.unwrap();
|
||||
|
||||
// Rate limits are reloaded from the database, so user A is still rate-limited
|
||||
// for resource A.
|
||||
let rate_limiter = RateLimiter::new(db.clone());
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// After 1s, user 1 can make another request before being rate-limited again.
|
||||
now += Duration::seconds(1);
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap();
|
||||
rate_limiter
|
||||
.check_internal(&*rate_limit_a, user_1, now)
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
struct RateLimitA;
|
||||
|
||||
impl RateLimit for RateLimitA {
|
||||
fn capacity(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> Duration {
|
||||
Duration::seconds(2)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"rate-limit-a"
|
||||
}
|
||||
}
|
||||
|
||||
struct RateLimitB;
|
||||
|
||||
impl RateLimit for RateLimitB {
|
||||
fn capacity(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn refill_duration(&self) -> Duration {
|
||||
Duration::seconds(3)
|
||||
}
|
||||
|
||||
fn db_name(&self) -> &'static str {
|
||||
"rate-limit-b"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
mod connection_pool;
|
||||
|
||||
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
|
||||
use crate::db::billing_subscription::SubscriptionKind;
|
||||
use crate::llm::LlmTokenClaims;
|
||||
use crate::{
|
||||
AppState, Error, Result, auth,
|
||||
@@ -178,15 +179,23 @@ impl Session {
|
||||
Ok(db.has_active_billing_subscription(user_id).await?)
|
||||
}
|
||||
|
||||
pub async fn current_plan(
|
||||
&self,
|
||||
_db: &MutexGuard<'_, DbHandle>,
|
||||
) -> anyhow::Result<proto::Plan> {
|
||||
if self.is_staff() {
|
||||
Ok(proto::Plan::ZedPro)
|
||||
pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
|
||||
let user_id = self.user_id();
|
||||
|
||||
let subscription = db.get_active_billing_subscription(user_id).await?;
|
||||
let subscription_kind = subscription.and_then(|subscription| subscription.kind);
|
||||
|
||||
let plan = if let Some(subscription_kind) = subscription_kind {
|
||||
match subscription_kind {
|
||||
SubscriptionKind::ZedPro => proto::Plan::ZedPro,
|
||||
SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
SubscriptionKind::ZedFree => proto::Plan::Free,
|
||||
}
|
||||
} else {
|
||||
Ok(proto::Plan::Free)
|
||||
}
|
||||
proto::Plan::Free
|
||||
};
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
fn user_id(&self) -> UserId {
|
||||
|
||||
@@ -46,7 +46,7 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
|
||||
let mut first_user = None;
|
||||
let mut others = vec![];
|
||||
|
||||
let flag_names = ["remoting", "language-models"];
|
||||
let flag_names = ["language-models"];
|
||||
let mut flags = Vec::new();
|
||||
|
||||
let existing_feature_flags = db.list_feature_flags().await?;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
AppState, Config, RateLimiter,
|
||||
AppState, Config,
|
||||
db::{NewUserParams, UserId, tests::TestDb},
|
||||
executor::Executor,
|
||||
rpc::{CLEANUP_TIMEOUT, Principal, RECONNECT_TIMEOUT, Server, ZedVersion},
|
||||
@@ -517,7 +517,6 @@ impl TestServer {
|
||||
blob_store_client: None,
|
||||
stripe_client: None,
|
||||
stripe_billing: None,
|
||||
rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
|
||||
executor,
|
||||
kinesis_client: None,
|
||||
config: Config {
|
||||
|
||||
@@ -552,7 +552,9 @@ impl TcpTransport {
|
||||
let host = connection_args.host;
|
||||
let port = connection_args.port;
|
||||
|
||||
let mut command = util::command::new_smol_command(&binary.command);
|
||||
let mut command = util::command::new_std_command(&binary.command);
|
||||
util::set_pre_exec_to_start_new_session(&mut command);
|
||||
let mut command = smol::process::Command::from(command);
|
||||
|
||||
if let Some(cwd) = &binary.cwd {
|
||||
command.current_dir(cwd);
|
||||
@@ -636,7 +638,9 @@ pub struct StdioTransport {
|
||||
impl StdioTransport {
|
||||
#[allow(dead_code, reason = "This is used in non test builds of Zed")]
|
||||
async fn start(binary: &DebugAdapterBinary, _: AsyncApp) -> Result<(TransportPipe, Self)> {
|
||||
let mut command = util::command::new_smol_command(&binary.command);
|
||||
let mut command = util::command::new_std_command(&binary.command);
|
||||
util::set_pre_exec_to_start_new_session(&mut command);
|
||||
let mut command = smol::process::Command::from(command);
|
||||
|
||||
if let Some(cwd) = &binary.cwd {
|
||||
command.current_dir(cwd);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::persistence::DebuggerPaneItem;
|
||||
use crate::{
|
||||
ClearAllBreakpoints, Continue, CreateDebuggingSession, Disconnect, Pause, Restart, StepBack,
|
||||
StepInto, StepOut, StepOver, Stop, ToggleIgnoreBreakpoints, persistence,
|
||||
@@ -32,8 +33,10 @@ use settings::Settings;
|
||||
use std::any::TypeId;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use task::{DebugTaskDefinition, DebugTaskTemplate};
|
||||
use terminal_view::terminal_panel::TerminalPanel;
|
||||
use task::{
|
||||
DebugTaskDefinition, DebugTaskTemplate, HideStrategy, RevealStrategy, RevealTarget, TaskId,
|
||||
};
|
||||
use terminal_view::TerminalView;
|
||||
use ui::{ContextMenu, Divider, DropdownMenu, Tooltip, prelude::*};
|
||||
use workspace::{
|
||||
Workspace,
|
||||
@@ -293,23 +296,21 @@ impl DebugPanel {
|
||||
|
||||
(session.clone(), dap_store.boot_session(session, cx))
|
||||
})?;
|
||||
Self::register_session(this.clone(), session.clone(), cx).await?;
|
||||
|
||||
match task.await {
|
||||
Err(e) => {
|
||||
this.update(cx, |this, cx| {
|
||||
this.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.show_error(&e, cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.ok();
|
||||
if let Err(e) = task.await {
|
||||
this.update(cx, |this, cx| {
|
||||
this.workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
workspace.show_error(&e, cx);
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.ok();
|
||||
|
||||
session
|
||||
.update(cx, |session, cx| session.shutdown(cx))?
|
||||
.await;
|
||||
}
|
||||
Ok(_) => Self::register_session(this, session, cx).await?,
|
||||
session
|
||||
.update(cx, |session, cx| session.shutdown(cx))?
|
||||
.await;
|
||||
}
|
||||
|
||||
anyhow::Ok(())
|
||||
@@ -467,6 +468,7 @@ impl DebugPanel {
|
||||
) {
|
||||
match event {
|
||||
dap_store::DapStoreEvent::RunInTerminal {
|
||||
session_id,
|
||||
title,
|
||||
cwd,
|
||||
command,
|
||||
@@ -476,6 +478,7 @@ impl DebugPanel {
|
||||
..
|
||||
} => {
|
||||
self.handle_run_in_terminal_request(
|
||||
*session_id,
|
||||
title.clone(),
|
||||
cwd.clone(),
|
||||
command.clone(),
|
||||
@@ -499,6 +502,7 @@ impl DebugPanel {
|
||||
|
||||
fn handle_run_in_terminal_request(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
title: Option<String>,
|
||||
cwd: Option<Arc<Path>>,
|
||||
command: Option<String>,
|
||||
@@ -506,56 +510,83 @@ impl DebugPanel {
|
||||
envs: HashMap<String, String>,
|
||||
mut sender: mpsc::Sender<Result<u32>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
let terminal_task = self.workspace.update(cx, |workspace, cx| {
|
||||
let terminal_panel = workspace.panel::<TerminalPanel>(cx).ok_or_else(|| {
|
||||
anyhow!("RunInTerminal DAP request failed because TerminalPanel wasn't found")
|
||||
});
|
||||
|
||||
let terminal_panel = match terminal_panel {
|
||||
Ok(panel) => panel,
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
};
|
||||
|
||||
terminal_panel.update(cx, |terminal_panel, cx| {
|
||||
let terminal_task = terminal_panel.add_terminal(
|
||||
TerminalKind::Debug {
|
||||
command,
|
||||
args,
|
||||
envs,
|
||||
cwd,
|
||||
title,
|
||||
},
|
||||
task::RevealStrategy::Never,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let pid_task = async move {
|
||||
let terminal = terminal_task.await?;
|
||||
|
||||
terminal.read_with(cx, |terminal, _| terminal.pty_info.pid())
|
||||
};
|
||||
|
||||
pid_task.await
|
||||
})
|
||||
let Some(session) = self
|
||||
.sessions
|
||||
.iter()
|
||||
.find(|s| s.read(cx).session_id(cx) == session_id)
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("no session {:?} found", session_id)));
|
||||
};
|
||||
let running = session.read(cx).running_state();
|
||||
let cwd = cwd.map(|p| p.to_path_buf());
|
||||
let shell = self
|
||||
.project
|
||||
.read(cx)
|
||||
.terminal_settings(&cwd, cx)
|
||||
.shell
|
||||
.clone();
|
||||
let kind = if let Some(command) = command {
|
||||
let title = title.clone().unwrap_or(command.clone());
|
||||
TerminalKind::Task(task::SpawnInTerminal {
|
||||
id: TaskId("debug".to_string()),
|
||||
full_label: title.clone(),
|
||||
label: title.clone(),
|
||||
command: command.clone(),
|
||||
args,
|
||||
command_label: title.clone(),
|
||||
cwd,
|
||||
env: envs,
|
||||
use_new_terminal: true,
|
||||
allow_concurrent_runs: true,
|
||||
reveal: RevealStrategy::NoFocus,
|
||||
reveal_target: RevealTarget::Dock,
|
||||
hide: HideStrategy::Never,
|
||||
shell,
|
||||
show_summary: false,
|
||||
show_command: false,
|
||||
show_rerun: false,
|
||||
})
|
||||
} else {
|
||||
TerminalKind::Shell(cwd.map(|c| c.to_path_buf()))
|
||||
};
|
||||
|
||||
let workspace = self.workspace.clone();
|
||||
let project = self.project.downgrade();
|
||||
|
||||
let terminal_task = self.project.update(cx, |project, cx| {
|
||||
project.create_terminal(kind, window.window_handle(), cx)
|
||||
});
|
||||
let terminal_task = cx.spawn_in(window, async move |_, cx| {
|
||||
let terminal = terminal_task.await?;
|
||||
|
||||
let terminal_view = cx.new_window_entity(|window, cx| {
|
||||
TerminalView::new(terminal.clone(), workspace, None, project, window, cx)
|
||||
})?;
|
||||
|
||||
running.update_in(cx, |running, window, cx| {
|
||||
running.ensure_pane_item(DebuggerPaneItem::Terminal, window, cx);
|
||||
running.debug_terminal.update(cx, |debug_terminal, cx| {
|
||||
debug_terminal.terminal = Some(terminal_view);
|
||||
cx.notify();
|
||||
});
|
||||
})?;
|
||||
|
||||
anyhow::Ok(terminal.read_with(cx, |terminal, _| terminal.pty_info.pid())?)
|
||||
});
|
||||
|
||||
cx.background_spawn(async move {
|
||||
match terminal_task {
|
||||
Ok(pid_task) => match pid_task.await {
|
||||
Ok(Some(pid)) => sender.send(Ok(pid.as_u32())).await?,
|
||||
Ok(None) => {
|
||||
match terminal_task.await {
|
||||
Ok(pid_task) => match pid_task {
|
||||
Some(pid) => sender.send(Ok(pid.as_u32())).await?,
|
||||
None => {
|
||||
sender
|
||||
.send(Err(anyhow!(
|
||||
"Terminal was spawned but PID was not available"
|
||||
)))
|
||||
.await?
|
||||
}
|
||||
Err(error) => sender.send(Err(anyhow!(error))).await?,
|
||||
},
|
||||
Err(error) => sender.send(Err(anyhow!(error))).await?,
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use dap::debugger_settings::DebuggerSettings;
|
||||
use debugger_panel::{DebugPanel, ToggleFocus};
|
||||
use editor::Editor;
|
||||
use feature_flags::{Debugger, FeatureFlagViewExt};
|
||||
use feature_flags::{DebuggerFeatureFlag, FeatureFlagViewExt};
|
||||
use gpui::{App, EntityInputHandler, actions};
|
||||
use new_session_modal::NewSessionModal;
|
||||
use project::debugger::{self, breakpoint_store::SourceBreakpoint};
|
||||
@@ -47,7 +47,7 @@ pub fn init(cx: &mut App) {
|
||||
return;
|
||||
};
|
||||
|
||||
cx.when_flag_enabled::<Debugger>(window, |workspace, _, _| {
|
||||
cx.when_flag_enabled::<DebuggerFeatureFlag>(window, |workspace, _, _| {
|
||||
workspace
|
||||
.register_action(|workspace, _: &ToggleFocus, window, cx| {
|
||||
workspace.toggle_panel_focus::<DebugPanel>(window, cx);
|
||||
|
||||
@@ -9,7 +9,7 @@ use util::ResultExt;
|
||||
use workspace::{Member, Pane, PaneAxis, Workspace};
|
||||
|
||||
use crate::session::running::{
|
||||
self, RunningState, SubView, breakpoint_list::BreakpointList, console::Console,
|
||||
self, DebugTerminal, RunningState, SubView, breakpoint_list::BreakpointList, console::Console,
|
||||
loaded_source_list::LoadedSourceList, module_list::ModuleList,
|
||||
stack_frame_list::StackFrameList, variable_list::VariableList,
|
||||
};
|
||||
@@ -22,6 +22,7 @@ pub(crate) enum DebuggerPaneItem {
|
||||
Frames,
|
||||
Modules,
|
||||
LoadedSources,
|
||||
Terminal,
|
||||
}
|
||||
|
||||
impl DebuggerPaneItem {
|
||||
@@ -33,6 +34,7 @@ impl DebuggerPaneItem {
|
||||
DebuggerPaneItem::Frames,
|
||||
DebuggerPaneItem::Modules,
|
||||
DebuggerPaneItem::LoadedSources,
|
||||
DebuggerPaneItem::Terminal,
|
||||
];
|
||||
VARIANTS
|
||||
}
|
||||
@@ -55,6 +57,7 @@ impl DebuggerPaneItem {
|
||||
DebuggerPaneItem::Frames => SharedString::new_static("Frames"),
|
||||
DebuggerPaneItem::Modules => SharedString::new_static("Modules"),
|
||||
DebuggerPaneItem::LoadedSources => SharedString::new_static("Sources"),
|
||||
DebuggerPaneItem::Terminal => SharedString::new_static("Terminal"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,6 +172,7 @@ pub(crate) fn deserialize_pane_layout(
|
||||
console: &Entity<Console>,
|
||||
breakpoint_list: &Entity<BreakpointList>,
|
||||
loaded_sources: &Entity<LoadedSourceList>,
|
||||
terminal: &Entity<DebugTerminal>,
|
||||
subscriptions: &mut HashMap<EntityId, Subscription>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<RunningState>,
|
||||
@@ -191,6 +195,7 @@ pub(crate) fn deserialize_pane_layout(
|
||||
console,
|
||||
breakpoint_list,
|
||||
loaded_sources,
|
||||
terminal,
|
||||
subscriptions,
|
||||
window,
|
||||
cx,
|
||||
@@ -273,6 +278,13 @@ pub(crate) fn deserialize_pane_layout(
|
||||
})),
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Terminal => Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
terminal.clone().into(),
|
||||
DebuggerPaneItem::Terminal,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -104,6 +104,12 @@ impl DebugSession {
|
||||
&self.mode
|
||||
}
|
||||
|
||||
pub(crate) fn running_state(&self) -> Entity<RunningState> {
|
||||
match &self.mode {
|
||||
DebugSessionState::Running(running_state) => running_state.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn label(&self, cx: &App) -> String {
|
||||
if let Some(label) = self.label.get() {
|
||||
return label.to_owned();
|
||||
|
||||
@@ -27,6 +27,7 @@ use project::{
|
||||
use rpc::proto::ViewId;
|
||||
use settings::Settings;
|
||||
use stack_frame_list::StackFrameList;
|
||||
use terminal_view::TerminalView;
|
||||
use ui::{
|
||||
ActiveTheme, AnyElement, App, Context, ContextMenu, DropdownMenu, FluentBuilder,
|
||||
InteractiveElement, IntoElement, Label, LabelCommon as _, ParentElement, Render, SharedString,
|
||||
@@ -35,7 +36,7 @@ use ui::{
|
||||
use util::ResultExt;
|
||||
use variable_list::VariableList;
|
||||
use workspace::{
|
||||
ActivePaneDecorator, DraggedTab, Item, Member, Pane, PaneGroup, Workspace,
|
||||
ActivePaneDecorator, DraggedTab, Item, ItemHandle, Member, Pane, PaneGroup, Workspace,
|
||||
item::TabContentParams, move_item, pane::Event,
|
||||
};
|
||||
|
||||
@@ -50,6 +51,7 @@ pub struct RunningState {
|
||||
_subscriptions: Vec<Subscription>,
|
||||
stack_frame_list: Entity<stack_frame_list::StackFrameList>,
|
||||
loaded_sources_list: Entity<LoadedSourceList>,
|
||||
pub debug_terminal: Entity<DebugTerminal>,
|
||||
module_list: Entity<module_list::ModuleList>,
|
||||
_console: Entity<Console>,
|
||||
breakpoint_list: Entity<BreakpointList>,
|
||||
@@ -364,6 +366,40 @@ pub(crate) fn new_debugger_pane(
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
pub struct DebugTerminal {
|
||||
pub terminal: Option<Entity<TerminalView>>,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
|
||||
impl DebugTerminal {
|
||||
fn empty(cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
terminal: None,
|
||||
focus_handle: cx.focus_handle(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl gpui::Render for DebugTerminal {
|
||||
fn render(&mut self, _window: &mut Window, _: &mut Context<Self>) -> impl IntoElement {
|
||||
if let Some(terminal) = self.terminal.clone() {
|
||||
terminal.into_any_element()
|
||||
} else {
|
||||
div().track_focus(&self.focus_handle).into_any_element()
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Focusable for DebugTerminal {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
if let Some(terminal) = self.terminal.as_ref() {
|
||||
return terminal.focus_handle(cx);
|
||||
} else {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RunningState {
|
||||
pub fn new(
|
||||
session: Entity<Session>,
|
||||
@@ -380,6 +416,8 @@ impl RunningState {
|
||||
StackFrameList::new(workspace.clone(), session.clone(), weak_state, window, cx)
|
||||
});
|
||||
|
||||
let debug_terminal = cx.new(DebugTerminal::empty);
|
||||
|
||||
let variable_list =
|
||||
cx.new(|cx| VariableList::new(session.clone(), stack_frame_list.clone(), window, cx));
|
||||
|
||||
@@ -452,6 +490,7 @@ impl RunningState {
|
||||
&console,
|
||||
&breakpoint_list,
|
||||
&loaded_source_list,
|
||||
&debug_terminal,
|
||||
&mut pane_close_subscriptions,
|
||||
window,
|
||||
cx,
|
||||
@@ -494,6 +533,7 @@ impl RunningState {
|
||||
breakpoint_list,
|
||||
loaded_sources_list: loaded_source_list,
|
||||
pane_close_subscriptions,
|
||||
debug_terminal,
|
||||
_schedule_serialize: None,
|
||||
}
|
||||
}
|
||||
@@ -525,6 +565,90 @@ impl RunningState {
|
||||
self.panes.pane_at_pixel_position(position).is_some()
|
||||
}
|
||||
|
||||
fn create_sub_view(
|
||||
&self,
|
||||
item_kind: DebuggerPaneItem,
|
||||
pane: &Entity<Pane>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Box<dyn ItemHandle> {
|
||||
match item_kind {
|
||||
DebuggerPaneItem::Console => {
|
||||
let weak_console = self._console.clone().downgrade();
|
||||
|
||||
Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
self._console.clone().into(),
|
||||
item_kind,
|
||||
Some(Box::new(move |cx| {
|
||||
weak_console
|
||||
.read_with(cx, |console, cx| console.show_indicator(cx))
|
||||
.unwrap_or_default()
|
||||
})),
|
||||
cx,
|
||||
))
|
||||
}
|
||||
DebuggerPaneItem::Variables => Box::new(SubView::new(
|
||||
self.variable_list.focus_handle(cx),
|
||||
self.variable_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::BreakpointList => Box::new(SubView::new(
|
||||
self.breakpoint_list.focus_handle(cx),
|
||||
self.breakpoint_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Frames => Box::new(SubView::new(
|
||||
self.stack_frame_list.focus_handle(cx),
|
||||
self.stack_frame_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Modules => Box::new(SubView::new(
|
||||
self.module_list.focus_handle(cx),
|
||||
self.module_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
|
||||
self.loaded_sources_list.focus_handle(cx),
|
||||
self.loaded_sources_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Terminal => Box::new(SubView::new(
|
||||
self.debug_terminal.focus_handle(cx),
|
||||
self.debug_terminal.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ensure_pane_item(
|
||||
&mut self,
|
||||
item_kind: DebuggerPaneItem,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.pane_items_status(cx).get(&item_kind) == Some(&true) {
|
||||
return;
|
||||
};
|
||||
let pane = self.panes.last_pane();
|
||||
let sub_view = self.create_sub_view(item_kind, &pane, cx);
|
||||
|
||||
pane.update(cx, |pane, cx| {
|
||||
pane.add_item_inner(sub_view, false, false, false, None, window, cx);
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn add_pane_item(
|
||||
&mut self,
|
||||
item_kind: DebuggerPaneItem,
|
||||
@@ -538,58 +662,7 @@ impl RunningState {
|
||||
);
|
||||
|
||||
if let Some(pane) = self.panes.pane_at_pixel_position(position) {
|
||||
let sub_view = match item_kind {
|
||||
DebuggerPaneItem::Console => {
|
||||
let weak_console = self._console.clone().downgrade();
|
||||
|
||||
Box::new(SubView::new(
|
||||
pane.focus_handle(cx),
|
||||
self._console.clone().into(),
|
||||
item_kind,
|
||||
Some(Box::new(move |cx| {
|
||||
weak_console
|
||||
.read_with(cx, |console, cx| console.show_indicator(cx))
|
||||
.unwrap_or_default()
|
||||
})),
|
||||
cx,
|
||||
))
|
||||
}
|
||||
DebuggerPaneItem::Variables => Box::new(SubView::new(
|
||||
self.variable_list.focus_handle(cx),
|
||||
self.variable_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::BreakpointList => Box::new(SubView::new(
|
||||
self.breakpoint_list.focus_handle(cx),
|
||||
self.breakpoint_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Frames => Box::new(SubView::new(
|
||||
self.stack_frame_list.focus_handle(cx),
|
||||
self.stack_frame_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::Modules => Box::new(SubView::new(
|
||||
self.module_list.focus_handle(cx),
|
||||
self.module_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
DebuggerPaneItem::LoadedSources => Box::new(SubView::new(
|
||||
self.loaded_sources_list.focus_handle(cx),
|
||||
self.loaded_sources_list.clone().into(),
|
||||
item_kind,
|
||||
None,
|
||||
cx,
|
||||
)),
|
||||
};
|
||||
let sub_view = self.create_sub_view(item_kind, pane, cx);
|
||||
|
||||
pane.update(cx, |pane, cx| {
|
||||
pane.add_item(sub_view, false, false, None, window, cx);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{tests::start_debug_session, *};
|
||||
use crate::{persistence::DebuggerPaneItem, tests::start_debug_session, *};
|
||||
use dap::{
|
||||
ErrorResponse, Message, RunInTerminalRequestArguments, SourceBreakpoint,
|
||||
StartDebuggingRequestArguments, StartDebuggingRequestArgumentsRequest,
|
||||
@@ -25,7 +25,7 @@ use std::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
};
|
||||
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
|
||||
use terminal_view::terminal_panel::TerminalPanel;
|
||||
use tests::{active_debug_session_panel, init_test, init_test_workspace};
|
||||
use util::path;
|
||||
use workspace::{Item, dock::Panel};
|
||||
@@ -385,22 +385,17 @@ async fn test_handle_successful_run_in_terminal_reverse_request(
|
||||
|
||||
workspace
|
||||
.update(cx, |workspace, _window, cx| {
|
||||
let terminal_panel = workspace.panel::<TerminalPanel>(cx).unwrap();
|
||||
|
||||
let panel = terminal_panel.read(cx).pane().unwrap().read(cx);
|
||||
|
||||
assert_eq!(1, panel.items_len());
|
||||
assert!(
|
||||
panel
|
||||
.active_item()
|
||||
.unwrap()
|
||||
.downcast::<TerminalView>()
|
||||
.unwrap()
|
||||
let debug_panel = workspace.panel::<DebugPanel>(cx).unwrap();
|
||||
let session = debug_panel.read(cx).active_session().unwrap();
|
||||
let running = session.read(cx).running_state();
|
||||
assert_eq!(
|
||||
running
|
||||
.read(cx)
|
||||
.terminal()
|
||||
.read(cx)
|
||||
.debug_terminal()
|
||||
.pane_items_status(cx)
|
||||
.get(&DebuggerPaneItem::Terminal),
|
||||
Some(&true)
|
||||
);
|
||||
assert!(running.read(cx).debug_terminal.read(cx).terminal.is_some());
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use feature_flags::{Debugger, FeatureFlagAppExt as _};
|
||||
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt as _};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate};
|
||||
use gpui::{
|
||||
AnyElement, BackgroundExecutor, Entity, Focusable, FontWeight, ListSizingBehavior,
|
||||
@@ -812,7 +812,7 @@ impl CodeActionContents {
|
||||
actions: Option<Rc<[AvailableCodeAction]>>,
|
||||
cx: &App,
|
||||
) -> Self {
|
||||
if !cx.has_flag::<Debugger>() {
|
||||
if !cx.has_flag::<DebuggerFeatureFlag>() {
|
||||
if let Some(tasks) = &mut tasks {
|
||||
tasks
|
||||
.templates
|
||||
|
||||
@@ -71,25 +71,26 @@ use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layo
|
||||
pub use element::{
|
||||
CursorLayout, EditorElement, HighlightedRange, HighlightedRangeLine, PointForPosition,
|
||||
};
|
||||
use feature_flags::{Debugger, FeatureFlagAppExt};
|
||||
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt};
|
||||
use futures::{
|
||||
FutureExt,
|
||||
future::{self, Shared, join},
|
||||
};
|
||||
use fuzzy::StringMatchCandidate;
|
||||
|
||||
use ::git::Restore;
|
||||
use ::git::blame::BlameEntry;
|
||||
use ::git::{Restore, blame::ParsedCommitMessage};
|
||||
use code_context_menus::{
|
||||
AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu,
|
||||
CompletionsMenu, ContextMenuOrigin,
|
||||
};
|
||||
use git::blame::{GitBlame, GlobalBlameRenderer};
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, AnyElement, AnyWeakEntity, App, AppContext,
|
||||
AsyncWindowContext, AvailableSpace, Background, Bounds, ClickEvent, ClipboardEntry,
|
||||
ClipboardItem, Context, DispatchPhase, Edges, Entity, EntityInputHandler, EventEmitter,
|
||||
FocusHandle, FocusOutEvent, Focusable, FontId, FontWeight, Global, HighlightStyle, Hsla,
|
||||
KeyContext, Modifiers, MouseButton, MouseDownEvent, PaintQuad, ParentElement, Pixels, Render,
|
||||
Action, Animation, AnimationExt, AnyElement, App, AppContext, AsyncWindowContext,
|
||||
AvailableSpace, Background, Bounds, ClickEvent, ClipboardEntry, ClipboardItem, Context,
|
||||
DispatchPhase, Edges, Entity, EntityInputHandler, EventEmitter, FocusHandle, FocusOutEvent,
|
||||
Focusable, FontId, FontWeight, Global, HighlightStyle, Hsla, KeyContext, Modifiers,
|
||||
MouseButton, MouseDownEvent, PaintQuad, ParentElement, Pixels, Render, ScrollHandle,
|
||||
SharedString, Size, Stateful, Styled, Subscription, Task, TextStyle, TextStyleRefinement,
|
||||
UTF16Selection, UnderlineStyle, UniformListScrollHandle, WeakEntity, WeakFocusHandle, Window,
|
||||
div, impl_actions, point, prelude::*, pulsating_between, px, relative, size,
|
||||
@@ -117,6 +118,7 @@ use language::{
|
||||
};
|
||||
use language::{BufferRow, CharClassifier, Runnable, RunnableRange, point_to_lsp};
|
||||
use linked_editing_ranges::refresh_linked_ranges;
|
||||
use markdown::Markdown;
|
||||
use mouse_context_menu::MouseContextMenu;
|
||||
use persistence::DB;
|
||||
use project::{
|
||||
@@ -798,6 +800,21 @@ impl ChangeList {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct InlineBlamePopoverState {
|
||||
scroll_handle: ScrollHandle,
|
||||
commit_message: Option<ParsedCommitMessage>,
|
||||
markdown: Entity<Markdown>,
|
||||
}
|
||||
|
||||
struct InlineBlamePopover {
|
||||
position: gpui::Point<Pixels>,
|
||||
show_task: Option<Task<()>>,
|
||||
hide_task: Option<Task<()>>,
|
||||
popover_bounds: Option<Bounds<Pixels>>,
|
||||
popover_state: InlineBlamePopoverState,
|
||||
}
|
||||
|
||||
/// Zed's primary implementation of text input, allowing users to edit a [`MultiBuffer`].
|
||||
///
|
||||
/// See the [module level documentation](self) for more information.
|
||||
@@ -866,6 +883,7 @@ pub struct Editor {
|
||||
context_menu_options: Option<ContextMenuOptions>,
|
||||
mouse_context_menu: Option<MouseContextMenu>,
|
||||
completion_tasks: Vec<(CompletionId, Task<Option<()>>)>,
|
||||
inline_blame_popover: Option<InlineBlamePopover>,
|
||||
signature_help_state: SignatureHelpState,
|
||||
auto_signature_help: Option<bool>,
|
||||
find_all_references_task_sources: Vec<Anchor>,
|
||||
@@ -922,7 +940,6 @@ pub struct Editor {
|
||||
show_git_blame_gutter: bool,
|
||||
show_git_blame_inline: bool,
|
||||
show_git_blame_inline_delay_task: Option<Task<()>>,
|
||||
pub git_blame_inline_tooltip: Option<AnyWeakEntity>,
|
||||
git_blame_inline_enabled: bool,
|
||||
render_diff_hunk_controls: RenderDiffHunkControlsFn,
|
||||
serialize_dirty_buffers: bool,
|
||||
@@ -1665,6 +1682,7 @@ impl Editor {
|
||||
context_menu_options: None,
|
||||
mouse_context_menu: None,
|
||||
completion_tasks: Default::default(),
|
||||
inline_blame_popover: Default::default(),
|
||||
signature_help_state: SignatureHelpState::default(),
|
||||
auto_signature_help: None,
|
||||
find_all_references_task_sources: Vec::new(),
|
||||
@@ -1730,7 +1748,6 @@ impl Editor {
|
||||
show_git_blame_inline: false,
|
||||
show_selection_menu: None,
|
||||
show_git_blame_inline_delay_task: None,
|
||||
git_blame_inline_tooltip: None,
|
||||
git_blame_inline_enabled: ProjectSettings::get_global(cx).git.inline_blame_enabled(),
|
||||
render_diff_hunk_controls: Arc::new(render_diff_hunk_controls),
|
||||
serialize_dirty_buffers: ProjectSettings::get_global(cx)
|
||||
@@ -1806,6 +1823,7 @@ impl Editor {
|
||||
);
|
||||
});
|
||||
editor.hide_signature_help(cx, SignatureHelpHiddenBy::Escape);
|
||||
editor.inline_blame_popover.take();
|
||||
}
|
||||
}
|
||||
EditorEvent::Edited { .. } => {
|
||||
@@ -2603,6 +2621,7 @@ impl Editor {
|
||||
self.update_visible_inline_completion(window, cx);
|
||||
self.edit_prediction_requires_modifier_in_indent_conflict = true;
|
||||
linked_editing_ranges::refresh_linked_ranges(self, window, cx);
|
||||
self.inline_blame_popover.take();
|
||||
if self.git_blame_inline_enabled {
|
||||
self.start_inline_blame_timer(window, cx);
|
||||
}
|
||||
@@ -5140,7 +5159,7 @@ impl Editor {
|
||||
Self::build_tasks_context(&project, &buffer, buffer_row, tasks, cx)
|
||||
});
|
||||
|
||||
let debugger_flag = cx.has_flag::<Debugger>();
|
||||
let debugger_flag = cx.has_flag::<DebuggerFeatureFlag>();
|
||||
|
||||
Some(cx.spawn_in(window, async move |editor, cx| {
|
||||
let task_context = match task_context {
|
||||
@@ -5483,6 +5502,82 @@ impl Editor {
|
||||
}
|
||||
}
|
||||
|
||||
fn show_blame_popover(
|
||||
&mut self,
|
||||
blame_entry: &BlameEntry,
|
||||
position: gpui::Point<Pixels>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(state) = &mut self.inline_blame_popover {
|
||||
state.hide_task.take();
|
||||
cx.notify();
|
||||
} else {
|
||||
let delay = EditorSettings::get_global(cx).hover_popover_delay;
|
||||
let show_task = cx.spawn(async move |editor, cx| {
|
||||
cx.background_executor()
|
||||
.timer(std::time::Duration::from_millis(delay))
|
||||
.await;
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
if let Some(state) = &mut editor.inline_blame_popover {
|
||||
state.show_task = None;
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
let Some(blame) = self.blame.as_ref() else {
|
||||
return;
|
||||
};
|
||||
let blame = blame.read(cx);
|
||||
let details = blame.details_for_entry(&blame_entry);
|
||||
let markdown = cx.new(|cx| {
|
||||
Markdown::new(
|
||||
details
|
||||
.as_ref()
|
||||
.map(|message| message.message.clone())
|
||||
.unwrap_or_default(),
|
||||
None,
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
self.inline_blame_popover = Some(InlineBlamePopover {
|
||||
position,
|
||||
show_task: Some(show_task),
|
||||
hide_task: None,
|
||||
popover_bounds: None,
|
||||
popover_state: InlineBlamePopoverState {
|
||||
scroll_handle: ScrollHandle::new(),
|
||||
commit_message: details,
|
||||
markdown,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn hide_blame_popover(&mut self, cx: &mut Context<Self>) {
|
||||
if let Some(state) = &mut self.inline_blame_popover {
|
||||
if state.show_task.is_some() {
|
||||
self.inline_blame_popover.take();
|
||||
cx.notify();
|
||||
} else {
|
||||
let hide_task = cx.spawn(async move |editor, cx| {
|
||||
cx.background_executor()
|
||||
.timer(std::time::Duration::from_millis(100))
|
||||
.await;
|
||||
editor
|
||||
.update(cx, |editor, cx| {
|
||||
editor.inline_blame_popover.take();
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
state.hide_task = Some(hide_task);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn refresh_document_highlights(&mut self, cx: &mut Context<Self>) -> Option<()> {
|
||||
if self.pending_rename.is_some() {
|
||||
return None;
|
||||
@@ -9055,7 +9150,7 @@ impl Editor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if !cx.has_flag::<Debugger>() {
|
||||
if !cx.has_flag::<DebuggerFeatureFlag>() {
|
||||
return;
|
||||
}
|
||||
let source = self
|
||||
@@ -16657,12 +16752,7 @@ impl Editor {
|
||||
|
||||
pub fn render_git_blame_inline(&self, window: &Window, cx: &App) -> bool {
|
||||
self.show_git_blame_inline
|
||||
&& (self.focus_handle.is_focused(window)
|
||||
|| self
|
||||
.git_blame_inline_tooltip
|
||||
.as_ref()
|
||||
.and_then(|t| t.upgrade())
|
||||
.is_some())
|
||||
&& (self.focus_handle.is_focused(window) || self.inline_blame_popover.is_some())
|
||||
&& !self.newest_selection_head_on_empty_line(cx)
|
||||
&& self.has_blame_entries(cx)
|
||||
}
|
||||
|
||||
@@ -30,17 +30,21 @@ use crate::{
|
||||
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
|
||||
use client::ParticipantIndex;
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use feature_flags::{Debugger, FeatureFlagAppExt};
|
||||
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt};
|
||||
use file_icons::FileIcons;
|
||||
use git::{Oid, blame::BlameEntry, status::FileStatus};
|
||||
use git::{
|
||||
Oid,
|
||||
blame::{BlameEntry, ParsedCommitMessage},
|
||||
status::FileStatus,
|
||||
};
|
||||
use gpui::{
|
||||
Action, Along, AnyElement, App, AvailableSpace, Axis as ScrollbarAxis, BorderStyle, Bounds,
|
||||
ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges, Element,
|
||||
ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, Hsla,
|
||||
Action, Along, AnyElement, App, AppContext, AvailableSpace, Axis as ScrollbarAxis, BorderStyle,
|
||||
Bounds, ClickEvent, ContentMask, Context, Corner, Corners, CursorStyle, DispatchPhase, Edges,
|
||||
Element, ElementInputHandler, Entity, Focusable as _, FontId, GlobalElementId, Hitbox, Hsla,
|
||||
InteractiveElement, IntoElement, Keystroke, Length, ModifiersChangedEvent, MouseButton,
|
||||
MouseDownEvent, MouseMoveEvent, MouseUpEvent, PaintQuad, ParentElement, Pixels, ScrollDelta,
|
||||
ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement, Style, Styled,
|
||||
TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
|
||||
ScrollHandle, ScrollWheelEvent, ShapedLine, SharedString, Size, StatefulInteractiveElement,
|
||||
Style, Styled, TextRun, TextStyleRefinement, WeakEntity, Window, anchored, deferred, div, fill,
|
||||
linear_color_stop, linear_gradient, outline, point, px, quad, relative, size, solid_background,
|
||||
transparent_black,
|
||||
};
|
||||
@@ -49,6 +53,7 @@ use language::language_settings::{
|
||||
IndentGuideBackgroundColoring, IndentGuideColoring, IndentGuideSettings, ShowWhitespaceSetting,
|
||||
};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use markdown::Markdown;
|
||||
use multi_buffer::{
|
||||
Anchor, ExcerptId, ExcerptInfo, ExpandExcerptDirection, ExpandInfo, MultiBufferPoint,
|
||||
MultiBufferRow, RowInfo,
|
||||
@@ -542,7 +547,7 @@ impl EditorElement {
|
||||
register_action(editor, window, Editor::insert_uuid_v4);
|
||||
register_action(editor, window, Editor::insert_uuid_v7);
|
||||
register_action(editor, window, Editor::open_selections_in_multibuffer);
|
||||
if cx.has_flag::<Debugger>() {
|
||||
if cx.has_flag::<DebuggerFeatureFlag>() {
|
||||
register_action(editor, window, Editor::toggle_breakpoint);
|
||||
register_action(editor, window, Editor::edit_log_breakpoint);
|
||||
register_action(editor, window, Editor::enable_breakpoint);
|
||||
@@ -1749,6 +1754,7 @@ impl EditorElement {
|
||||
content_origin: gpui::Point<Pixels>,
|
||||
scroll_pixel_position: gpui::Point<Pixels>,
|
||||
line_height: Pixels,
|
||||
text_hitbox: &Hitbox,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
@@ -1780,21 +1786,13 @@ impl EditorElement {
|
||||
padding * em_width
|
||||
};
|
||||
|
||||
let workspace = editor.workspace()?.downgrade();
|
||||
let blame_entry = blame
|
||||
.update(cx, |blame, cx| {
|
||||
blame.blame_for_rows(&[*row_info], cx).next()
|
||||
})
|
||||
.flatten()?;
|
||||
|
||||
let mut element = render_inline_blame_entry(
|
||||
self.editor.clone(),
|
||||
workspace,
|
||||
&blame,
|
||||
blame_entry,
|
||||
&self.style,
|
||||
cx,
|
||||
)?;
|
||||
let mut element = render_inline_blame_entry(blame_entry.clone(), &self.style, cx)?;
|
||||
|
||||
let start_y = content_origin.y
|
||||
+ line_height * (display_row.as_f32() - scroll_pixel_position.y / line_height);
|
||||
@@ -1820,11 +1818,122 @@ impl EditorElement {
|
||||
};
|
||||
|
||||
let absolute_offset = point(start_x, start_y);
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
let bounds = Bounds::new(absolute_offset, size);
|
||||
|
||||
self.layout_blame_entry_popover(
|
||||
bounds,
|
||||
blame_entry,
|
||||
blame,
|
||||
line_height,
|
||||
text_hitbox,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
element.prepaint_as_root(absolute_offset, AvailableSpace::min_size(), window, cx);
|
||||
|
||||
Some(element)
|
||||
}
|
||||
|
||||
fn layout_blame_entry_popover(
|
||||
&self,
|
||||
parent_bounds: Bounds<Pixels>,
|
||||
blame_entry: BlameEntry,
|
||||
blame: Entity<GitBlame>,
|
||||
line_height: Pixels,
|
||||
text_hitbox: &Hitbox,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let mouse_position = window.mouse_position();
|
||||
let mouse_over_inline_blame = parent_bounds.contains(&mouse_position);
|
||||
let mouse_over_popover = self.editor.update(cx, |editor, _| {
|
||||
editor
|
||||
.inline_blame_popover
|
||||
.as_ref()
|
||||
.and_then(|state| state.popover_bounds)
|
||||
.map_or(false, |bounds| bounds.contains(&mouse_position))
|
||||
});
|
||||
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
if mouse_over_inline_blame || mouse_over_popover {
|
||||
editor.show_blame_popover(&blame_entry, mouse_position, cx);
|
||||
} else {
|
||||
editor.hide_blame_popover(cx);
|
||||
}
|
||||
});
|
||||
|
||||
let should_draw = self.editor.update(cx, |editor, _| {
|
||||
editor
|
||||
.inline_blame_popover
|
||||
.as_ref()
|
||||
.map_or(false, |state| state.show_task.is_none())
|
||||
});
|
||||
|
||||
if should_draw {
|
||||
let maybe_element = self.editor.update(cx, |editor, cx| {
|
||||
editor
|
||||
.workspace()
|
||||
.map(|workspace| workspace.downgrade())
|
||||
.zip(
|
||||
editor
|
||||
.inline_blame_popover
|
||||
.as_ref()
|
||||
.map(|p| p.popover_state.clone()),
|
||||
)
|
||||
.and_then(|(workspace, popover_state)| {
|
||||
render_blame_entry_popover(
|
||||
blame_entry,
|
||||
popover_state.scroll_handle,
|
||||
popover_state.commit_message,
|
||||
popover_state.markdown,
|
||||
workspace,
|
||||
&blame,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(mut element) = maybe_element {
|
||||
let size = element.layout_as_root(AvailableSpace::min_size(), window, cx);
|
||||
let origin = self.editor.update(cx, |editor, _| {
|
||||
let target_point = editor
|
||||
.inline_blame_popover
|
||||
.as_ref()
|
||||
.map_or(mouse_position, |state| state.position);
|
||||
|
||||
let overall_height = size.height + HOVER_POPOVER_GAP;
|
||||
let popover_origin = if target_point.y > overall_height {
|
||||
point(target_point.x, target_point.y - size.height)
|
||||
} else {
|
||||
point(
|
||||
target_point.x,
|
||||
target_point.y + line_height + HOVER_POPOVER_GAP,
|
||||
)
|
||||
};
|
||||
|
||||
let horizontal_offset = (text_hitbox.top_right().x
|
||||
- POPOVER_RIGHT_OFFSET
|
||||
- (popover_origin.x + size.width))
|
||||
.min(Pixels::ZERO);
|
||||
|
||||
point(popover_origin.x + horizontal_offset, popover_origin.y)
|
||||
});
|
||||
|
||||
let popover_bounds = Bounds::new(origin, size);
|
||||
self.editor.update(cx, |editor, _| {
|
||||
if let Some(state) = &mut editor.inline_blame_popover {
|
||||
state.popover_bounds = Some(popover_bounds);
|
||||
}
|
||||
});
|
||||
|
||||
window.defer_draw(element, origin, 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn layout_blame_entries(
|
||||
&self,
|
||||
buffer_rows: &[RowInfo],
|
||||
@@ -5851,24 +5960,35 @@ fn prepaint_gutter_button(
|
||||
}
|
||||
|
||||
fn render_inline_blame_entry(
|
||||
editor: Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
blame: &Entity<GitBlame>,
|
||||
blame_entry: BlameEntry,
|
||||
style: &EditorStyle,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
let renderer = cx.global::<GlobalBlameRenderer>().0.clone();
|
||||
renderer.render_inline_blame_entry(&style.text, blame_entry, cx)
|
||||
}
|
||||
|
||||
fn render_blame_entry_popover(
|
||||
blame_entry: BlameEntry,
|
||||
scroll_handle: ScrollHandle,
|
||||
commit_message: Option<ParsedCommitMessage>,
|
||||
markdown: Entity<Markdown>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
blame: &Entity<GitBlame>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
let renderer = cx.global::<GlobalBlameRenderer>().0.clone();
|
||||
let blame = blame.read(cx);
|
||||
let details = blame.details_for_entry(&blame_entry);
|
||||
let repository = blame.repository(cx)?.clone();
|
||||
renderer.render_inline_blame_entry(
|
||||
&style.text,
|
||||
renderer.render_blame_entry_popover(
|
||||
blame_entry,
|
||||
details,
|
||||
scroll_handle,
|
||||
commit_message,
|
||||
markdown,
|
||||
repository,
|
||||
workspace,
|
||||
editor,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
@@ -6917,7 +7037,7 @@ impl Element for EditorElement {
|
||||
let mut breakpoint_rows = self.editor.update(cx, |editor, cx| {
|
||||
editor.active_breakpoints(start_row..end_row, window, cx)
|
||||
});
|
||||
if cx.has_flag::<Debugger>() {
|
||||
if cx.has_flag::<DebuggerFeatureFlag>() {
|
||||
for display_row in breakpoint_rows.keys() {
|
||||
active_rows.entry(*display_row).or_default().breakpoint = true;
|
||||
}
|
||||
@@ -6940,7 +7060,7 @@ impl Element for EditorElement {
|
||||
// We add the gutter breakpoint indicator to breakpoint_rows after painting
|
||||
// line numbers so we don't paint a line number debug accent color if a user
|
||||
// has their mouse over that line when a breakpoint isn't there
|
||||
if cx.has_flag::<Debugger>() {
|
||||
if cx.has_flag::<DebuggerFeatureFlag>() {
|
||||
let gutter_breakpoint_indicator =
|
||||
self.editor.read(cx).gutter_breakpoint_indicator.0;
|
||||
if let Some((gutter_breakpoint_point, _)) =
|
||||
@@ -7046,14 +7166,7 @@ impl Element for EditorElement {
|
||||
blame.blame_for_rows(&[row_infos], cx).next()
|
||||
})
|
||||
.flatten()?;
|
||||
let mut element = render_inline_blame_entry(
|
||||
self.editor.clone(),
|
||||
editor.workspace()?.downgrade(),
|
||||
blame,
|
||||
blame_entry,
|
||||
&style,
|
||||
cx,
|
||||
)?;
|
||||
let mut element = render_inline_blame_entry(blame_entry, &style, cx)?;
|
||||
let inline_blame_padding = INLINE_BLAME_PADDING_EM_WIDTHS * em_advance;
|
||||
Some(
|
||||
element
|
||||
@@ -7262,6 +7375,7 @@ impl Element for EditorElement {
|
||||
content_origin,
|
||||
scroll_pixel_position,
|
||||
line_height,
|
||||
&text_hitbox,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -7462,7 +7576,7 @@ impl Element for EditorElement {
|
||||
let show_breakpoints = snapshot
|
||||
.show_breakpoints
|
||||
.unwrap_or(gutter_settings.breakpoints);
|
||||
let breakpoints = if cx.has_flag::<Debugger>() && show_breakpoints {
|
||||
let breakpoints = if cx.has_flag::<DebuggerFeatureFlag>() && show_breakpoints {
|
||||
self.layout_breakpoints(
|
||||
line_height,
|
||||
start_row..end_row,
|
||||
|
||||
@@ -7,10 +7,11 @@ use git::{
|
||||
parse_git_remote_url,
|
||||
};
|
||||
use gpui::{
|
||||
AnyElement, App, AppContext as _, Context, Entity, Hsla, Subscription, Task, TextStyle,
|
||||
WeakEntity, Window,
|
||||
AnyElement, App, AppContext as _, Context, Entity, Hsla, ScrollHandle, Subscription, Task,
|
||||
TextStyle, WeakEntity, Window,
|
||||
};
|
||||
use language::{Bias, Buffer, BufferSnapshot, Edit};
|
||||
use markdown::Markdown;
|
||||
use multi_buffer::RowInfo;
|
||||
use project::{
|
||||
Project, ProjectItem,
|
||||
@@ -98,10 +99,18 @@ pub trait BlameRenderer {
|
||||
&self,
|
||||
_: &TextStyle,
|
||||
_: BlameEntry,
|
||||
_: &mut App,
|
||||
) -> Option<AnyElement>;
|
||||
|
||||
fn render_blame_entry_popover(
|
||||
&self,
|
||||
_: BlameEntry,
|
||||
_: ScrollHandle,
|
||||
_: Option<ParsedCommitMessage>,
|
||||
_: Entity<Markdown>,
|
||||
_: Entity<Repository>,
|
||||
_: WeakEntity<Workspace>,
|
||||
_: Entity<Editor>,
|
||||
_: &mut Window,
|
||||
_: &mut App,
|
||||
) -> Option<AnyElement>;
|
||||
|
||||
@@ -139,10 +148,20 @@ impl BlameRenderer for () {
|
||||
&self,
|
||||
_: &TextStyle,
|
||||
_: BlameEntry,
|
||||
_: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
None
|
||||
}
|
||||
|
||||
fn render_blame_entry_popover(
|
||||
&self,
|
||||
_: BlameEntry,
|
||||
_: ScrollHandle,
|
||||
_: Option<ParsedCommitMessage>,
|
||||
_: Entity<Markdown>,
|
||||
_: Entity<Repository>,
|
||||
_: WeakEntity<Workspace>,
|
||||
_: Entity<Editor>,
|
||||
_: &mut Window,
|
||||
_: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
None
|
||||
|
||||
@@ -957,6 +957,7 @@ impl Item for Editor {
|
||||
cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| {
|
||||
if matches!(event, workspace::Event::ModalOpened) {
|
||||
editor.mouse_context_menu.take();
|
||||
editor.inline_blame_popover.take();
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
ToolMetrics,
|
||||
assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
|
||||
};
|
||||
use agent::ThreadEvent;
|
||||
use agent::{ContextLoadResult, ThreadEvent};
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
@@ -115,7 +115,12 @@ impl ExampleContext {
|
||||
pub fn push_user_message(&mut self, text: impl ToString) {
|
||||
self.app
|
||||
.update_entity(&self.agent_thread, |thread, cx| {
|
||||
thread.insert_user_message(text.to_string(), vec![], None, cx);
|
||||
thread.insert_user_message(
|
||||
text.to_string(),
|
||||
ContextLoadResult::default(),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
@@ -253,6 +258,9 @@ impl ExampleContext {
|
||||
}
|
||||
});
|
||||
}
|
||||
ThreadEvent::InvalidToolInput { .. } => {
|
||||
println!("{log_prefix} invalid tool input");
|
||||
}
|
||||
ThreadEvent::ToolConfirmationNeeded => {
|
||||
panic!(
|
||||
"{}Bug: Tool confirmation should not be required in eval",
|
||||
|
||||
@@ -13,13 +13,11 @@ use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
|
||||
|
||||
mod add_arg_to_trait_method;
|
||||
mod file_search;
|
||||
mod no_empty_messages;
|
||||
|
||||
pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
|
||||
let mut threads: Vec<Rc<dyn Example>> = vec![
|
||||
Rc::new(file_search::FileSearchExample),
|
||||
Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
|
||||
Rc::new(no_empty_messages::NoEmptyMessagesExample),
|
||||
];
|
||||
|
||||
for example_path in list_declarative_examples(examples_dir).unwrap() {
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use assistant_tools::{ListDirectoryToolInput, ReadFileToolInput};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::example::{Example, ExampleContext, ExampleMetadata};
|
||||
|
||||
pub struct NoEmptyMessagesExample;
|
||||
|
||||
#[async_trait(?Send)]
|
||||
impl Example for NoEmptyMessagesExample {
|
||||
fn meta(&self) -> ExampleMetadata {
|
||||
ExampleMetadata {
|
||||
name: "no_empty_messages".to_string(),
|
||||
url: "https://github.com/zed-industries/zed.git".to_string(),
|
||||
revision: "fcfeea4825c563715bcd1a1af809d88a37d12ccb".to_string(),
|
||||
language_server: None,
|
||||
max_assertions: Some(3),
|
||||
}
|
||||
}
|
||||
|
||||
async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
|
||||
cx.push_user_message(format!(
|
||||
r#"
|
||||
Summarize all the files in crates/assistant/src.
|
||||
Do tool calls only and do NOT include a description of what you are doing.
|
||||
Just give me an explanation of what you did after you finished all tool calls
|
||||
"#
|
||||
));
|
||||
|
||||
let response = cx.run_turn().await?;
|
||||
dbg!(response);
|
||||
// let tool_use = response.expect_tool("list_directory", cx)?;
|
||||
// let input = tool_use.parse_input::<ListDirectoryToolInput>()?;
|
||||
// cx.assert(
|
||||
// &input.path == "zed/crates/assistant/src",
|
||||
// "Path matches directory",
|
||||
// )?;
|
||||
|
||||
for file in [
|
||||
"assistant.rs",
|
||||
"assistant_configuration.rs",
|
||||
"assistant_panel.rs",
|
||||
"inline_assistant.rs",
|
||||
"slash_command_settings.rs",
|
||||
"terminal_inline_assistant.rs",
|
||||
] {
|
||||
let response = cx.run_turn().await?;
|
||||
dbg!(response);
|
||||
// let tool_use = response.expect_tool("read_file", cx)?;
|
||||
// let input = tool_use.parse_input::<ReadFileToolInput>()?;
|
||||
// dbg!(&input);
|
||||
// cx.assert(
|
||||
// input.path == format!("zed/crates/assistant/src/{file}"),
|
||||
// format!("Path {} matches file {file}", input.path),
|
||||
// )?;
|
||||
}
|
||||
|
||||
cx.run_to_end().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -218,8 +218,14 @@ impl ExampleInstance {
|
||||
});
|
||||
|
||||
let tools = cx.new(|_| ToolWorkingSet::default());
|
||||
let thread_store =
|
||||
ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
|
||||
let prompt_store = None;
|
||||
let thread_store = ThreadStore::load(
|
||||
project.clone(),
|
||||
tools,
|
||||
prompt_store,
|
||||
app_state.prompt_builder.clone(),
|
||||
cx,
|
||||
);
|
||||
let meta = self.thread.meta();
|
||||
let this = self.clone();
|
||||
|
||||
|
||||
@@ -64,23 +64,18 @@ impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag {
|
||||
const NAME: &'static str = "predict-edits-rate-completions";
|
||||
}
|
||||
|
||||
pub struct Remoting {}
|
||||
impl FeatureFlag for Remoting {
|
||||
const NAME: &'static str = "remoting";
|
||||
}
|
||||
|
||||
pub struct LanguageModels {}
|
||||
impl FeatureFlag for LanguageModels {
|
||||
pub struct LanguageModelsFeatureFlag {}
|
||||
impl FeatureFlag for LanguageModelsFeatureFlag {
|
||||
const NAME: &'static str = "language-models";
|
||||
}
|
||||
|
||||
pub struct LlmClosedBeta {}
|
||||
impl FeatureFlag for LlmClosedBeta {
|
||||
pub struct LlmClosedBetaFeatureFlag {}
|
||||
impl FeatureFlag for LlmClosedBetaFeatureFlag {
|
||||
const NAME: &'static str = "llm-closed-beta";
|
||||
}
|
||||
|
||||
pub struct ZedPro {}
|
||||
impl FeatureFlag for ZedPro {
|
||||
pub struct ZedProFeatureFlag {}
|
||||
impl FeatureFlag for ZedProFeatureFlag {
|
||||
const NAME: &'static str = "zed-pro";
|
||||
}
|
||||
|
||||
@@ -90,13 +85,13 @@ impl FeatureFlag for NotebookFeatureFlag {
|
||||
const NAME: &'static str = "notebooks";
|
||||
}
|
||||
|
||||
pub struct Debugger {}
|
||||
impl FeatureFlag for Debugger {
|
||||
pub struct DebuggerFeatureFlag {}
|
||||
impl FeatureFlag for DebuggerFeatureFlag {
|
||||
const NAME: &'static str = "debugger";
|
||||
}
|
||||
|
||||
pub struct ThreadAutoCapture {}
|
||||
impl FeatureFlag for ThreadAutoCapture {
|
||||
pub struct ThreadAutoCaptureFeatureFlag {}
|
||||
impl FeatureFlag for ThreadAutoCaptureFeatureFlag {
|
||||
const NAME: &'static str = "thread-auto-capture";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
use crate::{commit_tooltip::CommitTooltip, commit_view::CommitView};
|
||||
use editor::{BlameRenderer, Editor};
|
||||
use crate::{
|
||||
commit_tooltip::{CommitAvatar, CommitDetails, CommitTooltip},
|
||||
commit_view::CommitView,
|
||||
};
|
||||
use editor::{BlameRenderer, Editor, hover_markdown_style};
|
||||
use git::{
|
||||
blame::{BlameEntry, ParsedCommitMessage},
|
||||
repository::CommitSummary,
|
||||
};
|
||||
use gpui::{
|
||||
AnyElement, App, AppContext as _, ClipboardItem, Element as _, Entity, Hsla,
|
||||
InteractiveElement as _, MouseButton, Pixels, StatefulInteractiveElement as _, Styled as _,
|
||||
Subscription, TextStyle, WeakEntity, Window, div,
|
||||
ClipboardItem, Entity, Hsla, MouseButton, ScrollHandle, Subscription, TextStyle, WeakEntity,
|
||||
prelude::*,
|
||||
};
|
||||
use markdown::{Markdown, MarkdownElement};
|
||||
use project::{git_store::Repository, project_settings::ProjectSettings};
|
||||
use settings::Settings as _;
|
||||
use ui::{
|
||||
ActiveTheme, Color, ContextMenu, FluentBuilder as _, Icon, IconName, ParentElement as _, h_flex,
|
||||
};
|
||||
use theme::ThemeSettings;
|
||||
use time::OffsetDateTime;
|
||||
use time_format::format_local_timestamp;
|
||||
use ui::{ContextMenu, Divider, IconButtonShape, prelude::*};
|
||||
use workspace::Workspace;
|
||||
|
||||
const GIT_BLAME_MAX_AUTHOR_CHARS_DISPLAYED: usize = 20;
|
||||
@@ -115,10 +119,6 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
&self,
|
||||
style: &TextStyle,
|
||||
blame_entry: BlameEntry,
|
||||
details: Option<ParsedCommitMessage>,
|
||||
repository: Entity<Repository>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
editor: Entity<Editor>,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
let relative_timestamp = blame_entry_relative_timestamp(&blame_entry);
|
||||
@@ -144,25 +144,223 @@ impl BlameRenderer for GitBlameRenderer {
|
||||
.child(Icon::new(IconName::FileGit).color(Color::Hint))
|
||||
.child(text)
|
||||
.gap_2()
|
||||
.hoverable_tooltip(move |_window, cx| {
|
||||
let tooltip = cx.new(|cx| {
|
||||
CommitTooltip::blame_entry(
|
||||
&blame_entry,
|
||||
details.clone(),
|
||||
repository.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
editor.update(cx, |editor, _| {
|
||||
editor.git_blame_inline_tooltip = Some(tooltip.downgrade().into())
|
||||
});
|
||||
tooltip.into()
|
||||
})
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_blame_entry_popover(
|
||||
&self,
|
||||
blame: BlameEntry,
|
||||
scroll_handle: ScrollHandle,
|
||||
details: Option<ParsedCommitMessage>,
|
||||
markdown: Entity<Markdown>,
|
||||
repository: Entity<Repository>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Option<AnyElement> {
|
||||
let commit_time = blame
|
||||
.committer_time
|
||||
.and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok())
|
||||
.unwrap_or(OffsetDateTime::now_utc());
|
||||
|
||||
let commit_details = CommitDetails {
|
||||
sha: blame.sha.to_string().into(),
|
||||
commit_time,
|
||||
author_name: blame
|
||||
.author
|
||||
.clone()
|
||||
.unwrap_or("<no name>".to_string())
|
||||
.into(),
|
||||
author_email: blame.author_mail.clone().unwrap_or("".to_string()).into(),
|
||||
message: details,
|
||||
};
|
||||
|
||||
let avatar = CommitAvatar::new(&commit_details).render(window, cx);
|
||||
|
||||
let author = commit_details.author_name.clone();
|
||||
let author_email = commit_details.author_email.clone();
|
||||
|
||||
let short_commit_id = commit_details
|
||||
.sha
|
||||
.get(0..8)
|
||||
.map(|sha| sha.to_string().into())
|
||||
.unwrap_or_else(|| commit_details.sha.clone());
|
||||
let full_sha = commit_details.sha.to_string().clone();
|
||||
let absolute_timestamp = format_local_timestamp(
|
||||
commit_details.commit_time,
|
||||
OffsetDateTime::now_utc(),
|
||||
time_format::TimestampFormat::MediumAbsolute,
|
||||
);
|
||||
let markdown_style = {
|
||||
let mut style = hover_markdown_style(window, cx);
|
||||
if let Some(code_block) = &style.code_block.text {
|
||||
style.base_text_style.refine(code_block);
|
||||
}
|
||||
style
|
||||
};
|
||||
|
||||
let message = commit_details
|
||||
.message
|
||||
.as_ref()
|
||||
.map(|_| MarkdownElement::new(markdown.clone(), markdown_style).into_any())
|
||||
.unwrap_or("<no commit message>".into_any());
|
||||
|
||||
let pull_request = commit_details
|
||||
.message
|
||||
.as_ref()
|
||||
.and_then(|details| details.pull_request.clone());
|
||||
|
||||
let ui_font_size = ThemeSettings::get_global(cx).ui_font_size(cx);
|
||||
let message_max_height = window.line_height() * 12 + (ui_font_size / 0.4);
|
||||
let commit_summary = CommitSummary {
|
||||
sha: commit_details.sha.clone(),
|
||||
subject: commit_details
|
||||
.message
|
||||
.as_ref()
|
||||
.map_or(Default::default(), |message| {
|
||||
message
|
||||
.message
|
||||
.split('\n')
|
||||
.next()
|
||||
.unwrap()
|
||||
.trim_end()
|
||||
.to_string()
|
||||
.into()
|
||||
}),
|
||||
commit_timestamp: commit_details.commit_time.unix_timestamp(),
|
||||
has_parent: false,
|
||||
};
|
||||
|
||||
let ui_font = ThemeSettings::get_global(cx).ui_font.clone();
|
||||
|
||||
// padding to avoid tooltip appearing right below the mouse cursor
|
||||
// TODO: use tooltip_container here
|
||||
Some(
|
||||
div()
|
||||
.pl_2()
|
||||
.pt_2p5()
|
||||
.child(
|
||||
v_flex()
|
||||
.elevation_2(cx)
|
||||
.font(ui_font)
|
||||
.text_ui(cx)
|
||||
.text_color(cx.theme().colors().text)
|
||||
.py_1()
|
||||
.px_2()
|
||||
.map(|el| {
|
||||
el.occlude()
|
||||
.on_mouse_move(|_, _, cx| cx.stop_propagation())
|
||||
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
|
||||
.child(
|
||||
v_flex()
|
||||
.w(gpui::rems(30.))
|
||||
.gap_4()
|
||||
.child(
|
||||
h_flex()
|
||||
.pb_1p5()
|
||||
.gap_x_2()
|
||||
.overflow_x_hidden()
|
||||
.flex_wrap()
|
||||
.children(avatar)
|
||||
.child(author)
|
||||
.when(!author_email.is_empty(), |this| {
|
||||
this.child(
|
||||
div()
|
||||
.text_color(
|
||||
cx.theme().colors().text_muted,
|
||||
)
|
||||
.child(author_email),
|
||||
)
|
||||
})
|
||||
.border_b_1()
|
||||
.border_color(cx.theme().colors().border_variant),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.id("inline-blame-commit-message")
|
||||
.child(message)
|
||||
.max_h(message_max_height)
|
||||
.overflow_y_scroll()
|
||||
.track_scroll(&scroll_handle),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.text_color(cx.theme().colors().text_muted)
|
||||
.w_full()
|
||||
.justify_between()
|
||||
.pt_1p5()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.child(absolute_timestamp)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.when_some(pull_request, |this, pr| {
|
||||
this.child(
|
||||
Button::new(
|
||||
"pull-request-button",
|
||||
format!("#{}", pr.number),
|
||||
)
|
||||
.color(Color::Muted)
|
||||
.icon(IconName::PullRequest)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.style(ButtonStyle::Subtle)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
cx.open_url(pr.url.as_str())
|
||||
}),
|
||||
)
|
||||
})
|
||||
.child(Divider::vertical())
|
||||
.child(
|
||||
Button::new(
|
||||
"commit-sha-button",
|
||||
short_commit_id.clone(),
|
||||
)
|
||||
.style(ButtonStyle::Subtle)
|
||||
.color(Color::Muted)
|
||||
.icon(IconName::FileGit)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(move |_, window, cx| {
|
||||
CommitView::open(
|
||||
commit_summary.clone(),
|
||||
repository.downgrade(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
cx.stop_propagation();
|
||||
}),
|
||||
)
|
||||
.child(
|
||||
IconButton::new(
|
||||
"copy-sha-button",
|
||||
IconName::Copy,
|
||||
)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.stop_propagation();
|
||||
cx.write_to_clipboard(
|
||||
ClipboardItem::new_string(
|
||||
full_sha.clone(),
|
||||
),
|
||||
)
|
||||
}),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}),
|
||||
)
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
|
||||
fn open_blame_commit(
|
||||
&self,
|
||||
blame_entry: BlameEntry,
|
||||
|
||||
@@ -27,22 +27,18 @@ pub struct CommitDetails {
|
||||
pub message: Option<ParsedCommitMessage>,
|
||||
}
|
||||
|
||||
struct CommitAvatar<'a> {
|
||||
pub struct CommitAvatar<'a> {
|
||||
commit: &'a CommitDetails,
|
||||
}
|
||||
|
||||
impl<'a> CommitAvatar<'a> {
|
||||
fn new(details: &'a CommitDetails) -> Self {
|
||||
pub fn new(details: &'a CommitDetails) -> Self {
|
||||
Self { commit: details }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> CommitAvatar<'a> {
|
||||
fn render(
|
||||
&'a self,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<CommitTooltip>,
|
||||
) -> Option<impl IntoElement + use<>> {
|
||||
pub fn render(&'a self, window: &mut Window, cx: &mut App) -> Option<impl IntoElement + use<>> {
|
||||
let remote = self
|
||||
.commit
|
||||
.message
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
mod supported_countries;
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
@@ -338,7 +334,6 @@ pub struct CountTokensResponse {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub raw_args: String,
|
||||
pub args: serde_json::Value,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Returns whether the given country code is supported by Google Gemini.
|
||||
///
|
||||
/// <https://ai.google.dev/gemini-api/docs/available-regions>
|
||||
pub fn is_supported_country(country_code: &str) -> bool {
|
||||
SUPPORTED_COUNTRIES.contains(&country_code)
|
||||
}
|
||||
|
||||
/// The list of country codes supported by Google Gemini.
|
||||
///
|
||||
/// https://ai.google.dev/gemini-api/docs/available-regions
|
||||
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
vec![
|
||||
"DZ", // Algeria
|
||||
"AS", // American Samoa
|
||||
"AO", // Angola
|
||||
"AI", // Anguilla
|
||||
"AQ", // Antarctica
|
||||
"AG", // Antigua and Barbuda
|
||||
"AR", // Argentina
|
||||
"AM", // Armenia
|
||||
"AW", // Aruba
|
||||
"AU", // Australia
|
||||
"AT", // Austria
|
||||
"AZ", // Azerbaijan
|
||||
"BS", // The Bahamas
|
||||
"BH", // Bahrain
|
||||
"BD", // Bangladesh
|
||||
"BB", // Barbados
|
||||
"BE", // Belgium
|
||||
"BZ", // Belize
|
||||
"BJ", // Benin
|
||||
"BM", // Bermuda
|
||||
"BT", // Bhutan
|
||||
"BO", // Bolivia
|
||||
"BW", // Botswana
|
||||
"BR", // Brazil
|
||||
"IO", // British Indian Ocean Territory
|
||||
"VG", // British Virgin Islands
|
||||
"BN", // Brunei
|
||||
"BG", // Bulgaria
|
||||
"BF", // Burkina Faso
|
||||
"BI", // Burundi
|
||||
"CV", // Cabo Verde
|
||||
"KH", // Cambodia
|
||||
"CM", // Cameroon
|
||||
"CA", // Canada
|
||||
"BQ", // Caribbean Netherlands
|
||||
"KY", // Cayman Islands
|
||||
"CF", // Central African Republic
|
||||
"TD", // Chad
|
||||
"CL", // Chile
|
||||
"CX", // Christmas Island
|
||||
"CC", // Cocos (Keeling) Islands
|
||||
"CO", // Colombia
|
||||
"KM", // Comoros
|
||||
"CK", // Cook Islands
|
||||
"CI", // Côte d'Ivoire
|
||||
"CR", // Costa Rica
|
||||
"HR", // Croatia
|
||||
"CW", // Curaçao
|
||||
"CZ", // Czech Republic
|
||||
"CD", // Democratic Republic of the Congo
|
||||
"DK", // Denmark
|
||||
"DJ", // Djibouti
|
||||
"DM", // Dominica
|
||||
"DO", // Dominican Republic
|
||||
"EC", // Ecuador
|
||||
"EG", // Egypt
|
||||
"SV", // El Salvador
|
||||
"GQ", // Equatorial Guinea
|
||||
"ER", // Eritrea
|
||||
"EE", // Estonia
|
||||
"SZ", // Eswatini
|
||||
"ET", // Ethiopia
|
||||
"FK", // Falkland Islands (Islas Malvinas)
|
||||
"FJ", // Fiji
|
||||
"FI", // Finland
|
||||
"FR", // France
|
||||
"GA", // Gabon
|
||||
"GM", // The Gambia
|
||||
"GE", // Georgia
|
||||
"DE", // Germany
|
||||
"GH", // Ghana
|
||||
"GI", // Gibraltar
|
||||
"GR", // Greece
|
||||
"GD", // Grenada
|
||||
"GU", // Guam
|
||||
"GT", // Guatemala
|
||||
"GG", // Guernsey
|
||||
"GN", // Guinea
|
||||
"GW", // Guinea-Bissau
|
||||
"GY", // Guyana
|
||||
"HT", // Haiti
|
||||
"HM", // Heard Island and McDonald Islands
|
||||
"HN", // Honduras
|
||||
"HU", // Hungary
|
||||
"IS", // Iceland
|
||||
"IN", // India
|
||||
"ID", // Indonesia
|
||||
"IQ", // Iraq
|
||||
"IE", // Ireland
|
||||
"IM", // Isle of Man
|
||||
"IL", // Israel
|
||||
"IT", // Italy
|
||||
"JM", // Jamaica
|
||||
"JP", // Japan
|
||||
"JE", // Jersey
|
||||
"JO", // Jordan
|
||||
"KZ", // Kazakhstan
|
||||
"KE", // Kenya
|
||||
"KI", // Kiribati
|
||||
"KG", // Kyrgyzstan
|
||||
"KW", // Kuwait
|
||||
"LA", // Laos
|
||||
"LV", // Latvia
|
||||
"LB", // Lebanon
|
||||
"LS", // Lesotho
|
||||
"LR", // Liberia
|
||||
"LY", // Libya
|
||||
"LI", // Liechtenstein
|
||||
"LT", // Lithuania
|
||||
"LU", // Luxembourg
|
||||
"MG", // Madagascar
|
||||
"MW", // Malawi
|
||||
"MY", // Malaysia
|
||||
"MV", // Maldives
|
||||
"ML", // Mali
|
||||
"MT", // Malta
|
||||
"MH", // Marshall Islands
|
||||
"MR", // Mauritania
|
||||
"MU", // Mauritius
|
||||
"MX", // Mexico
|
||||
"FM", // Micronesia
|
||||
"MN", // Mongolia
|
||||
"MS", // Montserrat
|
||||
"MA", // Morocco
|
||||
"MZ", // Mozambique
|
||||
"NA", // Namibia
|
||||
"NR", // Nauru
|
||||
"NP", // Nepal
|
||||
"NL", // Netherlands
|
||||
"NC", // New Caledonia
|
||||
"NZ", // New Zealand
|
||||
"NI", // Nicaragua
|
||||
"NE", // Niger
|
||||
"NG", // Nigeria
|
||||
"NU", // Niue
|
||||
"NF", // Norfolk Island
|
||||
"MP", // Northern Mariana Islands
|
||||
"NO", // Norway
|
||||
"OM", // Oman
|
||||
"PK", // Pakistan
|
||||
"PW", // Palau
|
||||
"PS", // Palestine
|
||||
"PA", // Panama
|
||||
"PG", // Papua New Guinea
|
||||
"PY", // Paraguay
|
||||
"PE", // Peru
|
||||
"PH", // Philippines
|
||||
"PN", // Pitcairn Islands
|
||||
"PL", // Poland
|
||||
"PT", // Portugal
|
||||
"PR", // Puerto Rico
|
||||
"QA", // Qatar
|
||||
"CY", // Republic of Cyprus
|
||||
"CG", // Republic of the Congo
|
||||
"RO", // Romania
|
||||
"RW", // Rwanda
|
||||
"BL", // Saint Barthélemy
|
||||
"KN", // Saint Kitts and Nevis
|
||||
"LC", // Saint Lucia
|
||||
"PM", // Saint Pierre and Miquelon
|
||||
"VC", // Saint Vincent and the Grenadines
|
||||
"SH", // Saint Helena, Ascension and Tristan da Cunha
|
||||
"WS", // Samoa
|
||||
"ST", // São Tomé and Príncipe
|
||||
"SA", // Saudi Arabia
|
||||
"SN", // Senegal
|
||||
"SC", // Seychelles
|
||||
"SL", // Sierra Leone
|
||||
"SG", // Singapore
|
||||
"SK", // Slovakia
|
||||
"SI", // Slovenia
|
||||
"SB", // Solomon Islands
|
||||
"SO", // Somalia
|
||||
"ZA", // South Africa
|
||||
"GS", // South Georgia and the South Sandwich Islands
|
||||
"KR", // South Korea
|
||||
"SS", // South Sudan
|
||||
"ES", // Spain
|
||||
"LK", // Sri Lanka
|
||||
"SD", // Sudan
|
||||
"SR", // Suriname
|
||||
"SE", // Sweden
|
||||
"CH", // Switzerland
|
||||
"TW", // Taiwan
|
||||
"TJ", // Tajikistan
|
||||
"TZ", // Tanzania
|
||||
"TH", // Thailand
|
||||
"TL", // Timor-Leste
|
||||
"TG", // Togo
|
||||
"TK", // Tokelau
|
||||
"TO", // Tonga
|
||||
"TT", // Trinidad and Tobago
|
||||
"TN", // Tunisia
|
||||
"TR", // Türkiye
|
||||
"TM", // Turkmenistan
|
||||
"TC", // Turks and Caicos Islands
|
||||
"TV", // Tuvalu
|
||||
"UG", // Uganda
|
||||
"GB", // United Kingdom
|
||||
"AE", // United Arab Emirates
|
||||
"US", // United States
|
||||
"UM", // United States Minor Outlying Islands
|
||||
"VI", // U.S. Virgin Islands
|
||||
"UY", // Uruguay
|
||||
"UZ", // Uzbekistan
|
||||
"VU", // Vanuatu
|
||||
"VE", // Venezuela
|
||||
"VN", // Vietnam
|
||||
"WF", // Wallis and Futuna
|
||||
"EH", // Western Sahara
|
||||
"YE", // Yemen
|
||||
"ZM", // Zambia
|
||||
"ZW", // Zimbabwe
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
};
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
|
||||
@@ -168,7 +168,12 @@ impl LanguageModel for FakeLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
self.current_completion_txs.lock().push((request, tx));
|
||||
async move {
|
||||
|
||||
@@ -76,6 +76,19 @@ pub enum LanguageModelCompletionEvent {
|
||||
UsageUpdate(TokenUsage),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("received bad input JSON")]
|
||||
BadInputJson {
|
||||
id: LanguageModelToolUseId,
|
||||
tool_name: Arc<str>,
|
||||
raw_input: Arc<str>,
|
||||
json_parse_error: String,
|
||||
},
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Indicates the format used to define the input schema for a language model tool.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub enum LanguageModelToolSchemaFormat {
|
||||
@@ -193,7 +206,7 @@ pub struct LanguageModelToolUse {
|
||||
|
||||
pub struct LanguageModelTextStream {
|
||||
pub message_id: Option<String>,
|
||||
pub stream: BoxStream<'static, Result<String>>,
|
||||
pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
|
||||
// Has complete token usage after the stream has finished
|
||||
pub last_token_usage: Arc<Mutex<TokenUsage>>,
|
||||
}
|
||||
@@ -246,7 +259,12 @@ pub trait LanguageModel: Send + Sync {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
>;
|
||||
|
||||
fn stream_completion_with_usage(
|
||||
&self,
|
||||
@@ -255,7 +273,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<(
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
Option<RequestUsage>,
|
||||
)>,
|
||||
> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use collections::{HashSet, IndexMap};
|
||||
use feature_flags::{Assistant2FeatureFlag, ZedPro};
|
||||
use feature_flags::{Assistant2FeatureFlag, ZedProFeatureFlag};
|
||||
use gpui::{
|
||||
Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
|
||||
Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
|
||||
@@ -584,7 +584,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
.p_1()
|
||||
.gap_4()
|
||||
.justify_between()
|
||||
.when(cx.has_flag::<ZedPro>(), |this| {
|
||||
.when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
|
||||
this.child(match plan {
|
||||
Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
|
||||
.icon(IconName::ZedAssistant)
|
||||
|
||||
@@ -71,7 +71,7 @@ fn register_language_model_providers(
|
||||
);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
|
||||
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
|
||||
cx.observe_flag::<feature_flags::LanguageModelsFeatureFlag, _>(move |enabled, cx| {
|
||||
let user_store = user_store.clone();
|
||||
let client = client.clone();
|
||||
LanguageModelRegistry::global(cx).update(cx, move |registry, cx| {
|
||||
|
||||
@@ -12,10 +12,10 @@ use gpui::{
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
|
||||
RateLimiter, Role,
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
|
||||
};
|
||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use schemars::JsonSchema;
|
||||
@@ -27,7 +27,7 @@ use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::{ResultExt, maybe};
|
||||
use util::ResultExt;
|
||||
|
||||
const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
|
||||
const PROVIDER_NAME: &str = "Anthropic";
|
||||
@@ -448,7 +448,12 @@ impl LanguageModel for AnthropicModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_anthropic(
|
||||
request,
|
||||
self.model.request_id().into(),
|
||||
@@ -626,7 +631,7 @@ pub fn into_anthropic(
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
@@ -740,30 +745,32 @@ pub fn map_to_language_model_completion_events(
|
||||
Event::ContentBlockStop { index } => {
|
||||
if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
|
||||
let input_json = tool_use.input_json.trim();
|
||||
let input_value = if input_json.is_empty() {
|
||||
Ok(serde_json::Value::Object(serde_json::Map::default()))
|
||||
} else {
|
||||
serde_json::Value::from_str(input_json)
|
||||
};
|
||||
let event_result = match input_value {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
},
|
||||
)),
|
||||
Err(json_parse_err) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_use.id.into(),
|
||||
tool_name: tool_use.name.into(),
|
||||
raw_input: input_json.into(),
|
||||
json_parse_error: json_parse_err.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
return Some((
|
||||
vec![maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_use.id.into(),
|
||||
name: tool_use.name.into(),
|
||||
is_input_complete: true,
|
||||
input: if input_json.is_empty() {
|
||||
serde_json::Value::Object(
|
||||
serde_json::Map::default(),
|
||||
)
|
||||
} else {
|
||||
serde_json::Value::from_str(
|
||||
input_json
|
||||
)
|
||||
.map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
|
||||
},
|
||||
raw_input: tool_use.input_json.clone(),
|
||||
},
|
||||
))
|
||||
})],
|
||||
state,
|
||||
));
|
||||
return Some((vec![event_result], state));
|
||||
}
|
||||
}
|
||||
Event::MessageStart { message } => {
|
||||
@@ -810,14 +817,19 @@ pub fn map_to_language_model_completion_events(
|
||||
}
|
||||
Event::Error { error } => {
|
||||
return Some((
|
||||
vec![Err(anyhow!(AnthropicError::ApiError(error)))],
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
AnthropicError::ApiError(error)
|
||||
)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Err(err) => {
|
||||
return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
|
||||
return Some((
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,9 +32,10 @@ use gpui_tokio::Tokio;
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, TokenUsage,
|
||||
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, TokenUsage,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -542,7 +543,12 @@ impl LanguageModel for BedrockModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
|
||||
// Get region - from credentials or directly from settings
|
||||
let region = state
|
||||
@@ -780,7 +786,7 @@ pub fn get_bedrock_tokens(
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
|
||||
handle: Handle,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
struct RawToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
@@ -971,7 +977,7 @@ pub fn map_to_language_model_completion_events(
|
||||
_ => {}
|
||||
},
|
||||
|
||||
Err(err) => return Some((Some(Err(anyhow!(err))), state)),
|
||||
Err(err) => return Some((Some(Err(anyhow!(err).into())), state)),
|
||||
}
|
||||
}
|
||||
None
|
||||
|
||||
@@ -2,7 +2,7 @@ use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use collections::BTreeMap;
|
||||
use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
|
||||
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
|
||||
stream::BoxStream,
|
||||
@@ -10,11 +10,11 @@ use futures::{
|
||||
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
|
||||
use language_model::{
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
|
||||
LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
|
||||
ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
|
||||
@@ -324,7 +324,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
|
||||
);
|
||||
}
|
||||
|
||||
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
|
||||
let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
|
||||
zed_cloud_provider_additional_models()
|
||||
} else {
|
||||
&[]
|
||||
@@ -531,7 +531,7 @@ impl CloudLanguageModel {
|
||||
{
|
||||
request_builder.uri(completions_url)
|
||||
} else {
|
||||
request_builder.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
|
||||
request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
|
||||
};
|
||||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
@@ -745,7 +745,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
self.stream_completion_with_usage(request, cx)
|
||||
.map(|result| result.map(|(stream, _)| stream))
|
||||
.boxed()
|
||||
@@ -758,7 +763,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<(
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent>>,
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
Option<RequestUsage>,
|
||||
)>,
|
||||
> {
|
||||
@@ -940,7 +945,7 @@ impl Render for ConfigurationView {
|
||||
),
|
||||
),
|
||||
)
|
||||
} else if cx.has_flag::<ZedPro>() {
|
||||
} else if cx.has_flag::<ZedProFeatureFlag>() {
|
||||
Some(
|
||||
h_flex()
|
||||
.gap_2()
|
||||
|
||||
@@ -17,16 +17,16 @@ use gpui::{
|
||||
Transformation, percentage, svg,
|
||||
};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolUse, MessageContent, RateLimiter, Role,
|
||||
StopReason,
|
||||
};
|
||||
use settings::SettingsStore;
|
||||
use std::time::Duration;
|
||||
use strum::IntoEnumIterator;
|
||||
use ui::prelude::*;
|
||||
use util::maybe;
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
use super::google::count_google_tokens;
|
||||
@@ -242,7 +242,12 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
if let Some(message) = request.messages.last() {
|
||||
if message.contents_empty() {
|
||||
const EMPTY_PROMPT_MSG: &str =
|
||||
@@ -285,7 +290,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||
is_streaming: bool,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
@@ -309,7 +314,7 @@ pub fn map_to_language_model_completion_events(
|
||||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no choices"))],
|
||||
vec![Err(anyhow!("Response contained no choices").into())],
|
||||
state,
|
||||
));
|
||||
};
|
||||
@@ -322,7 +327,7 @@ pub fn map_to_language_model_completion_events(
|
||||
|
||||
let Some(delta) = delta else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no delta"))],
|
||||
vec![Err(anyhow!("Response contained no delta").into())],
|
||||
state,
|
||||
));
|
||||
};
|
||||
@@ -361,20 +366,26 @@ pub fn map_to_language_model_completion_events(
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| {
|
||||
maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
)?,
|
||||
},
|
||||
))
|
||||
})
|
||||
|(_, tool_call)| match serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
@@ -393,7 +404,7 @@ pub fn map_to_language_model_completion_events(
|
||||
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => return Some((vec![Err(err)], state)),
|
||||
Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,9 +464,11 @@ impl CopilotChatLanguageModel {
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(ChatMessage::User {
|
||||
content: text_content,
|
||||
});
|
||||
if !text_content.is_empty() {
|
||||
messages.push(ChatMessage::User {
|
||||
content: text_content,
|
||||
});
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
@@ -9,9 +9,9 @@ use gpui::{
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -324,7 +324,12 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_deepseek(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
@@ -336,20 +341,22 @@ impl LanguageModel for DeepSeekLanguageModel {
|
||||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
|
||||
@@ -11,8 +11,9 @@ use gpui::{
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, MessageContent,
|
||||
StopReason,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
@@ -355,12 +356,19 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.await.map_err(|err| anyhow!(err))?;
|
||||
let response = request
|
||||
.await
|
||||
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
|
||||
Ok(map_to_language_model_completion_events(response))
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
@@ -396,7 +404,6 @@ pub fn into_google(
|
||||
Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
|
||||
function_call: google_ai::FunctionCall {
|
||||
name: tool_use.name.to_string(),
|
||||
raw_args: tool_use.raw_input,
|
||||
args: tool_use.input,
|
||||
},
|
||||
}))
|
||||
@@ -472,7 +479,7 @@ pub fn into_google(
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
@@ -493,7 +500,7 @@ pub fn map_to_language_model_completion_events(
|
||||
if let Some(event) = state.events.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
let mut events: Vec<Result<LanguageModelCompletionEvent>> = Vec::new();
|
||||
let mut events: Vec<_> = Vec::new();
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut state.usage, &usage_metadata);
|
||||
@@ -540,8 +547,8 @@ pub fn map_to_language_model_completion_events(
|
||||
is_input_complete: true,
|
||||
raw_input: function_call_part
|
||||
.function_call
|
||||
.raw_args
|
||||
.clone(),
|
||||
.args
|
||||
.to_string(),
|
||||
input: function_call_part.function_call.args,
|
||||
},
|
||||
)));
|
||||
@@ -560,7 +567,10 @@ pub fn map_to_language_model_completion_events(
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => {
|
||||
return Some((vec![Err(anyhow!(err))], state));
|
||||
return Some((
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
|
||||
state,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
@@ -310,7 +312,12 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = self.to_lmstudio_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
@@ -364,7 +371,11 @@ impl LanguageModel for LmStudioLanguageModel {
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.map(|result| {
|
||||
result
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
|
||||
@@ -8,9 +8,9 @@ use gpui::{
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
|
||||
use futures::stream::BoxStream;
|
||||
@@ -344,7 +344,12 @@ impl LanguageModel for MistralLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = into_mistral(
|
||||
request,
|
||||
self.model.id().to_string(),
|
||||
@@ -356,20 +361,22 @@ impl LanguageModel for MistralLanguageModel {
|
||||
let stream = stream.await?;
|
||||
Ok(stream
|
||||
.map(|result| {
|
||||
result.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
result
|
||||
.and_then(|response| {
|
||||
response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| anyhow!("Empty response"))
|
||||
.map(|choice| {
|
||||
choice
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
})
|
||||
})
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ use anyhow::{Result, anyhow};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{AuthenticateError, LanguageModelCompletionEvent};
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
@@ -322,7 +324,12 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
>,
|
||||
> {
|
||||
let request = self.to_ollama_request(request);
|
||||
|
||||
let http_client = self.http_client.clone();
|
||||
@@ -356,7 +363,11 @@ impl LanguageModel for OllamaLanguageModel {
|
||||
async move {
|
||||
Ok(future
|
||||
.await?
|
||||
.map(|result| result.map(LanguageModelCompletionEvent::Text))
|
||||
.map(|result| {
|
||||
result
|
||||
.map(LanguageModelCompletionEvent::Text)
|
||||
.map_err(LanguageModelCompletionError::Other)
|
||||
})
|
||||
.boxed())
|
||||
}
|
||||
.boxed()
|
||||
|
||||
@@ -9,10 +9,10 @@ use gpui::{
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
|
||||
};
|
||||
use open_ai::{Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
@@ -24,7 +24,7 @@ use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::{ResultExt, maybe};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
@@ -321,7 +321,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
@@ -419,7 +424,7 @@ pub fn into_open_ai(
|
||||
|
||||
pub fn map_to_language_model_completion_events(
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
@@ -443,7 +448,9 @@ pub fn map_to_language_model_completion_events(
|
||||
Ok(event) => {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Some((
|
||||
vec![Err(anyhow!("Response contained no choices"))],
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))],
|
||||
state,
|
||||
));
|
||||
};
|
||||
@@ -484,20 +491,26 @@ pub fn map_to_language_model_completion_events(
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(state.tool_calls_by_index.drain().map(
|
||||
|(_, tool_call)| {
|
||||
maybe!({
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
input: serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
)?,
|
||||
},
|
||||
))
|
||||
})
|
||||
|(_, tool_call)| match serde_json::Value::from_str(
|
||||
&tool_call.arguments,
|
||||
) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => {
|
||||
Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
));
|
||||
|
||||
@@ -516,7 +529,9 @@ pub fn map_to_language_model_completion_events(
|
||||
|
||||
return Some((events, state));
|
||||
}
|
||||
Err(err) => return Some((vec![Err(err)], state)),
|
||||
Err(err) => {
|
||||
return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ fn get_max_tokens(name: &str) -> usize {
|
||||
"mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "qwen2.5-coder"
|
||||
| "dolphin-mixtral" => 32768,
|
||||
"llama3.1" | "llama3.2" | "llama3.3" | "phi3" | "phi3.5" | "phi4" | "command-r"
|
||||
| "deepseek-coder-v2" | "deepseek-r1" | "yi-coder" => 128000,
|
||||
| "deepseek-coder-v2" | "deepseek-v3" | "deepseek-r1" | "yi-coder" => 128000,
|
||||
_ => DEFAULT_TOKENS,
|
||||
}
|
||||
.clamp(1, MAXIMUM_TOKENS)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
mod supported_countries;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use futures::{
|
||||
AsyncBufReadExt, AsyncReadExt, StreamExt,
|
||||
@@ -15,8 +13,6 @@ use std::{
|
||||
};
|
||||
use strum::EnumIter;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Returns whether the given country code is supported by OpenAI.
|
||||
///
|
||||
/// <https://platform.openai.com/docs/supported-countries>
|
||||
pub fn is_supported_country(country_code: &str) -> bool {
|
||||
SUPPORTED_COUNTRIES.contains(&country_code)
|
||||
}
|
||||
|
||||
/// The list of country codes supported by OpenAI.
|
||||
///
|
||||
/// https://platform.openai.com/docs/supported-countries
|
||||
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
vec![
|
||||
"AL", // Albania
|
||||
"DZ", // Algeria
|
||||
"AS", // American Samoa (US)
|
||||
"AI", // Anguilla (UK)
|
||||
"AF", // Afghanistan
|
||||
"AD", // Andorra
|
||||
"AO", // Angola
|
||||
"AG", // Antigua and Barbuda
|
||||
"AR", // Argentina
|
||||
"AM", // Armenia
|
||||
"AU", // Australia
|
||||
"AT", // Austria
|
||||
"AZ", // Azerbaijan
|
||||
"BS", // Bahamas
|
||||
"BH", // Bahrain
|
||||
"BD", // Bangladesh
|
||||
"BB", // Barbados
|
||||
"BE", // Belgium
|
||||
"BZ", // Belize
|
||||
"BJ", // Benin
|
||||
"BM", // Bermuda (UK)
|
||||
"BT", // Bhutan
|
||||
"BO", // Bolivia
|
||||
"BA", // Bosnia and Herzegovina
|
||||
"BW", // Botswana
|
||||
"BR", // Brazil
|
||||
"IO", // British Indian Ocean Territory (UK)
|
||||
"BN", // Brunei
|
||||
"BG", // Bulgaria
|
||||
"BF", // Burkina Faso
|
||||
"BI", // Burundi
|
||||
"CV", // Cabo Verde
|
||||
"KH", // Cambodia
|
||||
"CM", // Cameroon
|
||||
"CA", // Canada
|
||||
"KY", // Cayman Islands (UK)
|
||||
"CF", // Central African Republic
|
||||
"TD", // Chad
|
||||
"CL", // Chile
|
||||
"CX", // Christmas Island (AU)
|
||||
"CC", // Cocos (Keeling) Islands (AU)
|
||||
"CO", // Colombia
|
||||
"KM", // Comoros
|
||||
"CG", // Congo (Brazzaville)
|
||||
"CD", // Congo (DRC)
|
||||
"CK", // Cook Islands (NZ)
|
||||
"CR", // Costa Rica
|
||||
"CI", // Côte d'Ivoire
|
||||
"HR", // Croatia
|
||||
"CY", // Cyprus
|
||||
"CZ", // Czechia (Czech Republic)
|
||||
"DK", // Denmark
|
||||
"DJ", // Djibouti
|
||||
"DM", // Dominica
|
||||
"DO", // Dominican Republic
|
||||
"EC", // Ecuador
|
||||
"EG", // Egypt
|
||||
"SV", // El Salvador
|
||||
"GQ", // Equatorial Guinea
|
||||
"ER", // Eritrea
|
||||
"EE", // Estonia
|
||||
"SZ", // Eswatini (Swaziland)
|
||||
"ET", // Ethiopia
|
||||
"FK", // Falkland Islands (UK)
|
||||
"FJ", // Fiji
|
||||
"FI", // Finland
|
||||
"FR", // France
|
||||
"GF", // French Guiana (FR)
|
||||
"PF", // French Polynesia (FR)
|
||||
"TF", // French Southern Territories
|
||||
"GA", // Gabon
|
||||
"GM", // Gambia
|
||||
"GE", // Georgia
|
||||
"DE", // Germany
|
||||
"GH", // Ghana
|
||||
"GI", // Gibraltar (UK)
|
||||
"GR", // Greece
|
||||
"GD", // Grenada
|
||||
"GT", // Guatemala
|
||||
"GU", // Guam (US)
|
||||
"GN", // Guinea
|
||||
"GW", // Guinea-Bissau
|
||||
"GY", // Guyana
|
||||
"HT", // Haiti
|
||||
"HM", // Heard Island and McDonald Islands (AU)
|
||||
"VA", // Holy See (Vatican City)
|
||||
"HN", // Honduras
|
||||
"HU", // Hungary
|
||||
"IS", // Iceland
|
||||
"IN", // India
|
||||
"ID", // Indonesia
|
||||
"IQ", // Iraq
|
||||
"IE", // Ireland
|
||||
"IL", // Israel
|
||||
"IT", // Italy
|
||||
"JM", // Jamaica
|
||||
"JP", // Japan
|
||||
"JO", // Jordan
|
||||
"KZ", // Kazakhstan
|
||||
"KE", // Kenya
|
||||
"KI", // Kiribati
|
||||
"KW", // Kuwait
|
||||
"KG", // Kyrgyzstan
|
||||
"LA", // Laos
|
||||
"LV", // Latvia
|
||||
"LB", // Lebanon
|
||||
"LS", // Lesotho
|
||||
"LR", // Liberia
|
||||
"LY", // Libya
|
||||
"LI", // Liechtenstein
|
||||
"LT", // Lithuania
|
||||
"LU", // Luxembourg
|
||||
"MG", // Madagascar
|
||||
"MW", // Malawi
|
||||
"MY", // Malaysia
|
||||
"MV", // Maldives
|
||||
"ML", // Mali
|
||||
"MT", // Malta
|
||||
"MH", // Marshall Islands
|
||||
"MR", // Mauritania
|
||||
"MU", // Mauritius
|
||||
"MX", // Mexico
|
||||
"FM", // Micronesia
|
||||
"MD", // Moldova
|
||||
"MC", // Monaco
|
||||
"MN", // Mongolia
|
||||
"MS", // Montserrat (UK)
|
||||
"ME", // Montenegro
|
||||
"MA", // Morocco
|
||||
"MZ", // Mozambique
|
||||
"MM", // Myanmar
|
||||
"NA", // Namibia
|
||||
"NR", // Nauru
|
||||
"NP", // Nepal
|
||||
"NL", // Netherlands
|
||||
"NZ", // New Zealand
|
||||
"NI", // Nicaragua
|
||||
"NE", // Niger
|
||||
"NG", // Nigeria
|
||||
"NF", // Norfolk Island (AU)
|
||||
"MK", // North Macedonia
|
||||
"MI", // Northern Mariana Islands (UK)
|
||||
"NO", // Norway
|
||||
"NU", // Niue (NZ)
|
||||
"OM", // Oman
|
||||
"PK", // Pakistan
|
||||
"PW", // Palau
|
||||
"PS", // Palestine
|
||||
"PA", // Panama
|
||||
"PG", // Papua New Guinea
|
||||
"PY", // Paraguay
|
||||
"PE", // Peru
|
||||
"PH", // Philippines
|
||||
"PN", // Pitcairn (UK)
|
||||
"PL", // Poland
|
||||
"PT", // Portugal
|
||||
"PR", // Puerto Rico (US)
|
||||
"QA", // Qatar
|
||||
"RO", // Romania
|
||||
"RW", // Rwanda
|
||||
"BL", // Saint Barthélemy (FR)
|
||||
"KN", // Saint Kitts and Nevis
|
||||
"LC", // Saint Lucia
|
||||
"MF", // Saint Martin (FR)
|
||||
"PM", // Saint Pierre and Miquelon (FR)
|
||||
"VC", // Saint Vincent and the Grenadines
|
||||
"WS", // Samoa
|
||||
"SM", // San Marino
|
||||
"ST", // Sao Tome and Principe
|
||||
"SA", // Saudi Arabia
|
||||
"SN", // Senegal
|
||||
"RS", // Serbia
|
||||
"SC", // Seychelles
|
||||
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
|
||||
"SL", // Sierra Leone
|
||||
"SG", // Singapore
|
||||
"SK", // Slovakia
|
||||
"SI", // Slovenia
|
||||
"SB", // Solomon Islands
|
||||
"SO", // Somalia
|
||||
"ZA", // South Africa
|
||||
"KR", // South Korea
|
||||
"SS", // South Sudan
|
||||
"ES", // Spain
|
||||
"LK", // Sri Lanka
|
||||
"SR", // Suriname
|
||||
"SE", // Sweden
|
||||
"CH", // Switzerland
|
||||
"SD", // Sudan
|
||||
"TW", // Taiwan
|
||||
"TJ", // Tajikistan
|
||||
"TZ", // Tanzania
|
||||
"TH", // Thailand
|
||||
"TL", // Timor-Leste (East Timor)
|
||||
"TG", // Togo
|
||||
"TK", // Tokelau (NZ)
|
||||
"TO", // Tonga
|
||||
"TT", // Trinidad and Tobago
|
||||
"TN", // Tunisia
|
||||
"TR", // Turkey
|
||||
"TM", // Turkmenistan
|
||||
"TC", // Turks and Caicos Islands (UK)
|
||||
"TV", // Tuvalu
|
||||
"UG", // Uganda
|
||||
"UA", // Ukraine (with certain exceptions)
|
||||
"AE", // United Arab Emirates
|
||||
"GB", // United Kingdom
|
||||
"UM", // United States Minor Outlying Islands (US)
|
||||
"US", // United States of America
|
||||
"UY", // Uruguay
|
||||
"UZ", // Uzbekistan
|
||||
"VU", // Vanuatu
|
||||
"VN", // Vietnam
|
||||
"VI", // Virgin Islands (US)
|
||||
"VG", // Virgin Islands (UK)
|
||||
"WF", // Wallis and Futuna (FR)
|
||||
"YE", // Yemen
|
||||
"ZM", // Zambia
|
||||
"ZW", // Zimbabwe
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
@@ -31,14 +31,6 @@ pub enum TerminalKind {
|
||||
Shell(Option<PathBuf>),
|
||||
/// Run a task.
|
||||
Task(SpawnInTerminal),
|
||||
/// Run a debug terminal.
|
||||
Debug {
|
||||
command: Option<String>,
|
||||
args: Vec<String>,
|
||||
envs: HashMap<String, String>,
|
||||
cwd: Option<Arc<Path>>,
|
||||
title: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// SshCommand describes how to connect to a remote server
|
||||
@@ -105,7 +97,6 @@ impl Project {
|
||||
self.active_project_directory(cx)
|
||||
}
|
||||
}
|
||||
TerminalKind::Debug { cwd, .. } => cwd.clone(),
|
||||
};
|
||||
|
||||
let mut settings_location = None;
|
||||
@@ -209,7 +200,6 @@ impl Project {
|
||||
this.active_project_directory(cx)
|
||||
}
|
||||
}
|
||||
TerminalKind::Debug { cwd, .. } => cwd.clone(),
|
||||
};
|
||||
let ssh_details = this.ssh_details(cx);
|
||||
|
||||
@@ -243,7 +233,6 @@ impl Project {
|
||||
};
|
||||
|
||||
let mut python_venv_activate_command = None;
|
||||
let debug_terminal = matches!(kind, TerminalKind::Debug { .. });
|
||||
|
||||
let (spawn_task, shell) = match kind {
|
||||
TerminalKind::Shell(_) => {
|
||||
@@ -339,27 +328,6 @@ impl Project {
|
||||
}
|
||||
}
|
||||
}
|
||||
TerminalKind::Debug {
|
||||
command,
|
||||
args,
|
||||
envs,
|
||||
title,
|
||||
..
|
||||
} => {
|
||||
env.extend(envs);
|
||||
|
||||
let shell = if let Some(program) = command {
|
||||
Shell::WithArguments {
|
||||
program,
|
||||
args,
|
||||
title_override: Some(title.unwrap_or("Debug Terminal".into()).into()),
|
||||
}
|
||||
} else {
|
||||
settings.shell.clone()
|
||||
};
|
||||
|
||||
(None, shell)
|
||||
}
|
||||
};
|
||||
TerminalBuilder::new(
|
||||
local_path.map(|path| path.to_path_buf()),
|
||||
@@ -373,7 +341,6 @@ impl Project {
|
||||
ssh_details.is_some(),
|
||||
window,
|
||||
completion_tx,
|
||||
debug_terminal,
|
||||
cx,
|
||||
)
|
||||
.map(|builder| {
|
||||
|
||||
@@ -60,9 +60,7 @@ pub enum PromptId {
|
||||
|
||||
impl PromptId {
|
||||
pub fn new() -> PromptId {
|
||||
PromptId::User {
|
||||
uuid: UserPromptId::new(),
|
||||
}
|
||||
UserPromptId::new().into()
|
||||
}
|
||||
|
||||
pub fn is_built_in(&self) -> bool {
|
||||
@@ -70,6 +68,12 @@ impl PromptId {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UserPromptId> for PromptId {
|
||||
fn from(uuid: UserPromptId) -> Self {
|
||||
PromptId::User { uuid }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct UserPromptId(pub Uuid);
|
||||
@@ -227,9 +231,7 @@ impl PromptStore {
|
||||
.collect::<heed::Result<HashMap<_, _>>>()?;
|
||||
|
||||
for (prompt_id_v1, metadata_v1) in metadata_v1 {
|
||||
let prompt_id_v2 = PromptId::User {
|
||||
uuid: UserPromptId(prompt_id_v1.0),
|
||||
};
|
||||
let prompt_id_v2 = UserPromptId(prompt_id_v1.0).into();
|
||||
let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "prompt_library"
|
||||
name = "rules_library"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/prompt_library.rs"
|
||||
path = "src/rules_library.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
File diff suppressed because it is too large
Load Diff
@@ -352,7 +352,6 @@ impl TerminalBuilder {
|
||||
is_ssh_terminal: bool,
|
||||
window: AnyWindowHandle,
|
||||
completion_tx: Sender<Option<ExitStatus>>,
|
||||
debug_terminal: bool,
|
||||
cx: &App,
|
||||
) -> Result<TerminalBuilder> {
|
||||
// If the parent environment doesn't have a locale set
|
||||
@@ -502,7 +501,6 @@ impl TerminalBuilder {
|
||||
word_regex: RegexSearch::new(WORD_REGEX).unwrap(),
|
||||
python_file_line_regex: RegexSearch::new(PYTHON_FILE_LINE_REGEX).unwrap(),
|
||||
vi_mode_enabled: false,
|
||||
debug_terminal,
|
||||
is_ssh_terminal,
|
||||
python_venv_directory,
|
||||
};
|
||||
@@ -660,7 +658,6 @@ pub struct Terminal {
|
||||
python_file_line_regex: RegexSearch,
|
||||
task: Option<TaskState>,
|
||||
vi_mode_enabled: bool,
|
||||
debug_terminal: bool,
|
||||
is_ssh_terminal: bool,
|
||||
}
|
||||
|
||||
@@ -1855,10 +1852,6 @@ impl Terminal {
|
||||
self.task.as_ref()
|
||||
}
|
||||
|
||||
pub fn debug_terminal(&self) -> bool {
|
||||
self.debug_terminal
|
||||
}
|
||||
|
||||
pub fn wait_for_completed_task(&self, cx: &App) -> Task<Option<ExitStatus>> {
|
||||
if let Some(task) = self.task() {
|
||||
if task.status == TaskStatus::Running {
|
||||
|
||||
@@ -1420,9 +1420,6 @@ impl Item for TerminalView {
|
||||
}
|
||||
}
|
||||
},
|
||||
None if self.terminal.read(cx).debug_terminal() => {
|
||||
(IconName::Debug, Color::Muted, None)
|
||||
}
|
||||
None => (IconName::Terminal, Color::Muted, None),
|
||||
};
|
||||
|
||||
@@ -1583,7 +1580,7 @@ impl SerializableItem for TerminalView {
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<Task<gpui::Result<()>>> {
|
||||
let terminal = self.terminal().read(cx);
|
||||
if terminal.task().is_some() || terminal.debug_terminal() {
|
||||
if terminal.task().is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::platforms::{platform_linux, platform_mac, platform_windows};
|
||||
use auto_update::AutoUpdateStatus;
|
||||
use call::ActiveCall;
|
||||
use client::{Client, UserStore};
|
||||
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
||||
use feature_flags::{FeatureFlagAppExt, ZedProFeatureFlag};
|
||||
use gpui::{
|
||||
Action, AnyElement, App, Context, Corner, Decorations, Element, Entity, InteractiveElement,
|
||||
Interactivity, IntoElement, MouseButton, ParentElement, Render, Stateful,
|
||||
@@ -663,7 +663,7 @@ impl TitleBar {
|
||||
.anchor(Corner::TopRight)
|
||||
.menu(move |window, cx| {
|
||||
ContextMenu::build(window, cx, |menu, _, cx| {
|
||||
menu.when(cx.has_flag::<ZedPro>(), |menu| {
|
||||
menu.when(cx.has_flag::<ZedProFeatureFlag>(), |menu| {
|
||||
menu.action(
|
||||
format!(
|
||||
"Current Plan: {}",
|
||||
|
||||
@@ -142,6 +142,10 @@ impl PaneGroup {
|
||||
self.root.first_pane()
|
||||
}
|
||||
|
||||
pub fn last_pane(&self) -> Entity<Pane> {
|
||||
self.root.last_pane()
|
||||
}
|
||||
|
||||
pub fn find_pane_in_direction(
|
||||
&mut self,
|
||||
active_pane: &Entity<Pane>,
|
||||
@@ -360,6 +364,13 @@ impl Member {
|
||||
}
|
||||
}
|
||||
|
||||
fn last_pane(&self) -> Entity<Pane> {
|
||||
match self {
|
||||
Member::Axis(axis) => axis.members.last().unwrap().last_pane(),
|
||||
Member::Pane(pane) => pane.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(
|
||||
&self,
|
||||
basis: usize,
|
||||
|
||||
@@ -20,7 +20,7 @@ use command_palette_hooks::CommandPaletteFilter;
|
||||
use debugger_ui::debugger_panel::DebugPanel;
|
||||
use editor::ProposedChangesEditorToolbar;
|
||||
use editor::{Editor, MultiBuffer, scroll::Autoscroll};
|
||||
use feature_flags::{Debugger, FeatureFlagAppExt, FeatureFlagViewExt};
|
||||
use feature_flags::{DebuggerFeatureFlag, FeatureFlagAppExt, FeatureFlagViewExt};
|
||||
use futures::{StreamExt, channel::mpsc, select_biased};
|
||||
use git_ui::git_panel::GitPanel;
|
||||
use git_ui::project_diff::ProjectDiffToolbar;
|
||||
@@ -279,7 +279,7 @@ fn feature_gate_zed_pro_actions(cx: &mut App) {
|
||||
filter.hide_action_types(&zed_pro_actions);
|
||||
});
|
||||
|
||||
cx.observe_flag::<feature_flags::ZedPro, _>({
|
||||
cx.observe_flag::<feature_flags::ZedProFeatureFlag, _>({
|
||||
move |is_enabled, cx| {
|
||||
CommandPaletteFilter::update_global(cx, |filter, _cx| {
|
||||
if is_enabled {
|
||||
@@ -439,7 +439,7 @@ fn initialize_panels(
|
||||
workspace.add_panel(channels_panel, window, cx);
|
||||
workspace.add_panel(chat_panel, window, cx);
|
||||
workspace.add_panel(notification_panel, window, cx);
|
||||
cx.when_flag_enabled::<Debugger>(window, |_, window, cx| {
|
||||
cx.when_flag_enabled::<DebuggerFeatureFlag>(window, |_, window, cx| {
|
||||
cx.spawn_in(
|
||||
window,
|
||||
async move |workspace: gpui::WeakEntity<Workspace>,
|
||||
|
||||
@@ -199,14 +199,14 @@ pub mod assistant {
|
||||
|
||||
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct OpenPromptLibrary {
|
||||
pub struct OpenRulesLibrary {
|
||||
#[serde(skip)]
|
||||
pub prompt_to_select: Option<Uuid>,
|
||||
}
|
||||
|
||||
impl_action_with_deprecated_aliases!(
|
||||
assistant,
|
||||
OpenPromptLibrary,
|
||||
OpenRulesLibrary,
|
||||
["assistant::DeployPromptLibrary"]
|
||||
);
|
||||
|
||||
|
||||
@@ -2178,7 +2178,7 @@ Examples:
|
||||
|
||||
## Show Whitespaces
|
||||
|
||||
- Description: Whether or not to show render whitespace characters in the editor.
|
||||
- Description: Whether or not to render whitespace characters in the editor.
|
||||
- Setting: `show_whitespaces`
|
||||
- Default: `selection`
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user