Compare commits
2 Commits
streaming-
...
thomas/add
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcb71cda2a | ||
|
|
5145a8dfca |
8
.github/workflows/eval.yml
vendored
8
.github/workflows/eval.yml
vendored
@@ -2,7 +2,13 @@ name: Run Agent Eval
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
- cron: "0 * * * *"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "v[0-9]+.[0-9]+.x"
|
||||
tags:
|
||||
- "v*"
|
||||
|
||||
pull_request:
|
||||
branches:
|
||||
|
||||
515
Cargo.lock
generated
515
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -109,7 +109,7 @@ members = [
|
||||
"crates/project",
|
||||
"crates/project_panel",
|
||||
"crates/project_symbols",
|
||||
"crates/rules_library",
|
||||
"crates/prompt_library",
|
||||
"crates/prompt_store",
|
||||
"crates/proto",
|
||||
"crates/recent_projects",
|
||||
@@ -296,7 +296,6 @@ livekit_api = { path = "crates/livekit_api" }
|
||||
livekit_client = { path = "crates/livekit_client" }
|
||||
lmstudio = { path = "crates/lmstudio" }
|
||||
lsp = { path = "crates/lsp" }
|
||||
lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "c9c189f1c5dd53c624a419ce35bc77ad6a908d18" }
|
||||
markdown = { path = "crates/markdown" }
|
||||
markdown_preview = { path = "crates/markdown_preview" }
|
||||
media = { path = "crates/media" }
|
||||
@@ -319,7 +318,7 @@ prettier = { path = "crates/prettier" }
|
||||
project = { path = "crates/project" }
|
||||
project_panel = { path = "crates/project_panel" }
|
||||
project_symbols = { path = "crates/project_symbols" }
|
||||
rules_library = { path = "crates/rules_library" }
|
||||
prompt_library = { path = "crates/prompt_library" }
|
||||
prompt_store = { path = "crates/prompt_store" }
|
||||
proto = { path = "crates/proto" }
|
||||
recent_projects = { path = "crates/recent_projects" }
|
||||
@@ -500,7 +499,6 @@ prost-types = "0.9"
|
||||
pulldown-cmark = { version = "0.12.0", default-features = false }
|
||||
quote = "1.0.9"
|
||||
rand = "0.8.5"
|
||||
ref-cast = "1.0.24"
|
||||
rayon = "1.8"
|
||||
regex = "1.5"
|
||||
repair_json = "0.1.0"
|
||||
@@ -607,7 +605,7 @@ wasmtime-wasi = "29"
|
||||
which = "6.0.0"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "0.7.1"
|
||||
zed_llm_client = "0.6.1"
|
||||
zstd = "0.11"
|
||||
metal = "0.29"
|
||||
|
||||
|
||||
@@ -212,7 +212,7 @@
|
||||
"ctrl-shift-g": "search::SelectPreviousMatch",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-k h": "assistant::DeployHistory",
|
||||
"ctrl-k l": "assistant::OpenRulesLibrary",
|
||||
"ctrl-k l": "assistant::OpenPromptLibrary",
|
||||
"new": "assistant::NewChat",
|
||||
"ctrl-t": "assistant::NewChat",
|
||||
"ctrl-n": "assistant::NewChat"
|
||||
@@ -241,7 +241,7 @@
|
||||
"ctrl-alt-n": "agent::NewTextThread",
|
||||
"ctrl-shift-h": "agent::OpenHistory",
|
||||
"ctrl-alt-c": "agent::OpenConfiguration",
|
||||
"ctrl-alt-p": "assistant::OpenRulesLibrary",
|
||||
"ctrl-alt-p": "assistant::OpenPromptLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "assistant::ToggleModelSelector",
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -308,9 +308,9 @@
|
||||
{
|
||||
"context": "PromptLibrary",
|
||||
"bindings": {
|
||||
"new": "rules_library::NewRule",
|
||||
"ctrl-n": "rules_library::NewRule",
|
||||
"ctrl-shift-s": "rules_library::ToggleDefaultRule"
|
||||
"new": "prompt_library::NewPrompt",
|
||||
"ctrl-n": "prompt_library::NewPrompt",
|
||||
"ctrl-shift-s": "prompt_library::ToggleDefaultPrompt"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -675,7 +675,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"context": "Editor && mode == full",
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
"shift-enter": "editor::ExpandExcerpts",
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
"context": "PromptLibrary",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"cmd-n": "rules_library::NewRule",
|
||||
"cmd-shift-s": "rules_library::ToggleDefaultRule",
|
||||
"cmd-n": "prompt_library::NewPrompt",
|
||||
"cmd-shift-s": "prompt_library::ToggleDefaultPrompt",
|
||||
"cmd-w": "workspace::CloseWindow"
|
||||
}
|
||||
},
|
||||
@@ -257,7 +257,7 @@
|
||||
"cmd-shift-g": "search::SelectPreviousMatch",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-k h": "assistant::DeployHistory",
|
||||
"cmd-k l": "assistant::OpenRulesLibrary",
|
||||
"cmd-k l": "assistant::OpenPromptLibrary",
|
||||
"cmd-t": "assistant::NewChat",
|
||||
"cmd-n": "assistant::NewChat"
|
||||
}
|
||||
@@ -286,7 +286,7 @@
|
||||
"cmd-alt-n": "agent::NewTextThread",
|
||||
"cmd-shift-h": "agent::OpenHistory",
|
||||
"cmd-alt-c": "agent::OpenConfiguration",
|
||||
"cmd-alt-p": "assistant::OpenRulesLibrary",
|
||||
"cmd-alt-p": "assistant::OpenPromptLibrary",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "assistant::ToggleModelSelector",
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
@@ -738,7 +738,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"context": "!ContextEditor > Editor && mode == full",
|
||||
"context": "Editor && mode == full",
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"alt-enter": "editor::OpenExcerpts",
|
||||
|
||||
@@ -31,9 +31,6 @@ If appropriate, use tool calls to explore the current project, which contains th
|
||||
- When looking for symbols in the project, prefer the `grep` tool.
|
||||
- As you learn about the structure of the project, use that information to scope `grep` searches to targeted subtrees of the project.
|
||||
- Bias towards not asking the user for help if you can find the answer yourself.
|
||||
{{! TODO: Only mention tools if they are enabled }}
|
||||
- The user might specify a partial file path. If you don't know the full path, use `find_path` (not `grep`) before you read the file.
|
||||
- Before you read or edit a file, you must first find the full path. DO NOT ever guess a file path!
|
||||
|
||||
## Fixing Diagnostics
|
||||
|
||||
|
||||
@@ -646,7 +646,7 @@
|
||||
"fetch": true,
|
||||
"list_directory": false,
|
||||
"now": true,
|
||||
"find_path": true,
|
||||
"path_search": true,
|
||||
"read_file": true,
|
||||
"grep": true,
|
||||
"thinking": true,
|
||||
@@ -670,7 +670,7 @@
|
||||
"list_directory": true,
|
||||
"move_path": false,
|
||||
"now": false,
|
||||
"find_path": true,
|
||||
"path_search": true,
|
||||
"read_file": true,
|
||||
"grep": true,
|
||||
"rename": false,
|
||||
|
||||
@@ -1,7 +1,2 @@
|
||||
allow-private-module-inception = true
|
||||
avoid-breaking-exported-api = false
|
||||
ignore-interior-mutability = [
|
||||
# Suppresses clippy::mutable_key_type, which is a false positive as the Eq
|
||||
# and Hash impls do not use fields with interior mutability.
|
||||
"agent::context::AgentContextKey"
|
||||
]
|
||||
|
||||
@@ -28,6 +28,7 @@ async-watch.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -61,10 +62,9 @@ parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
picker.workspace = true
|
||||
project.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
ref-cast.workspace = true
|
||||
release_channel.workspace = true
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent, ui::AnimatedLabel};
|
||||
use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
|
||||
use anyhow::Result;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use collections::{HashMap, HashSet};
|
||||
@@ -8,8 +8,8 @@ use editor::{
|
||||
scroll::Autoscroll,
|
||||
};
|
||||
use gpui::{
|
||||
Action, AnyElement, AnyView, App, Empty, Entity, EventEmitter, FocusHandle, Focusable,
|
||||
SharedString, Subscription, Task, WeakEntity, Window, prelude::*,
|
||||
Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
|
||||
Subscription, Task, WeakEntity, Window, prelude::*,
|
||||
};
|
||||
use language::{Capability, DiskState, OffsetRangeExt, Point};
|
||||
use multi_buffer::PathKey;
|
||||
@@ -307,10 +307,6 @@ impl AgentDiff {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.thread.read(cx).is_generating() {
|
||||
return;
|
||||
}
|
||||
|
||||
let snapshot = self.multibuffer.read(cx).snapshot(cx);
|
||||
let diff_hunks_in_ranges = self
|
||||
.editor
|
||||
@@ -343,10 +339,6 @@ impl AgentDiff {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if self.thread.read(cx).is_generating() {
|
||||
return;
|
||||
}
|
||||
|
||||
let snapshot = self.multibuffer.read(cx).snapshot(cx);
|
||||
let diff_hunks_in_ranges = self
|
||||
.editor
|
||||
@@ -658,11 +650,6 @@ fn render_diff_hunk_controls(
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let editor = editor.clone();
|
||||
|
||||
if agent_diff.read(cx).thread.read(cx).is_generating() {
|
||||
return Empty.into_any();
|
||||
}
|
||||
|
||||
h_flex()
|
||||
.h(line_height)
|
||||
.mr_0p5()
|
||||
@@ -870,14 +857,8 @@ impl Render for AgentDiffToolbar {
|
||||
None => return div(),
|
||||
};
|
||||
|
||||
let is_generating = agent_diff.read(cx).thread.read(cx).is_generating();
|
||||
if is_generating {
|
||||
return div()
|
||||
.w(rems(6.5625)) // Arbitrary 105px size—so the label doesn't dance around
|
||||
.child(AnimatedLabel::new("Generating"));
|
||||
}
|
||||
|
||||
let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
|
||||
|
||||
if is_empty {
|
||||
return div();
|
||||
}
|
||||
@@ -962,13 +943,11 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let prompt_store = None;
|
||||
let thread_store = cx
|
||||
.update(|cx| {
|
||||
ThreadStore::load(
|
||||
project.clone(),
|
||||
cx.new(|_| ToolWorkingSet::default()),
|
||||
prompt_store,
|
||||
Arc::new(PromptBuilder::new(None).unwrap()),
|
||||
cx,
|
||||
)
|
||||
@@ -990,7 +969,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|_, cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit(
|
||||
|
||||
@@ -39,7 +39,6 @@ use thread::ThreadId;
|
||||
pub use crate::active_thread::ActiveThread;
|
||||
use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal};
|
||||
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
|
||||
pub use crate::context::{ContextLoadResult, LoadedContext};
|
||||
pub use crate::inline_assistant::InlineAssistant;
|
||||
pub use crate::thread::{Message, Thread, ThreadEvent};
|
||||
pub use crate::thread_store::ThreadStore;
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use assistant_settings::{
|
||||
AgentProfile, AgentProfileContent, AgentProfileId, AssistantSettings, AssistantSettingsContent,
|
||||
ContextServerPresetContent,
|
||||
ContextServerPresetContent, VersionedAssistantSettingsContent,
|
||||
};
|
||||
use assistant_tool::{ToolSource, ToolWorkingSet};
|
||||
use fs::Fs;
|
||||
@@ -201,10 +201,10 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
let profile_id = self.profile_id.clone();
|
||||
let default_profile = self.profile.clone();
|
||||
let tool = tool.clone();
|
||||
move |settings: &mut AssistantSettingsContent, _cx| {
|
||||
settings
|
||||
.v2_setting(|v2_settings| {
|
||||
let profiles = v2_settings.profiles.get_or_insert_default();
|
||||
move |settings, _cx| match settings {
|
||||
AssistantSettingsContent::Versioned(boxed) => {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
let profile =
|
||||
profiles
|
||||
.entry(profile_id)
|
||||
@@ -240,10 +240,9 @@ impl PickerDelegate for ToolPickerDelegate {
|
||||
*preset.tools.entry(tool.name.clone()).or_default() = is_enabled;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -12,21 +12,21 @@ use assistant_settings::{AssistantDockPosition, AssistantSettings};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
|
||||
use client::{UserStore, zed_urls};
|
||||
use client::zed_urls;
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, ClipboardItem,
|
||||
Corner, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels,
|
||||
Subscription, Task, UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
|
||||
Action, Animation, AnimationExt as _, AnyElement, App, AsyncWindowContext, Corner, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, FontWeight, KeyContext, Pixels, Subscription, Task,
|
||||
UpdateGlobal, WeakEntity, prelude::*, pulsating_between,
|
||||
};
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use project::Project;
|
||||
use prompt_store::{PromptBuilder, PromptStore, UserPromptId};
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
|
||||
use proto::Plan;
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use settings::{Settings, update_settings_file};
|
||||
use time::UtcOffset;
|
||||
use ui::{
|
||||
@@ -36,7 +36,7 @@ use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
use workspace::dock::{DockPosition, Panel, PanelEvent};
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
use zed_actions::assistant::{OpenRulesLibrary, ToggleFocus};
|
||||
use zed_actions::assistant::{OpenPromptLibrary, ToggleFocus};
|
||||
|
||||
use crate::active_thread::{ActiveThread, ActiveThreadEvent};
|
||||
use crate::assistant_configuration::{AssistantConfiguration, AssistantConfigurationEvent};
|
||||
@@ -79,11 +79,11 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
panel.deploy_prompt_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
})
|
||||
@@ -180,7 +180,6 @@ impl ActiveView {
|
||||
|
||||
pub struct AssistantPanel {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
user_store: Entity<UserStore>,
|
||||
project: Entity<Project>,
|
||||
fs: Arc<dyn Fs>,
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
@@ -189,7 +188,6 @@ pub struct AssistantPanel {
|
||||
message_editor: Entity<MessageEditor>,
|
||||
_active_thread_subscriptions: Vec<Subscription>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
configuration: Option<Entity<AssistantConfiguration>>,
|
||||
configuration_subscription: Option<Subscription>,
|
||||
local_timezone: UtcOffset,
|
||||
@@ -206,25 +204,14 @@ impl AssistantPanel {
|
||||
pub fn load(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
mut cx: AsyncWindowContext,
|
||||
cx: AsyncWindowContext,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = cx.update(|_window, cx| PromptStore::global(cx));
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = match prompt_store {
|
||||
Ok(prompt_store) => prompt_store.await.ok(),
|
||||
Err(_) => None,
|
||||
};
|
||||
let tools = cx.new(|_| ToolWorkingSet::default())?;
|
||||
let thread_store = workspace
|
||||
.update(cx, |workspace, cx| {
|
||||
let project = workspace.project().clone();
|
||||
ThreadStore::load(
|
||||
project,
|
||||
tools.clone(),
|
||||
prompt_store.clone(),
|
||||
prompt_builder.clone(),
|
||||
cx,
|
||||
)
|
||||
ThreadStore::load(project, tools.clone(), prompt_builder.clone(), cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
@@ -242,16 +229,7 @@ impl AssistantPanel {
|
||||
.await?;
|
||||
|
||||
workspace.update_in(cx, |workspace, window, cx| {
|
||||
cx.new(|cx| {
|
||||
Self::new(
|
||||
workspace,
|
||||
thread_store,
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -260,13 +238,11 @@ impl AssistantPanel {
|
||||
workspace: &Workspace,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
context_store: Entity<assistant_context_editor::ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
|
||||
let fs = workspace.app_state().fs.clone();
|
||||
let user_store = workspace.app_state().user_store.clone();
|
||||
let project = workspace.project();
|
||||
let language_registry = project.read(cx).languages().clone();
|
||||
let workspace = workspace.weak_handle();
|
||||
@@ -284,7 +260,6 @@ impl AssistantPanel {
|
||||
fs.clone(),
|
||||
workspace.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
prompt_store.clone(),
|
||||
thread_store.downgrade(),
|
||||
thread.clone(),
|
||||
window,
|
||||
@@ -316,6 +291,7 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
thread_store.clone(),
|
||||
language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -331,7 +307,6 @@ impl AssistantPanel {
|
||||
Self {
|
||||
active_view,
|
||||
workspace,
|
||||
user_store,
|
||||
project: project.clone(),
|
||||
fs: fs.clone(),
|
||||
language_registry,
|
||||
@@ -344,7 +319,6 @@ impl AssistantPanel {
|
||||
message_editor_subscription,
|
||||
],
|
||||
context_store,
|
||||
prompt_store,
|
||||
configuration: None,
|
||||
configuration_subscription: None,
|
||||
local_timezone: UtcOffset::from_whole_seconds(
|
||||
@@ -378,17 +352,18 @@ impl AssistantPanel {
|
||||
self.local_timezone
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
}
|
||||
|
||||
pub(crate) fn thread_store(&self) -> &Entity<ThreadStore> {
|
||||
&self.thread_store
|
||||
}
|
||||
|
||||
fn cancel(&mut self, _: &editor::actions::Cancel, window: &mut Window, cx: &mut Context<Self>) {
|
||||
fn cancel(
|
||||
&mut self,
|
||||
_: &editor::actions::Cancel,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(window, cx));
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
|
||||
}
|
||||
|
||||
fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context<Self>) {
|
||||
@@ -438,6 +413,7 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
self.thread_store.clone(),
|
||||
self.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
self.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -456,7 +432,6 @@ impl AssistantPanel {
|
||||
self.fs.clone(),
|
||||
self.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
self.prompt_store.clone(),
|
||||
self.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -511,13 +486,13 @@ impl AssistantPanel {
|
||||
context_editor.focus_handle(cx).focus(window);
|
||||
}
|
||||
|
||||
fn deploy_rules_library(
|
||||
fn deploy_prompt_library(
|
||||
&mut self,
|
||||
action: &OpenRulesLibrary,
|
||||
action: &OpenPromptLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_rules_library(
|
||||
open_prompt_library(
|
||||
self.language_registry.clone(),
|
||||
Box::new(PromptLibraryInlineAssist::new(self.workspace.clone())),
|
||||
Arc::new(|| {
|
||||
@@ -527,9 +502,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
action.prompt_to_select.map(|uuid| PromptId::User {
|
||||
uuid: UserPromptId(uuid),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -625,6 +600,7 @@ impl AssistantPanel {
|
||||
thread.clone(),
|
||||
this.thread_store.clone(),
|
||||
this.language_registry.clone(),
|
||||
message_editor_context_store.clone(),
|
||||
this.workspace.clone(),
|
||||
window,
|
||||
cx,
|
||||
@@ -643,7 +619,6 @@ impl AssistantPanel {
|
||||
this.fs.clone(),
|
||||
this.workspace.clone(),
|
||||
message_editor_context_store,
|
||||
this.prompt_store.clone(),
|
||||
this.thread_store.downgrade(),
|
||||
thread,
|
||||
window,
|
||||
@@ -1147,7 +1122,7 @@ impl AssistantPanel {
|
||||
"New Text Thread",
|
||||
NewTextThread.boxed_clone(),
|
||||
)
|
||||
.action("Rules Library", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Prompt Library", Box::new(OpenPromptLibrary::default()))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.separator()
|
||||
.header("MCPs")
|
||||
@@ -1573,19 +1548,9 @@ impl AssistantPanel {
|
||||
}
|
||||
|
||||
fn render_usage_banner(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
let plan = self
|
||||
.user_store
|
||||
.read(cx)
|
||||
.current_plan()
|
||||
.map(|plan| match plan {
|
||||
Plan::Free => zed_llm_client::Plan::Free,
|
||||
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
})
|
||||
.unwrap_or(zed_llm_client::Plan::Free);
|
||||
let usage = self.thread.read(cx).last_usage()?;
|
||||
|
||||
Some(UsageBanner::new(plan, usage).into_any_element())
|
||||
Some(UsageBanner::new(zed_llm_client::Plan::ZedProTrial, usage.amount).into_any_element())
|
||||
}
|
||||
|
||||
fn render_last_error(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
|
||||
@@ -1640,8 +1605,6 @@ impl AssistantPanel {
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(ERROR_MESSAGE))
|
||||
.child(Button::new("subscribe", "Subscribe").on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
@@ -1688,8 +1651,6 @@ impl AssistantPanel {
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(ERROR_MESSAGE))
|
||||
.child(
|
||||
Button::new("subscribe", "Update Monthly Spend Limit").on_click(
|
||||
cx.listener(|this, _, _, cx| {
|
||||
@@ -1755,8 +1716,6 @@ impl AssistantPanel {
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(error_message))
|
||||
.child(
|
||||
Button::new("subscribe", call_to_action).on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
@@ -1788,7 +1747,6 @@ impl AssistantPanel {
|
||||
message: SharedString,
|
||||
cx: &mut Context<Self>,
|
||||
) -> AnyElement {
|
||||
let message_with_header = format!("{}\n{}", header, message);
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(
|
||||
@@ -1803,14 +1761,12 @@ impl AssistantPanel {
|
||||
.id("error-message")
|
||||
.max_h_32()
|
||||
.overflow_y_scroll()
|
||||
.child(Label::new(message.clone())),
|
||||
.child(Label::new(message)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.justify_end()
|
||||
.mt_1()
|
||||
.gap_1()
|
||||
.child(self.create_copy_button(message_with_header))
|
||||
.child(Button::new("dismiss", "Dismiss").on_click(cx.listener(
|
||||
|this, _, _, cx| {
|
||||
this.thread.update(cx, |this, _cx| {
|
||||
@@ -1824,15 +1780,6 @@ impl AssistantPanel {
|
||||
.into_any()
|
||||
}
|
||||
|
||||
fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
|
||||
let message = message.into();
|
||||
IconButton::new("copy", IconName::Copy)
|
||||
.on_click(move |_, _, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
|
||||
})
|
||||
.tooltip(Tooltip::text("Copy Error Message"))
|
||||
}
|
||||
|
||||
fn key_context(&self) -> KeyContext {
|
||||
let mut key_context = KeyContext::new_with_defaults();
|
||||
key_context.add("AgentPanel");
|
||||
@@ -1860,7 +1807,7 @@ impl Render for AssistantPanel {
|
||||
this.open_configuration(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(Self::open_active_thread_as_markdown))
|
||||
.on_action(cx.listener(Self::deploy_rules_library))
|
||||
.on_action(cx.listener(Self::deploy_prompt_library))
|
||||
.on_action(cx.listener(Self::open_agent_diff))
|
||||
.on_action(cx.listener(Self::go_back))
|
||||
.child(self.render_toolbar(window, cx))
|
||||
@@ -1887,13 +1834,13 @@ impl PromptLibraryInlineAssist {
|
||||
}
|
||||
}
|
||||
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
_initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
let Some(project) = self
|
||||
@@ -1903,14 +1850,11 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let prompt_store = None;
|
||||
let thread_store = None;
|
||||
assistant.assist(
|
||||
&prompt_editor,
|
||||
self.workspace.clone(),
|
||||
project,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -1989,8 +1933,8 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
// being updated.
|
||||
cx.defer_in(window, move |panel, window, cx| {
|
||||
if panel.has_active_thread() {
|
||||
panel.message_editor.update(cx, |message_editor, cx| {
|
||||
message_editor.context_store().update(cx, |store, cx| {
|
||||
panel.thread.update(cx, |thread, cx| {
|
||||
thread.context_store().update(cx, |store, cx| {
|
||||
let buffer = buffer.read(cx);
|
||||
let selection_ranges = selection_ranges
|
||||
.into_iter()
|
||||
@@ -2007,7 +1951,9 @@ impl AssistantPanelDelegate for ConcreteAssistantPanelDelegate {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (buffer, range) in selection_ranges {
|
||||
store.add_selection(buffer, range, cx);
|
||||
store
|
||||
.add_selection(buffer, range, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
use crate::context::ContextLoadResult;
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::inline_prompt_editor::CodegenStatus;
|
||||
use crate::{context::load_context, context_store::ContextStore};
|
||||
use anyhow::Result;
|
||||
use client::telemetry::Telemetry;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
|
||||
};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
|
||||
use futures::{SinkExt, Stream, StreamExt, channel::mpsc, future::LocalBoxFuture, join};
|
||||
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
@@ -16,9 +14,7 @@ use language_model::{
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use parking_lot::Mutex;
|
||||
use project::Project;
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt;
|
||||
use std::{
|
||||
@@ -43,8 +39,6 @@ pub struct BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
pub is_insertion: bool,
|
||||
@@ -56,8 +50,6 @@ impl BufferCodegen {
|
||||
range: Range<Anchor>,
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
context_store: Entity<ContextStore>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -68,8 +60,6 @@ impl BufferCodegen {
|
||||
range.clone(),
|
||||
false,
|
||||
Some(context_store.clone()),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
Some(telemetry.clone()),
|
||||
builder.clone(),
|
||||
cx,
|
||||
@@ -85,8 +75,6 @@ impl BufferCodegen {
|
||||
range,
|
||||
initial_transaction_id,
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
builder,
|
||||
};
|
||||
@@ -165,8 +153,6 @@ impl BufferCodegen {
|
||||
self.range.clone(),
|
||||
false,
|
||||
Some(self.context_store.clone()),
|
||||
self.project.clone(),
|
||||
self.prompt_store.clone(),
|
||||
Some(self.telemetry.clone()),
|
||||
self.builder.clone(),
|
||||
cx,
|
||||
@@ -243,14 +229,13 @@ pub struct CodegenAlternative {
|
||||
generation: Task<()>,
|
||||
diff: Diff,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
_subscription: gpui::Subscription,
|
||||
builder: Arc<PromptBuilder>,
|
||||
active: bool,
|
||||
edits: Vec<(Range<Anchor>, String)>,
|
||||
line_operations: Vec<LineOperation>,
|
||||
request: Option<LanguageModelRequest>,
|
||||
elapsed_time: Option<f64>,
|
||||
completion: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
@@ -264,8 +249,6 @@ impl CodegenAlternative {
|
||||
range: Range<Anchor>,
|
||||
active: bool,
|
||||
context_store: Option<Entity<ContextStore>>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
telemetry: Option<Arc<Telemetry>>,
|
||||
builder: Arc<PromptBuilder>,
|
||||
cx: &mut Context<Self>,
|
||||
@@ -307,8 +290,6 @@ impl CodegenAlternative {
|
||||
generation: Task::ready(()),
|
||||
diff: Diff::default(),
|
||||
context_store,
|
||||
project,
|
||||
prompt_store,
|
||||
telemetry,
|
||||
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
|
||||
builder,
|
||||
@@ -316,6 +297,7 @@ impl CodegenAlternative {
|
||||
edits: Vec::new(),
|
||||
line_operations: Vec::new(),
|
||||
range,
|
||||
request: None,
|
||||
elapsed_time: None,
|
||||
completion: None,
|
||||
}
|
||||
@@ -384,18 +366,16 @@ impl CodegenAlternative {
|
||||
async { Ok(LanguageModelTextStream::default()) }.boxed_local()
|
||||
} else {
|
||||
let request = self.build_request(user_prompt, cx)?;
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
|
||||
self.request = Some(request.clone());
|
||||
|
||||
cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_request(
|
||||
&self,
|
||||
user_prompt: String,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
fn build_request(&self, user_prompt: String, cx: &mut App) -> Result<LanguageModelRequest> {
|
||||
let buffer = self.buffer.read(cx).snapshot(cx);
|
||||
let language = buffer.language_at(self.range.start);
|
||||
let language_name = if let Some(language) = language.as_ref() {
|
||||
@@ -428,44 +408,30 @@ impl CodegenAlternative {
|
||||
.generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
|
||||
|
||||
let context_task = self.context_store.as_ref().map(|context_store| {
|
||||
if let Some(project) = self.project.upgrade() {
|
||||
let context = context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
load_context(context, &project, &self.prompt_store, cx)
|
||||
} else {
|
||||
Task::ready(ContextLoadResult::default())
|
||||
}
|
||||
});
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
Ok(cx.spawn(async move |_cx| {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
if let Some(context_store) = &self.context_store {
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(context_task) = context_task {
|
||||
context_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
}
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
}
|
||||
}))
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
messages: vec![request_message],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn handle_stream(
|
||||
@@ -542,9 +508,7 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks = StripInvalidSpans::new(
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
@@ -1070,7 +1034,6 @@ impl Diff {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fs::FakeFs;
|
||||
use futures::{
|
||||
Stream,
|
||||
stream::{self},
|
||||
@@ -1113,16 +1076,12 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1181,16 +1140,12 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1252,16 +1207,12 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1323,16 +1274,12 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
true,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
@@ -1382,16 +1329,12 @@ mod tests {
|
||||
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
|
||||
});
|
||||
let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs, vec![], cx).await;
|
||||
let codegen = cx.new(|cx| {
|
||||
CodegenAlternative::new(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
false,
|
||||
None,
|
||||
project.downgrade(),
|
||||
None,
|
||||
None,
|
||||
prompt_builder,
|
||||
cx,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,11 +10,8 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use completion_provider::ContextPickerCompletionProvider;
|
||||
use editor::display_map::{Crease, FoldId};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
|
||||
use fetch_context_picker::FetchContextPicker;
|
||||
use file_context_picker::FileContextPicker;
|
||||
use file_context_picker::render_file_context_entry;
|
||||
use gpui::{
|
||||
App, DismissEvent, Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task,
|
||||
@@ -23,10 +20,10 @@ use gpui::{
|
||||
use language::Buffer;
|
||||
use multi_buffer::MultiBufferRow;
|
||||
use project::{Entry, ProjectPath};
|
||||
use prompt_store::{PromptStore, UserPromptId};
|
||||
use rules_context_picker::{RulesContextEntry, RulesContextPicker};
|
||||
use prompt_store::UserPromptId;
|
||||
use rules_context_picker::RulesContextEntry;
|
||||
use symbol_context_picker::SymbolContextPicker;
|
||||
use thread_context_picker::{ThreadContextEntry, ThreadContextPicker, render_thread_context_entry};
|
||||
use thread_context_picker::{ThreadContextEntry, render_thread_context_entry};
|
||||
use ui::{
|
||||
ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*,
|
||||
};
|
||||
@@ -35,6 +32,11 @@ use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::AssistantPanel;
|
||||
use crate::context::RULES_ICON;
|
||||
pub use crate::context_picker::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::context_picker::fetch_context_picker::FetchContextPicker;
|
||||
use crate::context_picker::file_context_picker::FileContextPicker;
|
||||
use crate::context_picker::rules_context_picker::RulesContextPicker;
|
||||
use crate::context_picker::thread_context_picker::ThreadContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::ThreadId;
|
||||
use crate::thread_store::ThreadStore;
|
||||
@@ -164,7 +166,6 @@ pub(super) struct ContextPicker {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -192,13 +193,6 @@ impl ContextPicker {
|
||||
)
|
||||
.collect::<Vec<Subscription>>();
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
ContextPicker {
|
||||
mode: ContextPickerState::Default(ContextMenu::build(
|
||||
window,
|
||||
@@ -208,7 +202,6 @@ impl ContextPicker {
|
||||
workspace,
|
||||
context_store,
|
||||
thread_store,
|
||||
prompt_store,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
@@ -233,12 +226,7 @@ impl ContextPicker {
|
||||
.workspace
|
||||
.upgrade()
|
||||
.map(|workspace| {
|
||||
available_context_picker_entries(
|
||||
&self.prompt_store,
|
||||
&self.thread_store,
|
||||
&workspace,
|
||||
cx,
|
||||
)
|
||||
available_context_picker_entries(&self.thread_store, &workspace, cx)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
@@ -316,10 +304,10 @@ impl ContextPicker {
|
||||
}));
|
||||
}
|
||||
ContextPickerMode::Rules => {
|
||||
if let Some(prompt_store) = self.prompt_store.as_ref() {
|
||||
if let Some(thread_store) = self.thread_store.as_ref() {
|
||||
self.mode = ContextPickerState::Rules(cx.new(|cx| {
|
||||
RulesContextPicker::new(
|
||||
prompt_store.clone(),
|
||||
thread_store.clone(),
|
||||
context_picker.clone(),
|
||||
self.context_store.clone(),
|
||||
window,
|
||||
@@ -538,7 +526,6 @@ enum RecentEntry {
|
||||
}
|
||||
|
||||
fn available_context_picker_entries(
|
||||
prompt_store: &Option<Entity<PromptStore>>,
|
||||
thread_store: &Option<WeakEntity<ThreadStore>>,
|
||||
workspace: &Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -563,9 +550,6 @@ fn available_context_picker_entries(
|
||||
|
||||
if thread_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Thread));
|
||||
}
|
||||
|
||||
if prompt_store.is_some() {
|
||||
entries.push(ContextPickerEntry::Mode(ContextPickerMode::Rules));
|
||||
}
|
||||
|
||||
@@ -601,21 +585,22 @@ fn recent_context_picker_entries(
|
||||
}),
|
||||
);
|
||||
|
||||
let current_threads = context_store.read(cx).thread_ids();
|
||||
let mut current_threads = context_store.read(cx).thread_ids();
|
||||
|
||||
let active_thread_id = workspace
|
||||
if let Some(active_thread) = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|panel| panel.read(cx).active_thread(cx).read(cx).id());
|
||||
.map(|panel| panel.read(cx).active_thread(cx))
|
||||
{
|
||||
current_threads.insert(active_thread.read(cx).id().clone());
|
||||
}
|
||||
|
||||
if let Some(thread_store) = thread_store.and_then(|thread_store| thread_store.upgrade()) {
|
||||
recent.extend(
|
||||
thread_store
|
||||
.read(cx)
|
||||
.reverse_chronological_threads()
|
||||
.threads()
|
||||
.into_iter()
|
||||
.filter(|thread| {
|
||||
Some(&thread.id) != active_thread_id && !current_threads.contains(&thread.id)
|
||||
})
|
||||
.filter(|thread| !current_threads.contains(&thread.id))
|
||||
.take(2)
|
||||
.map(|thread| {
|
||||
RecentEntry::Thread(ThreadContextEntry {
|
||||
@@ -637,7 +622,9 @@ fn add_selections_as_context(
|
||||
let selection_ranges = selection_ranges(workspace, cx);
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in selection_ranges {
|
||||
context_store.add_selection(buffer, range, cx);
|
||||
context_store
|
||||
.add_selection(buffer, range, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,21 +15,22 @@ use itertools::Itertools;
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use lsp::CompletionContext;
|
||||
use project::{Completion, CompletionIntent, ProjectPath, Symbol, WorktreeId};
|
||||
use prompt_store::PromptStore;
|
||||
use prompt_store::PromptId;
|
||||
use rope::Point;
|
||||
use text::{Anchor, OffsetRangeExt, ToPoint};
|
||||
use ui::prelude::*;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::file_context_picker::search_files;
|
||||
use crate::context_picker::symbol_context_picker::search_symbols;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
use super::fetch_context_picker::fetch_url_content;
|
||||
use super::file_context_picker::{FileMatch, search_files};
|
||||
use super::file_context_picker::FileMatch;
|
||||
use super::rules_context_picker::{RulesContextEntry, search_rules};
|
||||
use super::symbol_context_picker::SymbolMatch;
|
||||
use super::symbol_context_picker::search_symbols;
|
||||
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
|
||||
use super::{
|
||||
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
|
||||
@@ -37,8 +38,8 @@ use super::{
|
||||
};
|
||||
|
||||
pub(crate) enum Match {
|
||||
File(FileMatch),
|
||||
Symbol(SymbolMatch),
|
||||
File(FileMatch),
|
||||
Thread(ThreadMatch),
|
||||
Fetch(SharedString),
|
||||
Rules(RulesContextEntry),
|
||||
@@ -68,7 +69,6 @@ fn search(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
recent_entries: Vec<RecentEntry>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
workspace: Entity<Workspace>,
|
||||
cx: &mut App,
|
||||
@@ -85,7 +85,6 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Symbol) => {
|
||||
let search_symbols_task =
|
||||
search_symbols(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
@@ -97,7 +96,6 @@ fn search(
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Thread) => {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
let search_threads_task =
|
||||
@@ -113,7 +111,6 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Fetch) => {
|
||||
if !query.is_empty() {
|
||||
Task::ready(vec![Match::Fetch(query.into())])
|
||||
@@ -121,11 +118,10 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
Some(ContextPickerMode::Rules) => {
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) {
|
||||
let search_rules_task =
|
||||
search_rules(query.clone(), cancellation_flag.clone(), prompt_store, cx);
|
||||
search_rules(query.clone(), cancellation_flag.clone(), thread_store, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_rules_task
|
||||
.await
|
||||
@@ -137,7 +133,6 @@ fn search(
|
||||
Task::ready(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
if query.is_empty() {
|
||||
let mut matches = recent_entries
|
||||
@@ -168,7 +163,7 @@ fn search(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
matches.extend(
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx)
|
||||
available_context_picker_entries(&thread_store, &workspace, cx)
|
||||
.into_iter()
|
||||
.map(|mode| {
|
||||
Match::Entry(EntryMatch {
|
||||
@@ -185,8 +180,7 @@ fn search(
|
||||
let search_files_task =
|
||||
search_files(query.clone(), cancellation_flag.clone(), &workspace, cx);
|
||||
|
||||
let entries =
|
||||
available_context_picker_entries(&prompt_store, &thread_store, &workspace, cx);
|
||||
let entries = available_context_picker_entries(&thread_store, &workspace, cx);
|
||||
let entry_candidates = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -313,11 +307,9 @@ impl ContextPickerCompletionProvider {
|
||||
move |_, _: &mut Window, cx: &mut App| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
for (buffer, range) in &selections {
|
||||
context_store.add_selection(
|
||||
buffer.clone(),
|
||||
range.clone(),
|
||||
cx,
|
||||
);
|
||||
context_store
|
||||
.add_selection(buffer.clone(), range.clone(), cx)
|
||||
.detach_and_log_err(cx)
|
||||
}
|
||||
});
|
||||
|
||||
@@ -445,6 +437,7 @@ impl ContextPickerCompletionProvider {
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
context_store: Entity<ContextStore>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
) -> Completion {
|
||||
let new_text = MentionLink::for_rules(&rules);
|
||||
let new_text_len = new_text.len();
|
||||
@@ -464,10 +457,29 @@ impl ContextPickerCompletionProvider {
|
||||
new_text_len,
|
||||
editor.clone(),
|
||||
move |cx| {
|
||||
let user_prompt_id = rules.prompt_id;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(user_prompt_id, false, cx);
|
||||
});
|
||||
let prompt_uuid = rules.prompt_id;
|
||||
let prompt_id = PromptId::User { uuid: prompt_uuid };
|
||||
let context_store = context_store.clone();
|
||||
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
|
||||
log::error!("Can't add user rules as prompt store is missing.");
|
||||
return;
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(title) = metadata.title else {
|
||||
return;
|
||||
};
|
||||
let text_task = prompt_store.load(prompt_id, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let text = text_task.await?;
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(prompt_uuid, title, text, false, cx)
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
)),
|
||||
}
|
||||
@@ -504,7 +516,7 @@ impl ContextPickerCompletionProvider {
|
||||
let url_to_fetch = url_to_fetch.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
if context_store.update(cx, |context_store, _| {
|
||||
context_store.includes_url(&url_to_fetch)
|
||||
context_store.includes_url(&url_to_fetch).is_some()
|
||||
})? {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -580,7 +592,7 @@ impl ContextPickerCompletionProvider {
|
||||
move |cx| {
|
||||
context_store.update(cx, |context_store, cx| {
|
||||
let task = if is_directory {
|
||||
Task::ready(context_store.add_directory(&project_path, false, cx))
|
||||
context_store.add_directory(project_path.clone(), false, cx)
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path.clone(), false, cx)
|
||||
};
|
||||
@@ -720,19 +732,11 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
);
|
||||
|
||||
let prompt_store = thread_store.as_ref().and_then(|thread_store| {
|
||||
thread_store
|
||||
.read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
});
|
||||
|
||||
let search_task = search(
|
||||
mode,
|
||||
query,
|
||||
Arc::<AtomicBool>::default(),
|
||||
recent_entries,
|
||||
prompt_store,
|
||||
thread_store.clone(),
|
||||
workspace.clone(),
|
||||
cx,
|
||||
@@ -764,7 +768,6 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
cx,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Symbol(SymbolMatch { symbol, .. }) => Self::completion_for_symbol(
|
||||
symbol,
|
||||
excerpt_id,
|
||||
@@ -774,7 +777,6 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
workspace.clone(),
|
||||
cx,
|
||||
),
|
||||
|
||||
Match::Thread(ThreadMatch {
|
||||
thread, is_recent, ..
|
||||
}) => {
|
||||
@@ -789,15 +791,17 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
|
||||
Match::Rules(user_rules) => Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
)),
|
||||
|
||||
Match::Rules(user_rules) => {
|
||||
let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?;
|
||||
Some(Self::completion_for_rules(
|
||||
user_rules,
|
||||
excerpt_id,
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
context_store.clone(),
|
||||
thread_store,
|
||||
))
|
||||
}
|
||||
Match::Fetch(url) => Some(Self::completion_for_fetch(
|
||||
source_range.clone(),
|
||||
url,
|
||||
@@ -806,7 +810,6 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
context_store.clone(),
|
||||
http_client.clone(),
|
||||
)),
|
||||
|
||||
Match::Entry(EntryMatch { entry, .. }) => Self::completion_for_entry(
|
||||
entry,
|
||||
excerpt_id,
|
||||
|
||||
@@ -227,7 +227,7 @@ impl PickerDelegate for FetchContextPickerDelegate {
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
let added = self.context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store.read(cx).includes_url(&self.url)
|
||||
context_store.read(cx).includes_url(&self.url).is_some()
|
||||
});
|
||||
|
||||
Some(
|
||||
|
||||
@@ -134,9 +134,9 @@ impl PickerDelegate for FileContextPickerDelegate {
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
if is_directory {
|
||||
Task::ready(context_store.add_directory(&project_path, true, cx))
|
||||
context_store.add_directory(project_path, true, cx)
|
||||
} else {
|
||||
context_store.add_file_from_path(project_path.clone(), true, cx)
|
||||
context_store.add_file_from_path(project_path, true, cx)
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
@@ -325,11 +325,11 @@ pub fn render_file_context_entry(
|
||||
path: path.clone(),
|
||||
};
|
||||
if is_directory {
|
||||
context_store.read(cx).includes_directory(&project_path)
|
||||
} else {
|
||||
context_store
|
||||
.read(cx)
|
||||
.path_included_in_directory(&project_path, cx)
|
||||
} else {
|
||||
context_store.read(cx).file_path_included(&project_path, cx)
|
||||
.will_include_file_path(&project_path, cx)
|
||||
}
|
||||
});
|
||||
|
||||
@@ -357,7 +357,7 @@ pub fn render_file_context_entry(
|
||||
})),
|
||||
)
|
||||
.when_some(added, |el, added| match added {
|
||||
FileInclusion::Direct => el.child(
|
||||
FileInclusion::Direct(_) => el.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.justify_end()
|
||||
@@ -369,8 +369,9 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Added").size(LabelSize::Small)),
|
||||
),
|
||||
FileInclusion::InDirectory { full_path } => {
|
||||
let directory_full_path = full_path.to_string_lossy().into_owned();
|
||||
FileInclusion::InDirectory(directory_project_path) => {
|
||||
// TODO: Consider using worktree full_path to include worktree name.
|
||||
let directory_path = directory_project_path.path.to_string_lossy().into_owned();
|
||||
|
||||
el.child(
|
||||
h_flex()
|
||||
@@ -384,7 +385,7 @@ pub fn render_file_context_entry(
|
||||
)
|
||||
.child(Label::new("Included").size(LabelSize::Small)),
|
||||
)
|
||||
.tooltip(Tooltip::text(format!("in {directory_full_path}")))
|
||||
.tooltip(Tooltip::text(format!("in {directory_path}")))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity};
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use prompt_store::{PromptId, PromptStore, UserPromptId};
|
||||
use prompt_store::{PromptId, UserPromptId};
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::context::RULES_ICON;
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::{self, ContextStore};
|
||||
use crate::thread_store::ThreadStore;
|
||||
|
||||
pub struct RulesContextPicker {
|
||||
picker: Entity<Picker<RulesContextPickerDelegate>>,
|
||||
@@ -17,13 +18,13 @@ pub struct RulesContextPicker {
|
||||
|
||||
impl RulesContextPicker {
|
||||
pub fn new(
|
||||
prompt_store: Entity<PromptStore>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let delegate = RulesContextPickerDelegate::new(prompt_store, context_picker, context_store);
|
||||
let delegate = RulesContextPickerDelegate::new(thread_store, context_picker, context_store);
|
||||
let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx));
|
||||
|
||||
RulesContextPicker { picker }
|
||||
@@ -49,7 +50,7 @@ pub struct RulesContextEntry {
|
||||
}
|
||||
|
||||
pub struct RulesContextPickerDelegate {
|
||||
prompt_store: Entity<PromptStore>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
matches: Vec<RulesContextEntry>,
|
||||
@@ -58,12 +59,12 @@ pub struct RulesContextPickerDelegate {
|
||||
|
||||
impl RulesContextPickerDelegate {
|
||||
pub fn new(
|
||||
prompt_store: Entity<PromptStore>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
context_picker: WeakEntity<ContextPicker>,
|
||||
context_store: WeakEntity<context_store::ContextStore>,
|
||||
) -> Self {
|
||||
RulesContextPickerDelegate {
|
||||
prompt_store,
|
||||
thread_store,
|
||||
context_picker,
|
||||
context_store,
|
||||
matches: Vec::new(),
|
||||
@@ -102,12 +103,11 @@ impl PickerDelegate for RulesContextPickerDelegate {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
let search_task = search_rules(
|
||||
query,
|
||||
Arc::new(AtomicBool::default()),
|
||||
&self.prompt_store,
|
||||
cx,
|
||||
);
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return Task::ready(());
|
||||
};
|
||||
|
||||
let search_task = search_rules(query, Arc::new(AtomicBool::default()), thread_store, cx);
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let matches = search_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
@@ -124,11 +124,31 @@ impl PickerDelegate for RulesContextPickerDelegate {
|
||||
return;
|
||||
};
|
||||
|
||||
self.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(entry.prompt_id, true, cx)
|
||||
let Some(thread_store) = self.thread_store.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
let prompt_id = entry.prompt_id;
|
||||
|
||||
let load_rules_task = thread_store.update(cx, |thread_store, cx| {
|
||||
thread_store.load_rules(prompt_id, cx)
|
||||
});
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (metadata, text) = load_rules_task.await?;
|
||||
let Some(title) = metadata.title else {
|
||||
return Err(anyhow!("Encountered user rule with no title when attempting to add it to agent context."));
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.delegate
|
||||
.context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
context_store.add_rules(prompt_id, title, text, true, cx)
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
@@ -159,10 +179,11 @@ pub fn render_thread_context_entry(
|
||||
context_store: WeakEntity<ContextStore>,
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |context_store| {
|
||||
context_store
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store
|
||||
.read(cx)
|
||||
.includes_user_rules(user_rules.prompt_id)
|
||||
.includes_user_rules(&user_rules.prompt_id)
|
||||
.is_some()
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -197,9 +218,12 @@ pub fn render_thread_context_entry(
|
||||
pub(crate) fn search_rules(
|
||||
query: String,
|
||||
cancellation_flag: Arc<AtomicBool>,
|
||||
prompt_store: &Entity<PromptStore>,
|
||||
thread_store: Entity<ThreadStore>,
|
||||
cx: &mut App,
|
||||
) -> Task<Vec<RulesContextEntry>> {
|
||||
let Some(prompt_store) = thread_store.read(cx).prompt_store() else {
|
||||
return Task::ready(vec![]);
|
||||
};
|
||||
let search_task = prompt_store.read(cx).search(query, cancellation_flag, cx);
|
||||
cx.background_spawn(async move {
|
||||
search_task
|
||||
|
||||
@@ -10,6 +10,7 @@ use gpui::{
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use project::{DocumentSymbol, Symbol};
|
||||
use text::OffsetRangeExt;
|
||||
use ui::{ListItem, prelude::*};
|
||||
use util::ResultExt as _;
|
||||
use workspace::Workspace;
|
||||
@@ -227,16 +228,18 @@ pub(crate) fn add_symbol(
|
||||
)
|
||||
})?;
|
||||
|
||||
context_store.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
context_store
|
||||
.update(cx, move |context_store, cx| {
|
||||
context_store.add_symbol(
|
||||
buffer,
|
||||
name.into(),
|
||||
range,
|
||||
enclosing_range,
|
||||
remove_if_exists,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
||||
@@ -350,13 +353,38 @@ fn compute_symbol_entries(
|
||||
context_store: &ContextStore,
|
||||
cx: &App,
|
||||
) -> Vec<SymbolEntry> {
|
||||
symbols
|
||||
.into_iter()
|
||||
.map(|SymbolMatch { symbol, .. }| SymbolEntry {
|
||||
is_included: context_store.includes_symbol(&symbol, cx),
|
||||
let mut symbol_entries = Vec::with_capacity(symbols.len());
|
||||
for SymbolMatch { symbol, .. } in symbols {
|
||||
let symbols_for_path = context_store.included_symbols_by_path().get(&symbol.path);
|
||||
let is_included = if let Some(symbols_for_path) = symbols_for_path {
|
||||
let mut is_included = false;
|
||||
for included_symbol_id in symbols_for_path {
|
||||
if included_symbol_id.name.as_ref() == symbol.name.as_str() {
|
||||
if let Some(buffer) = context_store.buffer_for_symbol(included_symbol_id) {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let included_symbol_range =
|
||||
included_symbol_id.range.to_point_utf16(&snapshot);
|
||||
|
||||
if included_symbol_range.start == symbol.range.start.0
|
||||
&& included_symbol_range.end == symbol.range.end.0
|
||||
{
|
||||
is_included = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
is_included
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
symbol_entries.push(SymbolEntry {
|
||||
symbol,
|
||||
is_included,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
symbol_entries
|
||||
}
|
||||
|
||||
pub fn render_symbol_context_entry(id: ElementId, entry: &SymbolEntry) -> Stateful<Div> {
|
||||
|
||||
@@ -173,7 +173,7 @@ pub fn render_thread_context_entry(
|
||||
cx: &mut App,
|
||||
) -> Div {
|
||||
let added = context_store.upgrade().map_or(false, |ctx_store| {
|
||||
ctx_store.read(cx).includes_thread(&thread.id)
|
||||
ctx_store.read(cx).includes_thread(&thread.id).is_some()
|
||||
});
|
||||
|
||||
h_flex()
|
||||
@@ -219,7 +219,7 @@ pub(crate) fn search_threads(
|
||||
) -> Task<Vec<ThreadMatch>> {
|
||||
let threads = thread_store
|
||||
.read(cx)
|
||||
.reverse_chronological_threads()
|
||||
.threads()
|
||||
.into_iter()
|
||||
.map(|thread| ThreadContextEntry {
|
||||
id: thread.id,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,9 @@ use itertools::Itertools;
|
||||
use language::Buffer;
|
||||
use project::ProjectItem;
|
||||
use ui::{KeyBinding, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*};
|
||||
use workspace::Workspace;
|
||||
use workspace::{Workspace, notifications::NotifyResultExt};
|
||||
|
||||
use crate::context::{AgentContext, ContextKind};
|
||||
use crate::context::{ContextId, ContextKind};
|
||||
use crate::context_picker::ContextPicker;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::thread::Thread;
|
||||
@@ -32,7 +32,6 @@ pub struct ContextStrip {
|
||||
focus_handle: FocusHandle,
|
||||
suggest_context_kind: SuggestContextKind,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
focused_index: Option<usize>,
|
||||
children_bounds: Option<Vec<Bounds<Pixels>>>,
|
||||
@@ -74,31 +73,12 @@ impl ContextStrip {
|
||||
focus_handle,
|
||||
suggest_context_kind,
|
||||
workspace,
|
||||
thread_store,
|
||||
_subscriptions: subscriptions,
|
||||
focused_index: None,
|
||||
children_bounds: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn added_contexts(&self, cx: &App) -> Vec<AddedContext> {
|
||||
if let Some(workspace) = self.workspace.upgrade() {
|
||||
let project = workspace.read(cx).project().read(cx);
|
||||
let prompt_store = self
|
||||
.thread_store
|
||||
.as_ref()
|
||||
.and_then(|thread_store| thread_store.upgrade())
|
||||
.and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref());
|
||||
self.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.flat_map(|context| AddedContext::new(context.clone(), prompt_store, project, cx))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn suggested_context(&self, cx: &Context<Self>) -> Option<SuggestedContext> {
|
||||
match self.suggest_context_kind {
|
||||
SuggestContextKind::File => self.suggested_file(cx),
|
||||
@@ -113,19 +93,22 @@ impl ContextStrip {
|
||||
let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
|
||||
let active_buffer_entity = editor.buffer().read(cx).as_singleton()?;
|
||||
let active_buffer = active_buffer_entity.read(cx);
|
||||
|
||||
let project_path = active_buffer.project_path(cx)?;
|
||||
|
||||
if self
|
||||
.context_store
|
||||
.read(cx)
|
||||
.file_path_included(&project_path, cx)
|
||||
.will_include_buffer(active_buffer.remote_id(), &project_path)
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let file_name = active_buffer.file()?.file_name(cx);
|
||||
|
||||
let icon_path = FileIcons::get_icon(&Path::new(&file_name), cx);
|
||||
|
||||
Some(SuggestedContext::File {
|
||||
name: file_name.to_string_lossy().into_owned().into(),
|
||||
buffer: active_buffer_entity.downgrade(),
|
||||
@@ -152,6 +135,7 @@ impl ContextStrip {
|
||||
.context_store
|
||||
.read(cx)
|
||||
.includes_thread(active_thread.id())
|
||||
.is_some()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -288,12 +272,12 @@ impl ContextStrip {
|
||||
best.map(|(index, _, _)| index)
|
||||
}
|
||||
|
||||
fn open_context(&mut self, context: &AgentContext, window: &mut Window, cx: &mut App) {
|
||||
fn open_context(&mut self, id: ContextId, window: &mut Window, cx: &mut App) {
|
||||
let Some(workspace) = self.workspace.upgrade() else {
|
||||
return;
|
||||
};
|
||||
|
||||
crate::active_thread::open_context(context, workspace, window, cx);
|
||||
crate::active_thread::open_context(id, self.context_store.clone(), workspace, window, cx);
|
||||
}
|
||||
|
||||
fn remove_focused_context(
|
||||
@@ -303,17 +287,17 @@ impl ContextStrip {
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(index) = self.focused_index {
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let Some(context) = added_contexts.get(index) else {
|
||||
return;
|
||||
};
|
||||
let mut is_empty = false;
|
||||
|
||||
self.context_store.update(cx, |this, cx| {
|
||||
this.remove_context(&context.context, cx);
|
||||
if let Some(item) = this.context().get(index) {
|
||||
this.remove_context(item.id(), cx);
|
||||
}
|
||||
|
||||
is_empty = this.context().is_empty();
|
||||
});
|
||||
|
||||
let is_now_empty = added_contexts.len() == 1;
|
||||
if is_now_empty {
|
||||
if is_empty {
|
||||
cx.emit(ContextStripEvent::BlurredEmpty);
|
||||
} else {
|
||||
self.focused_index = Some(index.saturating_sub(1));
|
||||
@@ -322,28 +306,49 @@ impl ContextStrip {
|
||||
}
|
||||
}
|
||||
|
||||
fn is_suggested_focused(&self, added_contexts: &Vec<AddedContext>) -> bool {
|
||||
fn is_suggested_focused<T>(&self, context: &Vec<T>) -> bool {
|
||||
// We only suggest one item after the actual context
|
||||
self.focused_index == Some(added_contexts.len())
|
||||
self.focused_index == Some(context.len())
|
||||
}
|
||||
|
||||
fn accept_suggested_context(
|
||||
&mut self,
|
||||
_: &AcceptSuggestedContext,
|
||||
_window: &mut Window,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
if let Some(suggested) = self.suggested_context(cx) {
|
||||
if self.is_suggested_focused(&self.added_contexts(cx)) {
|
||||
self.add_suggested_context(&suggested, cx);
|
||||
let context_store = self.context_store.read(cx);
|
||||
|
||||
if self.is_suggested_focused(context_store.context()) {
|
||||
self.add_suggested_context(&suggested, window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_suggested_context(&mut self, suggested: &SuggestedContext, cx: &mut Context<Self>) {
|
||||
self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.add_suggested_context(&suggested, cx)
|
||||
fn add_suggested_context(
|
||||
&mut self,
|
||||
suggested: &SuggestedContext,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let task = self.context_store.update(cx, |context_store, cx| {
|
||||
context_store.accept_suggested_context(&suggested, cx)
|
||||
});
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
match task.await.notify_async_err(cx) {
|
||||
None => {}
|
||||
Some(()) => {
|
||||
if let Some(this) = this.upgrade() {
|
||||
this.update(cx, |_, cx| cx.notify())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
@@ -356,10 +361,17 @@ impl Focusable for ContextStrip {
|
||||
|
||||
impl Render for ContextStrip {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let context_store = self.context_store.read(cx);
|
||||
let context = context_store.context();
|
||||
let context_picker = self.context_picker.clone();
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
let added_contexts = self.added_contexts(cx);
|
||||
let suggested_context = self.suggested_context(cx);
|
||||
|
||||
let added_contexts = context
|
||||
.iter()
|
||||
.map(|c| AddedContext::new(c, cx))
|
||||
.collect::<Vec<_>>();
|
||||
let dupe_names = added_contexts
|
||||
.iter()
|
||||
.map(|c| c.name.clone())
|
||||
@@ -368,14 +380,6 @@ impl Render for ContextStrip {
|
||||
.filter(|(a, b)| a == b)
|
||||
.map(|(a, _)| a)
|
||||
.collect::<HashSet<SharedString>>();
|
||||
let no_added_context = added_contexts.is_empty();
|
||||
|
||||
let suggested_context = self.suggested_context(cx).map(|suggested_context| {
|
||||
(
|
||||
suggested_context,
|
||||
self.is_suggested_focused(&added_contexts),
|
||||
)
|
||||
});
|
||||
|
||||
h_flex()
|
||||
.flex_wrap()
|
||||
@@ -432,7 +436,7 @@ impl Render for ContextStrip {
|
||||
})
|
||||
.with_handle(self.context_picker_menu_handle.clone()),
|
||||
)
|
||||
.when(no_added_context && suggested_context.is_none(), {
|
||||
.when(context.is_empty() && suggested_context.is_none(), {
|
||||
|parent| {
|
||||
parent.child(
|
||||
h_flex()
|
||||
@@ -462,17 +466,16 @@ impl Render for ContextStrip {
|
||||
.enumerate()
|
||||
.map(|(i, added_context)| {
|
||||
let name = added_context.name.clone();
|
||||
let context = added_context.context.clone();
|
||||
let id = added_context.id;
|
||||
ContextPill::added(
|
||||
added_context,
|
||||
dupe_names.contains(&name),
|
||||
self.focused_index == Some(i),
|
||||
Some({
|
||||
let context = context.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
Rc::new(cx.listener(move |_this, _event, _window, cx| {
|
||||
context_store.update(cx, |this, cx| {
|
||||
this.remove_context(&context, cx);
|
||||
this.remove_context(id, cx);
|
||||
});
|
||||
cx.notify();
|
||||
}))
|
||||
@@ -481,7 +484,7 @@ impl Render for ContextStrip {
|
||||
.on_click({
|
||||
Rc::new(cx.listener(move |this, event: &ClickEvent, window, cx| {
|
||||
if event.down.click_count > 1 {
|
||||
this.open_context(&context, window, cx);
|
||||
this.open_context(id, window, cx);
|
||||
} else {
|
||||
this.focused_index = Some(i);
|
||||
}
|
||||
@@ -490,22 +493,22 @@ impl Render for ContextStrip {
|
||||
})
|
||||
}),
|
||||
)
|
||||
.when_some(suggested_context, |el, (suggested, focused)| {
|
||||
.when_some(suggested_context, |el, suggested| {
|
||||
el.child(
|
||||
ContextPill::suggested(
|
||||
suggested.name().clone(),
|
||||
suggested.icon_path(),
|
||||
suggested.kind(),
|
||||
focused,
|
||||
self.is_suggested_focused(&context),
|
||||
)
|
||||
.on_click(Rc::new(cx.listener(
|
||||
move |this, _event, _window, cx| {
|
||||
this.add_suggested_context(&suggested, cx);
|
||||
move |this, _event, window, cx| {
|
||||
this.add_suggested_context(&suggested, window, cx);
|
||||
},
|
||||
))),
|
||||
)
|
||||
})
|
||||
.when(!no_added_context, {
|
||||
.when(!context.is_empty(), {
|
||||
move |parent| {
|
||||
parent.child(
|
||||
IconButton::new("remove-all-context", IconName::Eraser)
|
||||
@@ -531,7 +534,6 @@ impl Render for ContextStrip {
|
||||
)
|
||||
}
|
||||
})
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,10 +51,7 @@ impl HistoryStore {
|
||||
return history_entries;
|
||||
}
|
||||
|
||||
for thread in self
|
||||
.thread_store
|
||||
.update(cx, |this, _cx| this.reverse_chronological_threads())
|
||||
{
|
||||
for thread in self.thread_store.update(cx, |this, _cx| this.threads()) {
|
||||
history_entries.push(HistoryEntry::Thread(thread));
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ use project::LspAction;
|
||||
use project::Project;
|
||||
use project::{CodeAction, ProjectTransaction};
|
||||
use prompt_store::PromptBuilder;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::{Settings, SettingsStore};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::{TerminalView, terminal_panel::TerminalPanel};
|
||||
@@ -246,13 +245,9 @@ impl InlineAssistant {
|
||||
.map_or(false, |model| model.provider.is_authenticated(cx))
|
||||
};
|
||||
|
||||
let assistant_panel = workspace
|
||||
let thread_store = workspace
|
||||
.panel::<AssistantPanel>(cx)
|
||||
.map(|assistant_panel| assistant_panel.read(cx));
|
||||
let prompt_store = assistant_panel
|
||||
.and_then(|assistant_panel| assistant_panel.prompt_store().as_ref().cloned());
|
||||
let thread_store =
|
||||
assistant_panel.map(|assistant_panel| assistant_panel.thread_store().downgrade());
|
||||
.map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
|
||||
|
||||
let handle_assist =
|
||||
|window: &mut Window, cx: &mut Context<Workspace>| match inline_assist_target {
|
||||
@@ -262,7 +257,6 @@ impl InlineAssistant {
|
||||
&active_editor,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -275,7 +269,6 @@ impl InlineAssistant {
|
||||
&active_terminal,
|
||||
cx.entity().downgrade(),
|
||||
workspace.project().downgrade(),
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
@@ -330,7 +323,6 @@ impl InlineAssistant {
|
||||
editor: &Entity<Editor>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -445,8 +437,6 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
None,
|
||||
context_store.clone(),
|
||||
project.clone(),
|
||||
prompt_store.clone(),
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -535,7 +525,6 @@ impl InlineAssistant {
|
||||
initial_transaction_id: Option<TransactionId>,
|
||||
focus: bool,
|
||||
workspace: Entity<Workspace>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -554,7 +543,7 @@ impl InlineAssistant {
|
||||
}
|
||||
|
||||
let project = workspace.read(cx).project().downgrade();
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone()));
|
||||
let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
|
||||
|
||||
let codegen = cx.new(|cx| {
|
||||
BufferCodegen::new(
|
||||
@@ -562,8 +551,6 @@ impl InlineAssistant {
|
||||
range.clone(),
|
||||
initial_transaction_id,
|
||||
context_store.clone(),
|
||||
project,
|
||||
prompt_store,
|
||||
self.telemetry.clone(),
|
||||
self.prompt_builder.clone(),
|
||||
cx,
|
||||
@@ -1341,7 +1328,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<InlineAssist>(
|
||||
row_range,
|
||||
cx.theme().status().info_background,
|
||||
Default::default(),
|
||||
false,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1406,7 +1393,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<DeletedLines>(
|
||||
Anchor::min()..Anchor::max(),
|
||||
cx.theme().status().deleted_background,
|
||||
Default::default(),
|
||||
false,
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
@@ -1802,7 +1789,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
let editor = self.editor.clone();
|
||||
let workspace = self.workspace.clone();
|
||||
let thread_store = self.thread_store.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
window.spawn(cx, async move |cx| {
|
||||
let workspace = workspace.upgrade().context("workspace was released")?;
|
||||
let editor = editor.upgrade().context("editor was released")?;
|
||||
@@ -1843,7 +1829,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
})?
|
||||
.context("invalid range")?;
|
||||
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
|
||||
let assist_id = assistant.suggest_assist(
|
||||
&editor,
|
||||
@@ -1852,7 +1837,6 @@ impl CodeActionProvider for AssistantCodeActionProvider {
|
||||
None,
|
||||
true,
|
||||
workspace,
|
||||
prompt_store,
|
||||
thread_store,
|
||||
window,
|
||||
cx,
|
||||
|
||||
@@ -13,7 +13,7 @@ use editor::{
|
||||
Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
|
||||
actions::{MoveDown, MoveUp},
|
||||
};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
|
||||
use feature_flags::{FeatureFlagAppExt as _, ZedPro};
|
||||
use fs::Fs;
|
||||
use gpui::{
|
||||
AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
|
||||
@@ -132,7 +132,7 @@ impl<T: 'static> Render for PromptEditor<T> {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -931,7 +931,7 @@ impl PromptEditor<BufferCodegen> {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
if cx.has_flag::<ZedPro>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::assistant_model_selector::ModelType;
|
||||
use crate::context::{ContextLoadResult, load_context};
|
||||
use crate::context::{AssistantContext, format_context_as_string};
|
||||
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
|
||||
use buffer_diff::BufferDiff;
|
||||
use collections::HashSet;
|
||||
@@ -13,8 +13,6 @@ use editor::{
|
||||
};
|
||||
use file_icons::FileIcons;
|
||||
use fs::Fs;
|
||||
use futures::future::Shared;
|
||||
use futures::{FutureExt as _, future};
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, ClipboardEntry, Entity, EventEmitter, Focusable, Subscription,
|
||||
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
|
||||
@@ -24,7 +22,6 @@ use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelReques
|
||||
use language_model_selector::ToggleModelSelector;
|
||||
use multi_buffer;
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use settings::Settings;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
@@ -34,7 +31,7 @@ use workspace::Workspace;
|
||||
|
||||
use crate::assistant_model_selector::AssistantModelSelector;
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::context_store::{ContextStore, refresh_context_store_text};
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
use crate::profile_selector::ProfileSelector;
|
||||
use crate::thread::{Thread, TokenUsageRatio};
|
||||
@@ -52,12 +49,9 @@ pub struct MessageEditor {
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
context_strip: Entity<ContextStrip>,
|
||||
context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
|
||||
model_selector: Entity<AssistantModelSelector>,
|
||||
last_loaded_context: Option<ContextLoadResult>,
|
||||
context_load_task: Option<Shared<Task<()>>>,
|
||||
profile_selector: Entity<ProfileSelector>,
|
||||
edits_expanded: bool,
|
||||
editor_is_expanded: bool,
|
||||
@@ -74,7 +68,6 @@ impl MessageEditor {
|
||||
fs: Arc<dyn Fs>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: WeakEntity<ThreadStore>,
|
||||
thread: Entity<Thread>,
|
||||
window: &mut Window,
|
||||
@@ -142,11 +135,13 @@ impl MessageEditor {
|
||||
let subscriptions = vec![
|
||||
cx.subscribe_in(&context_strip, window, Self::handle_context_strip_event),
|
||||
cx.subscribe(&editor, |this, _, event, cx| match event {
|
||||
EditorEvent::BufferEdited => this.handle_message_changed(cx),
|
||||
EditorEvent::BufferEdited => {
|
||||
this.message_or_context_changed(true, cx);
|
||||
}
|
||||
_ => {}
|
||||
}),
|
||||
cx.observe(&context_store, |this, _, cx| {
|
||||
let _ = this.start_context_load(cx);
|
||||
this.message_or_context_changed(false, cx);
|
||||
}),
|
||||
];
|
||||
|
||||
@@ -157,11 +152,8 @@ impl MessageEditor {
|
||||
incompatible_tools_state: incompatible_tools.clone(),
|
||||
workspace,
|
||||
context_store,
|
||||
prompt_store,
|
||||
context_strip,
|
||||
context_picker_menu_handle,
|
||||
context_load_task: None,
|
||||
last_loaded_context: None,
|
||||
model_selector: cx.new(|cx| {
|
||||
AssistantModelSelector::new(
|
||||
fs.clone(),
|
||||
@@ -183,10 +175,6 @@ impl MessageEditor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context_store(&self) -> &Entity<ContextStore> {
|
||||
&self.context_store
|
||||
}
|
||||
|
||||
fn toggle_chat_mode(&mut self, _: &ChatMode, _window: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.notify();
|
||||
}
|
||||
@@ -207,7 +195,6 @@ impl MessageEditor {
|
||||
editor.set_mode(EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: false,
|
||||
})
|
||||
} else {
|
||||
editor.set_mode(EditorMode::AutoHeight {
|
||||
@@ -226,7 +213,6 @@ impl MessageEditor {
|
||||
) {
|
||||
self.context_picker_menu_handle.toggle(window, cx);
|
||||
}
|
||||
|
||||
pub fn remove_all_context(
|
||||
&mut self,
|
||||
_: &RemoveAllContext,
|
||||
@@ -283,44 +269,56 @@ impl MessageEditor {
|
||||
self.last_estimated_token_count.take();
|
||||
cx.emit(MessageEditorEvent::EstimatedTokenCount);
|
||||
|
||||
let refresh_task =
|
||||
refresh_context_store_text(self.context_store.clone(), &HashSet::default(), cx);
|
||||
let wait_for_images = self.context_store.read(cx).wait_for_images(cx);
|
||||
|
||||
let thread = self.thread.clone();
|
||||
let context_store = self.context_store.clone();
|
||||
let git_store = self.project.read(cx).git_store().clone();
|
||||
let checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
|
||||
let context_task = self.load_context(cx);
|
||||
let window_handle = window.window_handle();
|
||||
|
||||
cx.spawn(async move |_this, cx| {
|
||||
let (checkpoint, loaded_context) = future::join(checkpoint, context_task).await;
|
||||
let loaded_context = loaded_context.unwrap_or_default();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.insert_user_message(user_message, loaded_context, checkpoint.ok(), cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
thread.send_to_model(model, Some(window_handle), cx);
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn wait_for_summaries(&mut self, cx: &mut Context<Self>) -> Task<()> {
|
||||
let context_store = self.context_store.clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
let checkpoint = checkpoint.await.ok();
|
||||
refresh_task.await;
|
||||
wait_for_images.await;
|
||||
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
let context = context_store.read(cx).context().clone();
|
||||
thread.insert_user_message(user_message, context, checkpoint, cx);
|
||||
})
|
||||
.log_err();
|
||||
|
||||
context_store
|
||||
.update(cx, |context_store, cx| {
|
||||
let excerpt_ids = context_store
|
||||
.context()
|
||||
.iter()
|
||||
.filter(|ctx| {
|
||||
matches!(
|
||||
ctx,
|
||||
AssistantContext::Selection(_) | AssistantContext::Image(_)
|
||||
)
|
||||
})
|
||||
.map(|ctx| ctx.id())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for id in excerpt_ids {
|
||||
context_store.remove_context(id, cx);
|
||||
}
|
||||
})
|
||||
.log_err();
|
||||
|
||||
if let Some(wait_for_summaries) = context_store
|
||||
.update(cx, |context_store, cx| context_store.wait_for_summaries(cx))
|
||||
.ok()
|
||||
.log_err()
|
||||
{
|
||||
this.update(cx, |this, cx| {
|
||||
this.waiting_for_summaries_to_send = true;
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
.log_err();
|
||||
|
||||
wait_for_summaries.await;
|
||||
|
||||
@@ -328,15 +326,24 @@ impl MessageEditor {
|
||||
this.waiting_for_summaries_to_send = false;
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
.log_err();
|
||||
}
|
||||
|
||||
// Send to model after summaries are done
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.advance_prompt_id();
|
||||
thread.send_to_model(model, cx);
|
||||
})
|
||||
.log_err();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn stop_current_and_send_new_message(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let cancelled = self.thread.update(cx, |thread, cx| {
|
||||
thread.cancel_last_completion(Some(window.window_handle()), cx)
|
||||
});
|
||||
let cancelled = self
|
||||
.thread
|
||||
.update(cx, |thread, cx| thread.cancel_last_completion(cx));
|
||||
|
||||
if cancelled {
|
||||
self.set_editor_is_expanded(false, cx);
|
||||
@@ -398,10 +405,8 @@ impl MessageEditor {
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_review_click(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.edits_expanded = true;
|
||||
fn handle_review_click(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
AgentDiff::deploy(self.thread.clone(), self.workspace.clone(), window, cx).log_err();
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn handle_file_click(
|
||||
@@ -1006,48 +1011,6 @@ impl MessageEditor {
|
||||
self.update_token_count_task.is_some()
|
||||
}
|
||||
|
||||
fn handle_message_changed(&mut self, cx: &mut Context<Self>) {
|
||||
self.message_or_context_changed(true, cx);
|
||||
}
|
||||
|
||||
fn start_context_load(&mut self, cx: &mut Context<Self>) -> Shared<Task<()>> {
|
||||
let summaries_task = self.wait_for_summaries(cx);
|
||||
let load_task = cx.spawn(async move |this, cx| {
|
||||
// Waits for detailed summaries before `load_context`, as it directly reads these from
|
||||
// the thread. TODO: Would be cleaner to have context loading await on summarization.
|
||||
summaries_task.await;
|
||||
let Ok(load_task) = this.update(cx, |this, cx| {
|
||||
let new_context = this.context_store.read_with(cx, |context_store, cx| {
|
||||
context_store.new_context_for_thread(this.thread.read(cx))
|
||||
});
|
||||
load_context(new_context, &this.project, &this.prompt_store, cx)
|
||||
}) else {
|
||||
return;
|
||||
};
|
||||
let result = load_task.await;
|
||||
this.update(cx, |this, cx| {
|
||||
this.last_loaded_context = Some(result);
|
||||
this.context_load_task = None;
|
||||
this.message_or_context_changed(false, cx);
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
// Replace existing load task, if any, causing it to be cancelled.
|
||||
let load_task = load_task.shared();
|
||||
self.context_load_task = Some(load_task.clone());
|
||||
load_task
|
||||
}
|
||||
|
||||
fn load_context(&mut self, cx: &mut Context<Self>) -> Task<Option<ContextLoadResult>> {
|
||||
let context_load_task = self.start_context_load(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
context_load_task.await;
|
||||
this.read_with(cx, |this, _cx| this.last_loaded_context.clone())
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
|
||||
fn message_or_context_changed(&mut self, debounce: bool, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Changed);
|
||||
self.update_token_count_task.take();
|
||||
@@ -1057,7 +1020,9 @@ impl MessageEditor {
|
||||
return;
|
||||
};
|
||||
|
||||
let context_store = self.context_store.clone();
|
||||
let editor = self.editor.clone();
|
||||
let thread = self.thread.clone();
|
||||
|
||||
self.update_token_count_task = Some(cx.spawn(async move |this, cx| {
|
||||
if debounce {
|
||||
@@ -1066,33 +1031,27 @@ impl MessageEditor {
|
||||
.await;
|
||||
}
|
||||
|
||||
let token_count = if let Some(task) = this.update(cx, |this, cx| {
|
||||
let loaded_context = this
|
||||
.last_loaded_context
|
||||
.as_ref()
|
||||
.map(|context_load_result| &context_load_result.loaded_context);
|
||||
let token_count = if let Some(task) = cx.update(|cx| {
|
||||
let context = context_store.read(cx).context().iter();
|
||||
let new_context = thread.read(cx).filter_new_context(context);
|
||||
let context_text =
|
||||
format_context_as_string(new_context, cx).unwrap_or(String::new());
|
||||
let message_text = editor.read(cx).text(cx);
|
||||
|
||||
if message_text.is_empty()
|
||||
&& loaded_context.map_or(true, |loaded_context| loaded_context.is_empty())
|
||||
{
|
||||
let content = context_text + &message_text;
|
||||
|
||||
if content.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: Vec::new(),
|
||||
cache: false,
|
||||
};
|
||||
|
||||
if let Some(loaded_context) = loaded_context {
|
||||
loaded_context.add_to_request_message(&mut request_message);
|
||||
}
|
||||
|
||||
let request = language_model::LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: language_model::Role::User,
|
||||
content: vec![content.into()],
|
||||
cache: false,
|
||||
}],
|
||||
tools: vec![],
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
|
||||
@@ -32,7 +32,7 @@ impl TerminalCodegen {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
|
||||
pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
|
||||
let Some(ConfiguredModel { model, .. }) =
|
||||
LanguageModelRegistry::read_global(cx).inline_assistant_model()
|
||||
else {
|
||||
@@ -45,7 +45,6 @@ impl TerminalCodegen {
|
||||
self.status = CodegenStatus::Pending;
|
||||
self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
|
||||
self.generation = cx.spawn(async move |this, cx| {
|
||||
let prompt = prompt_task.await;
|
||||
let model_telemetry_id = model.telemetry_id();
|
||||
let model_provider_id = model.provider_id();
|
||||
let response = model.stream_completion_text(prompt, &cx).await;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::context::load_context;
|
||||
use crate::context::attach_context_to_message;
|
||||
use crate::context_store::ContextStore;
|
||||
use crate::inline_prompt_editor::{
|
||||
CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
|
||||
@@ -10,14 +10,14 @@ use client::telemetry::Telemetry;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use editor::{MultiBuffer, actions::SelectAll};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
|
||||
use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
|
||||
use language::Buffer;
|
||||
use language_model::{
|
||||
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
Role, report_assistant_event,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::{PromptBuilder, PromptStore};
|
||||
use prompt_store::PromptBuilder;
|
||||
use std::sync::Arc;
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use terminal_view::TerminalView;
|
||||
@@ -69,7 +69,6 @@ impl TerminalInlineAssistant {
|
||||
terminal_view: &Entity<TerminalView>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: WeakEntity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
thread_store: Option<WeakEntity<ThreadStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
@@ -110,7 +109,6 @@ impl TerminalInlineAssistant {
|
||||
prompt_editor,
|
||||
workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -198,11 +196,11 @@ impl TerminalInlineAssistant {
|
||||
.log_err();
|
||||
|
||||
let codegen = assist.codegen.clone();
|
||||
let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
|
||||
return;
|
||||
};
|
||||
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
|
||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||
}
|
||||
|
||||
fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
|
||||
@@ -219,7 +217,7 @@ impl TerminalInlineAssistant {
|
||||
&self,
|
||||
assist_id: TerminalInlineAssistId,
|
||||
cx: &mut App,
|
||||
) -> Result<Task<LanguageModelRequest>> {
|
||||
) -> Result<LanguageModelRequest> {
|
||||
let assist = self.assists.get(&assist_id).context("invalid assist")?;
|
||||
|
||||
let shell = std::env::var("SHELL").ok();
|
||||
@@ -248,40 +246,28 @@ impl TerminalInlineAssistant {
|
||||
&latest_output,
|
||||
)?;
|
||||
|
||||
let contexts = assist
|
||||
.context_store
|
||||
.read(cx)
|
||||
.context()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let context_load_task = assist.workspace.update(cx, |workspace, cx| {
|
||||
let project = workspace.project();
|
||||
load_context(contexts, project, &assist.prompt_store, cx)
|
||||
})?;
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
Ok(cx.background_spawn(async move {
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
attach_context_to_message(
|
||||
&mut request_message,
|
||||
assist.context_store.read(cx).context().iter(),
|
||||
cx,
|
||||
);
|
||||
|
||||
context_load_task
|
||||
.await
|
||||
.loaded_context
|
||||
.add_to_request_message(&mut request_message);
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
request_message.content.push(prompt.into());
|
||||
|
||||
LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
}
|
||||
}))
|
||||
Ok(LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
messages: vec![request_message],
|
||||
tools: Vec::new(),
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn finish_assist(
|
||||
@@ -394,7 +380,6 @@ struct TerminalInlineAssist {
|
||||
codegen: Entity<TerminalCodegen>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -405,7 +390,6 @@ impl TerminalInlineAssist {
|
||||
prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
context_store: Entity<ContextStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
window: &mut Window,
|
||||
cx: &mut App,
|
||||
) -> Self {
|
||||
@@ -416,7 +400,6 @@ impl TerminalInlineAssist {
|
||||
codegen: codegen.clone(),
|
||||
workspace: workspace.clone(),
|
||||
context_store,
|
||||
prompt_store,
|
||||
_subscriptions: vec![
|
||||
window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
|
||||
TerminalInlineAssistant::update_global(cx, |this, cx| {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,8 +24,8 @@ use heed::types::SerdeBincode;
|
||||
use language_model::{LanguageModelToolUseId, Role, TokenUsage};
|
||||
use project::{Project, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext,
|
||||
UserRulesContext, WorktreeContext,
|
||||
ProjectContext, PromptBuilder, PromptId, PromptMetadata, PromptStore, PromptsUpdatedEvent,
|
||||
RulesFileContext, UserPromptId, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
@@ -82,11 +82,12 @@ impl ThreadStore {
|
||||
pub fn load(
|
||||
project: Entity<Project>,
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
prompt_builder: Arc<PromptBuilder>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Entity<Self>>> {
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
let prompt_store = prompt_store.await.ok();
|
||||
let (thread_store, ready_rx) = cx.update(|cx| {
|
||||
let mut option_ready_rx = None;
|
||||
let thread_store = cx.new(|cx| {
|
||||
@@ -348,8 +349,25 @@ impl ThreadStore {
|
||||
self.context_server_manager.clone()
|
||||
}
|
||||
|
||||
pub fn prompt_store(&self) -> &Option<Entity<PromptStore>> {
|
||||
&self.prompt_store
|
||||
pub fn prompt_store(&self) -> Option<Entity<PromptStore>> {
|
||||
self.prompt_store.clone()
|
||||
}
|
||||
|
||||
pub fn load_rules(
|
||||
&self,
|
||||
prompt_id: UserPromptId,
|
||||
cx: &App,
|
||||
) -> Task<Result<(PromptMetadata, String)>> {
|
||||
let prompt_id = PromptId::User { uuid: prompt_id };
|
||||
let Some(prompt_store) = self.prompt_store.as_ref() else {
|
||||
return Task::ready(Err(anyhow!("Prompt store unexpectedly missing.")));
|
||||
};
|
||||
let prompt_store = prompt_store.read(cx);
|
||||
let Some(metadata) = prompt_store.metadata(prompt_id) else {
|
||||
return Task::ready(Err(anyhow!("User rules not found in library.")));
|
||||
};
|
||||
let text_task = prompt_store.load(prompt_id, cx);
|
||||
cx.background_spawn(async move { Ok((metadata, text_task.await?)) })
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Entity<ToolWorkingSet> {
|
||||
@@ -361,12 +379,16 @@ impl ThreadStore {
|
||||
self.threads.len()
|
||||
}
|
||||
|
||||
pub fn reverse_chronological_threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
|
||||
let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
|
||||
threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
|
||||
threads
|
||||
}
|
||||
|
||||
pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
|
||||
self.threads().into_iter().take(limit).collect()
|
||||
}
|
||||
|
||||
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
|
||||
cx.new(|cx| {
|
||||
Thread::new(
|
||||
@@ -617,17 +639,12 @@ pub struct SerializedThread {
|
||||
}
|
||||
|
||||
impl SerializedThread {
|
||||
pub const VERSION: &'static str = "0.2.0";
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn from_json(json: &[u8]) -> Result<Self> {
|
||||
let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
|
||||
match saved_thread_json.get("version") {
|
||||
Some(serde_json::Value::String(version)) => match version.as_str() {
|
||||
SerializedThreadV0_1_0::VERSION => {
|
||||
let saved_thread =
|
||||
serde_json::from_value::<SerializedThreadV0_1_0>(saved_thread_json)?;
|
||||
Ok(saved_thread.upgrade())
|
||||
}
|
||||
SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
|
||||
saved_thread_json,
|
||||
)?),
|
||||
@@ -649,38 +666,6 @@ impl SerializedThread {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SerializedThreadV0_1_0(
|
||||
// The structure did not change, so we are reusing the latest SerializedThread.
|
||||
// When making the next version, make sure this points to SerializedThreadV0_2_0
|
||||
SerializedThread,
|
||||
);
|
||||
|
||||
impl SerializedThreadV0_1_0 {
|
||||
pub const VERSION: &'static str = "0.1.0";
|
||||
|
||||
pub fn upgrade(self) -> SerializedThread {
|
||||
debug_assert_eq!(SerializedThread::VERSION, "0.2.0");
|
||||
|
||||
let mut messages: Vec<SerializedMessage> = Vec::with_capacity(self.0.messages.len());
|
||||
|
||||
for message in self.0.messages {
|
||||
if message.role == Role::User && !message.tool_results.is_empty() {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
debug_assert!(last_message.role == Role::Assistant);
|
||||
|
||||
last_message.tool_results = message.tool_results;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
SerializedThread { messages, ..self.0 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SerializedMessage {
|
||||
pub id: MessageId,
|
||||
|
||||
@@ -30,6 +30,7 @@ pub struct ToolUse {
|
||||
pub struct ToolUseState {
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
|
||||
tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
|
||||
tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
|
||||
pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
|
||||
tool_result_cards: HashMap<LanguageModelToolUseId, AnyToolCard>,
|
||||
@@ -41,6 +42,7 @@ impl ToolUseState {
|
||||
Self {
|
||||
tools,
|
||||
tool_uses_by_assistant_message: HashMap::default(),
|
||||
tool_uses_by_user_message: HashMap::default(),
|
||||
tool_results: HashMap::default(),
|
||||
pending_tool_uses_by_id: HashMap::default(),
|
||||
tool_result_cards: HashMap::default(),
|
||||
@@ -54,6 +56,7 @@ impl ToolUseState {
|
||||
pub fn from_serialized_messages(
|
||||
tools: Entity<ToolWorkingSet>,
|
||||
messages: &[SerializedMessage],
|
||||
mut filter_by_tool_name: impl FnMut(&str) -> bool,
|
||||
) -> Self {
|
||||
let mut this = Self::new(tools);
|
||||
let mut tool_names_by_id = HashMap::default();
|
||||
@@ -65,10 +68,10 @@ impl ToolUseState {
|
||||
let tool_uses = message
|
||||
.tool_uses
|
||||
.iter()
|
||||
.filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
|
||||
.map(|tool_use| LanguageModelToolUse {
|
||||
id: tool_use.id.clone(),
|
||||
name: tool_use.name.clone().into(),
|
||||
raw_input: tool_use.input.to_string(),
|
||||
input: tool_use.input.clone(),
|
||||
is_input_complete: true,
|
||||
})
|
||||
@@ -82,6 +85,14 @@ impl ToolUseState {
|
||||
|
||||
this.tool_uses_by_assistant_message
|
||||
.insert(message.id, tool_uses);
|
||||
}
|
||||
}
|
||||
Role::User => {
|
||||
if !message.tool_results.is_empty() {
|
||||
let tool_uses_by_user_message = this
|
||||
.tool_uses_by_user_message
|
||||
.entry(message.id)
|
||||
.or_default();
|
||||
|
||||
for tool_result in &message.tool_results {
|
||||
let tool_use_id = tool_result.tool_use_id.clone();
|
||||
@@ -90,6 +101,11 @@ impl ToolUseState {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !(filter_by_tool_name)(tool_use.as_ref()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tool_uses_by_user_message.push(tool_use_id.clone());
|
||||
this.tool_results.insert(
|
||||
tool_use_id.clone(),
|
||||
LanguageModelToolResult {
|
||||
@@ -102,7 +118,7 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::System | Role::User => {}
|
||||
Role::System => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,26 +228,20 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_results_for_message(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Vec<&LanguageModelToolResult> {
|
||||
let Some(tool_uses) = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
else {
|
||||
return Vec::new();
|
||||
};
|
||||
pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
|
||||
let empty = Vec::new();
|
||||
|
||||
tool_uses
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.unwrap_or(&empty)
|
||||
.iter()
|
||||
.filter_map(|tool_use| self.tool_results.get(&tool_use.id))
|
||||
.filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)
|
||||
pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_user_message
|
||||
.get(&message_id)
|
||||
.map_or(false, |results| !results.is_empty())
|
||||
}
|
||||
|
||||
@@ -283,6 +293,14 @@ impl ToolUseState {
|
||||
self.tool_use_metadata_by_id
|
||||
.insert(tool_use.id.clone(), metadata);
|
||||
|
||||
// The tool use is being requested by the Assistant, so we want to
|
||||
// attach the tool results to the next user message.
|
||||
let next_user_message_id = MessageId(assistant_message_id.0 + 1);
|
||||
self.tool_uses_by_user_message
|
||||
.entry(next_user_message_id)
|
||||
.or_default()
|
||||
.push(tool_use.id.clone());
|
||||
|
||||
PendingToolUseStatus::Idle
|
||||
} else {
|
||||
PendingToolUseStatus::InputStillStreaming
|
||||
@@ -448,49 +466,31 @@ impl ToolUseState {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool {
|
||||
self.tool_uses_by_assistant_message
|
||||
.contains_key(&assistant_message_id)
|
||||
}
|
||||
|
||||
pub fn tool_results_message(
|
||||
pub fn attach_tool_results(
|
||||
&self,
|
||||
assistant_message_id: MessageId,
|
||||
) -> Option<LanguageModelRequestMessage> {
|
||||
let tool_uses = self
|
||||
.tool_uses_by_assistant_message
|
||||
.get(&assistant_message_id)?;
|
||||
|
||||
if tool_uses.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut request_message = LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![],
|
||||
cache: false,
|
||||
};
|
||||
|
||||
for tool_use in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(&tool_use.id) {
|
||||
request_message
|
||||
.content
|
||||
.push(MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_use.id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
message_id: MessageId,
|
||||
request_message: &mut LanguageModelRequestMessage,
|
||||
) {
|
||||
if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
|
||||
for tool_use_id in tool_uses {
|
||||
if let Some(tool_result) = self.tool_results.get(tool_use_id) {
|
||||
request_message.content.push(MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
tool_name: tool_result.tool_name.clone(),
|
||||
is_error: tool_result.is_error,
|
||||
content: if tool_result.content.is_empty() {
|
||||
// Surprisingly, the API fails if we return an empty string here.
|
||||
// It thinks we are sending a tool use without a tool result.
|
||||
"<Tool returned an empty string>".into()
|
||||
} else {
|
||||
tool_result.content.clone()
|
||||
},
|
||||
},
|
||||
}));
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(request_message)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
mod agent_notification;
|
||||
mod animated_label;
|
||||
mod context_pill;
|
||||
mod usage_banner;
|
||||
|
||||
pub use agent_notification::*;
|
||||
pub use animated_label::*;
|
||||
pub use context_pill::*;
|
||||
pub use usage_banner::*;
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
use gpui::{Animation, AnimationExt, FontWeight, pulsating_between};
|
||||
use std::time::Duration;
|
||||
use ui::prelude::*;
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct AnimatedLabel {
|
||||
base: Label,
|
||||
text: SharedString,
|
||||
}
|
||||
|
||||
impl AnimatedLabel {
|
||||
pub fn new(text: impl Into<SharedString>) -> Self {
|
||||
let text = text.into();
|
||||
AnimatedLabel {
|
||||
base: Label::new(text.clone()),
|
||||
text,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelCommon for AnimatedLabel {
|
||||
fn size(mut self, size: LabelSize) -> Self {
|
||||
self.base = self.base.size(size);
|
||||
self
|
||||
}
|
||||
|
||||
fn weight(mut self, weight: FontWeight) -> Self {
|
||||
self.base = self.base.weight(weight);
|
||||
self
|
||||
}
|
||||
|
||||
fn line_height_style(mut self, line_height_style: LineHeightStyle) -> Self {
|
||||
self.base = self.base.line_height_style(line_height_style);
|
||||
self
|
||||
}
|
||||
|
||||
fn color(mut self, color: Color) -> Self {
|
||||
self.base = self.base.color(color);
|
||||
self
|
||||
}
|
||||
|
||||
fn strikethrough(mut self) -> Self {
|
||||
self.base = self.base.strikethrough();
|
||||
self
|
||||
}
|
||||
|
||||
fn italic(mut self) -> Self {
|
||||
self.base = self.base.italic();
|
||||
self
|
||||
}
|
||||
|
||||
fn alpha(mut self, alpha: f32) -> Self {
|
||||
self.base = self.base.alpha(alpha);
|
||||
self
|
||||
}
|
||||
|
||||
fn underline(mut self) -> Self {
|
||||
self.base = self.base.underline();
|
||||
self
|
||||
}
|
||||
|
||||
fn truncate(mut self) -> Self {
|
||||
self.base = self.base.truncate();
|
||||
self
|
||||
}
|
||||
|
||||
fn single_line(mut self) -> Self {
|
||||
self.base = self.base.single_line();
|
||||
self
|
||||
}
|
||||
|
||||
fn buffer_font(mut self, cx: &App) -> Self {
|
||||
self.base = self.base.buffer_font(cx);
|
||||
self
|
||||
}
|
||||
|
||||
fn inline_code(mut self, cx: &App) -> Self {
|
||||
self.base = self.base.inline_code(cx);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for AnimatedLabel {
|
||||
fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
|
||||
let text = self.text.clone();
|
||||
|
||||
self.base
|
||||
.color(Color::Muted)
|
||||
.with_animations(
|
||||
"animated-label",
|
||||
vec![
|
||||
Animation::new(Duration::from_secs(1)),
|
||||
Animation::new(Duration::from_secs(1)).repeat(),
|
||||
],
|
||||
move |mut label, animation_ix, delta| {
|
||||
match animation_ix {
|
||||
0 => {
|
||||
let chars_to_show = (delta * text.len() as f32).ceil() as usize;
|
||||
let text = SharedString::from(text[0..chars_to_show].to_string());
|
||||
label.set_text(text);
|
||||
}
|
||||
1 => match delta {
|
||||
d if d < 0.25 => label.set_text(text.clone()),
|
||||
d if d < 0.5 => label.set_text(format!("{}.", text)),
|
||||
d if d < 0.75 => label.set_text(format!("{}..", text)),
|
||||
_ => label.set_text(format!("{}...", text)),
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
label
|
||||
},
|
||||
)
|
||||
.with_animation(
|
||||
"pulsating-label",
|
||||
Animation::new(Duration::from_secs(2))
|
||||
.repeat()
|
||||
.with_easing(pulsating_between(0.6, 1.)),
|
||||
|label, delta| label.map_element(|label| label.alpha(delta)),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
use std::{rc::Rc, time::Duration};
|
||||
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{Animation, AnimationExt as _, ClickEvent, Entity, MouseButton, pulsating_between};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
use text::OffsetRangeExt;
|
||||
use futures::FutureExt;
|
||||
use gpui::{Animation, AnimationExt as _, Image, MouseButton, pulsating_between};
|
||||
use gpui::{ClickEvent, Task};
|
||||
use language_model::LanguageModelImage;
|
||||
use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container};
|
||||
|
||||
use crate::context::{AgentContext, ContextKind, ImageStatus};
|
||||
use crate::context::{AssistantContext, ContextId, ContextKind, ImageContext};
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub enum ContextPill {
|
||||
@@ -72,7 +73,9 @@ impl ContextPill {
|
||||
|
||||
pub fn id(&self) -> ElementId {
|
||||
match self {
|
||||
Self::Added { context, .. } => context.context.element_id("context-pill".into()),
|
||||
Self::Added { context, .. } => {
|
||||
ElementId::NamedInteger("context-pill".into(), context.id.0)
|
||||
}
|
||||
Self::Suggested { .. } => "suggested-context-pill".into(),
|
||||
}
|
||||
}
|
||||
@@ -196,17 +199,14 @@ impl RenderOnce for ContextPill {
|
||||
)
|
||||
.when_some(on_remove.as_ref(), |element, on_remove| {
|
||||
element.child(
|
||||
IconButton::new(
|
||||
context.context.element_id("remove".into()),
|
||||
IconName::Close,
|
||||
)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
.on_click({
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
IconButton::new(("remove", context.id.0), IconName::Close)
|
||||
.shape(IconButtonShape::Square)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.tooltip(Tooltip::text("Remove Context"))
|
||||
.on_click({
|
||||
let on_remove = on_remove.clone();
|
||||
move |event, window, cx| on_remove(event, window, cx)
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when_some(on_click.as_ref(), |element, on_click| {
|
||||
@@ -262,11 +262,9 @@ pub enum ContextStatus {
|
||||
Error { message: SharedString },
|
||||
}
|
||||
|
||||
// TODO: Component commented out due to new dependency on `Project`.
|
||||
//
|
||||
// #[derive(RegisterComponent)]
|
||||
#[derive(RegisterComponent)]
|
||||
pub struct AddedContext {
|
||||
pub context: AgentContext,
|
||||
pub id: ContextId,
|
||||
pub kind: ContextKind,
|
||||
pub name: SharedString,
|
||||
pub parent: Option<SharedString>,
|
||||
@@ -277,19 +275,10 @@ pub struct AddedContext {
|
||||
}
|
||||
|
||||
impl AddedContext {
|
||||
/// Creates an `AddedContext` by retrieving relevant details of `AgentContext`. This returns a
|
||||
/// `None` if `DirectoryContext` or `RulesContext` no longer exist.
|
||||
///
|
||||
/// TODO: `None` cases are unremovable from `ContextStore` and so are a very minor memory leak.
|
||||
pub fn new(
|
||||
context: AgentContext,
|
||||
prompt_store: Option<&Entity<PromptStore>>,
|
||||
project: &Project,
|
||||
cx: &App,
|
||||
) -> Option<AddedContext> {
|
||||
pub fn new(context: &AssistantContext, cx: &App) -> AddedContext {
|
||||
match context {
|
||||
AgentContext::File(ref file_context) => {
|
||||
let full_path = file_context.buffer.read(cx).file()?.full_path(cx);
|
||||
AssistantContext::File(file_context) => {
|
||||
let full_path = file_context.context_buffer.full_path(cx);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
@@ -300,7 +289,8 @@ impl AddedContext {
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
Some(AddedContext {
|
||||
AddedContext {
|
||||
id: file_context.id,
|
||||
kind: ContextKind::File,
|
||||
name,
|
||||
parent,
|
||||
@@ -308,16 +298,18 @@ impl AddedContext {
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
AgentContext::Directory(ref directory_context) => {
|
||||
let worktree = project
|
||||
.worktree_for_entry(directory_context.entry_id, cx)?
|
||||
.read(cx);
|
||||
let entry = worktree.entry_for_id(directory_context.entry_id)?;
|
||||
let full_path = worktree.full_path(&entry.path);
|
||||
AssistantContext::Directory(directory_context) => {
|
||||
let worktree = directory_context.worktree.read(cx);
|
||||
// If the directory no longer exists, use its last known path.
|
||||
let full_path = worktree
|
||||
.entry_for_id(directory_context.entry_id)
|
||||
.map_or_else(
|
||||
|| directory_context.last_path.clone(),
|
||||
|entry| worktree.full_path(&entry.path).into(),
|
||||
);
|
||||
let full_path_string: SharedString =
|
||||
full_path.to_string_lossy().into_owned().into();
|
||||
let name = full_path
|
||||
@@ -328,7 +320,8 @@ impl AddedContext {
|
||||
.parent()
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
Some(AddedContext {
|
||||
AddedContext {
|
||||
id: directory_context.id,
|
||||
kind: ContextKind::Directory,
|
||||
name,
|
||||
parent,
|
||||
@@ -336,34 +329,33 @@ impl AddedContext {
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
AgentContext::Symbol(ref symbol_context) => Some(AddedContext {
|
||||
AssistantContext::Symbol(symbol_context) => AddedContext {
|
||||
id: symbol_context.id,
|
||||
kind: ContextKind::Symbol,
|
||||
name: symbol_context.symbol.clone(),
|
||||
name: symbol_context.context_symbol.id.name.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
}),
|
||||
},
|
||||
|
||||
AgentContext::Selection(ref selection_context) => {
|
||||
let buffer = selection_context.buffer.read(cx);
|
||||
let full_path = buffer.file()?.full_path(cx);
|
||||
AssistantContext::Selection(selection_context) => {
|
||||
let full_path = selection_context.context_buffer.full_path(cx);
|
||||
let mut full_path_string = full_path.to_string_lossy().into_owned();
|
||||
let mut name = full_path
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| full_path_string.clone());
|
||||
|
||||
let line_range = selection_context.range.to_point(&buffer.snapshot());
|
||||
|
||||
let line_range_text =
|
||||
format!(" ({}-{})", line_range.start.row + 1, line_range.end.row + 1);
|
||||
let line_range_text = format!(
|
||||
" ({}-{})",
|
||||
selection_context.line_range.start.row + 1,
|
||||
selection_context.line_range.end.row + 1
|
||||
);
|
||||
|
||||
full_path_string.push_str(&line_range_text);
|
||||
name.push_str(&line_range_text);
|
||||
@@ -373,17 +365,16 @@ impl AddedContext {
|
||||
.and_then(|p| p.file_name())
|
||||
.map(|n| n.to_string_lossy().into_owned().into());
|
||||
|
||||
Some(AddedContext {
|
||||
AddedContext {
|
||||
id: selection_context.id,
|
||||
kind: ContextKind::Selection,
|
||||
name: name.into(),
|
||||
parent,
|
||||
tooltip: None,
|
||||
icon_path: FileIcons::get_icon(&full_path, cx),
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
/*
|
||||
render_preview: Some(Rc::new({
|
||||
let content = selection_context.text.clone();
|
||||
let content = selection_context.context_buffer.text.clone();
|
||||
move |_, cx| {
|
||||
div()
|
||||
.id("context-pill-selection-preview")
|
||||
@@ -394,12 +385,11 @@ impl AddedContext {
|
||||
.into_any_element()
|
||||
}
|
||||
})),
|
||||
*/
|
||||
context,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
AgentContext::FetchedUrl(ref fetched_url_context) => Some(AddedContext {
|
||||
AssistantContext::FetchedUrl(fetched_url_context) => AddedContext {
|
||||
id: fetched_url_context.id,
|
||||
kind: ContextKind::FetchedUrl,
|
||||
name: fetched_url_context.url.clone(),
|
||||
parent: None,
|
||||
@@ -407,12 +397,12 @@ impl AddedContext {
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
}),
|
||||
},
|
||||
|
||||
AgentContext::Thread(ref thread_context) => Some(AddedContext {
|
||||
AssistantContext::Thread(thread_context) => AddedContext {
|
||||
id: thread_context.id,
|
||||
kind: ContextKind::Thread,
|
||||
name: thread_context.name(cx),
|
||||
name: thread_context.summary(cx),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
@@ -428,41 +418,36 @@ impl AddedContext {
|
||||
ContextStatus::Ready
|
||||
},
|
||||
render_preview: None,
|
||||
context,
|
||||
}),
|
||||
},
|
||||
|
||||
AgentContext::Rules(ref user_rules_context) => {
|
||||
let name = prompt_store
|
||||
.as_ref()?
|
||||
.read(cx)
|
||||
.metadata(user_rules_context.prompt_id.into())?
|
||||
.title?;
|
||||
Some(AddedContext {
|
||||
kind: ContextKind::Rules,
|
||||
name: name.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
context,
|
||||
})
|
||||
}
|
||||
AssistantContext::Rules(user_rules_context) => AddedContext {
|
||||
id: user_rules_context.id,
|
||||
kind: ContextKind::Rules,
|
||||
name: user_rules_context.title.clone(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: ContextStatus::Ready,
|
||||
render_preview: None,
|
||||
},
|
||||
|
||||
AgentContext::Image(ref image_context) => Some(AddedContext {
|
||||
AssistantContext::Image(image_context) => AddedContext {
|
||||
id: image_context.id,
|
||||
kind: ContextKind::Image,
|
||||
name: "Image".into(),
|
||||
parent: None,
|
||||
tooltip: None,
|
||||
icon_path: None,
|
||||
status: match image_context.status() {
|
||||
ImageStatus::Loading => ContextStatus::Loading {
|
||||
status: if image_context.is_loading() {
|
||||
ContextStatus::Loading {
|
||||
message: "Loading…".into(),
|
||||
},
|
||||
ImageStatus::Error => ContextStatus::Error {
|
||||
}
|
||||
} else if image_context.is_error() {
|
||||
ContextStatus::Error {
|
||||
message: "Failed to load image".into(),
|
||||
},
|
||||
ImageStatus::Ready => ContextStatus::Ready,
|
||||
}
|
||||
} else {
|
||||
ContextStatus::Ready
|
||||
},
|
||||
render_preview: Some(Rc::new({
|
||||
let image = image_context.original_image.clone();
|
||||
@@ -473,8 +458,7 @@ impl AddedContext {
|
||||
.into_any_element()
|
||||
}
|
||||
})),
|
||||
context,
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -494,8 +478,6 @@ impl Render for ContextPillPreview {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Component commented out due to new dependency on `Project`.
|
||||
/*
|
||||
impl Component for AddedContext {
|
||||
fn scope() -> ComponentScope {
|
||||
ComponentScope::Agent
|
||||
@@ -505,13 +487,12 @@ impl Component for AddedContext {
|
||||
"AddedContext"
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
let next_context_id = ContextId::zero();
|
||||
fn preview(_window: &mut Window, cx: &mut App) -> Option<AnyElement> {
|
||||
let image_ready = (
|
||||
"Ready",
|
||||
AddedContext::new(
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(0),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(Some(LanguageModelImage::empty())).shared(),
|
||||
}),
|
||||
@@ -522,8 +503,8 @@ impl Component for AddedContext {
|
||||
let image_loading = (
|
||||
"Loading",
|
||||
AddedContext::new(
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(1),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: cx
|
||||
.background_spawn(async move {
|
||||
@@ -539,8 +520,8 @@ impl Component for AddedContext {
|
||||
let image_error = (
|
||||
"Error",
|
||||
AddedContext::new(
|
||||
AgentContext::Image(ImageContext {
|
||||
context_id: next_context_id.post_inc(),
|
||||
&AssistantContext::Image(ImageContext {
|
||||
id: ContextId(2),
|
||||
original_image: Arc::new(Image::empty()),
|
||||
image_task: Task::ready(None).shared(),
|
||||
}),
|
||||
@@ -563,8 +544,5 @@ impl Component for AddedContext {
|
||||
)
|
||||
.into_any(),
|
||||
)
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
use client::zed_urls;
|
||||
use language_model::RequestUsage;
|
||||
use ui::{Banner, ProgressBar, Severity, prelude::*};
|
||||
use zed_llm_client::{Plan, UsageLimit};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct UsageBanner {
|
||||
plan: Plan,
|
||||
usage: RequestUsage,
|
||||
requests: i32,
|
||||
}
|
||||
|
||||
impl UsageBanner {
|
||||
pub fn new(plan: Plan, usage: RequestUsage) -> Self {
|
||||
Self { plan, usage }
|
||||
pub fn new(plan: Plan, requests: i32) -> Self {
|
||||
Self { plan, requests }
|
||||
}
|
||||
}
|
||||
|
||||
impl RenderOnce for UsageBanner {
|
||||
fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
|
||||
let used_percentage = match self.usage.limit {
|
||||
UsageLimit::Limited(limit) => Some((self.usage.amount as f32 / limit as f32) * 100.),
|
||||
let request_limit = self.plan.model_requests_limit();
|
||||
|
||||
let used_percentage = match request_limit {
|
||||
UsageLimit::Limited(limit) => Some((self.requests as f32 / limit as f32) * 100.),
|
||||
UsageLimit::Unlimited => None,
|
||||
};
|
||||
|
||||
let (severity, message) = match self.usage.limit {
|
||||
let (severity, message) = match request_limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
if self.usage.amount >= limit {
|
||||
if self.requests >= limit {
|
||||
let message = match self.plan {
|
||||
Plan::ZedPro => "Monthly request limit reached",
|
||||
Plan::ZedProTrial => "Trial request limit reached",
|
||||
@@ -32,7 +33,7 @@ impl RenderOnce for UsageBanner {
|
||||
};
|
||||
|
||||
(Severity::Error, message)
|
||||
} else if (self.usage.amount as f32 / limit as f32) >= 0.9 {
|
||||
} else if (self.requests as f32 / limit as f32) >= 0.9 {
|
||||
(Severity::Warning, "Approaching request limit")
|
||||
} else {
|
||||
let message = match self.plan {
|
||||
@@ -80,11 +81,11 @@ impl RenderOnce for UsageBanner {
|
||||
.child(ProgressBar::new("usage", percent, 100., cx))
|
||||
}))
|
||||
.child(
|
||||
Label::new(match self.usage.limit {
|
||||
Label::new(match request_limit {
|
||||
UsageLimit::Limited(limit) => {
|
||||
format!("{} / {limit}", self.usage.amount)
|
||||
format!("{} / {limit}", self.requests)
|
||||
}
|
||||
UsageLimit::Unlimited => format!("{} / ∞", self.usage.amount),
|
||||
UsageLimit::Unlimited => format!("{} / ∞", self.requests),
|
||||
})
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
@@ -103,131 +104,74 @@ impl Component for UsageBanner {
|
||||
}
|
||||
|
||||
fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
|
||||
let trial_limit = Plan::ZedProTrial.model_requests_limit();
|
||||
let trial_examples = vec![
|
||||
single_example(
|
||||
"Zed Pro Trial - New User",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedProTrial,
|
||||
RequestUsage {
|
||||
limit: trial_limit,
|
||||
amount: 10,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::ZedProTrial, 10))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Trial - Approaching Limit",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedProTrial,
|
||||
RequestUsage {
|
||||
limit: trial_limit,
|
||||
amount: 135,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::ZedProTrial, 135))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro Trial - Request Limit Reached",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedProTrial,
|
||||
RequestUsage {
|
||||
limit: trial_limit,
|
||||
amount: 150,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::ZedProTrial, 150))
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
let free_limit = Plan::Free.model_requests_limit();
|
||||
let free_examples = vec![
|
||||
single_example(
|
||||
"Free - Normal Usage",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::Free,
|
||||
RequestUsage {
|
||||
limit: free_limit,
|
||||
amount: 25,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::Free, 25))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Free - Approaching Limit",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::Free,
|
||||
RequestUsage {
|
||||
limit: free_limit,
|
||||
amount: 45,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::Free, 45))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Free - Request Limit Reached",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::Free,
|
||||
RequestUsage {
|
||||
limit: free_limit,
|
||||
amount: 50,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::Free, 50))
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
let zed_pro_limit = Plan::ZedPro.model_requests_limit();
|
||||
let zed_pro_examples = vec![
|
||||
single_example(
|
||||
"Zed Pro - Normal Usage",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedPro,
|
||||
RequestUsage {
|
||||
limit: zed_pro_limit,
|
||||
amount: 250,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::ZedPro, 250))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro - Approaching Limit",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedPro,
|
||||
RequestUsage {
|
||||
limit: zed_pro_limit,
|
||||
amount: 450,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::ZedPro, 450))
|
||||
.into_any_element(),
|
||||
),
|
||||
single_example(
|
||||
"Zed Pro - Request Limit Reached",
|
||||
div()
|
||||
.size_full()
|
||||
.child(UsageBanner::new(
|
||||
Plan::ZedPro,
|
||||
RequestUsage {
|
||||
limit: zed_pro_limit,
|
||||
amount: 500,
|
||||
},
|
||||
))
|
||||
.child(UsageBanner::new(Plan::ZedPro, 500))
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
mod supported_countries;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -9,6 +11,8 @@ use serde::{Deserialize, Serialize};
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
pub use supported_countries::*;
|
||||
|
||||
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
|
||||
225
crates/anthropic/src/supported_countries.rs
Normal file
225
crates/anthropic/src/supported_countries.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Returns whether the given country code is supported by Anthropic.
|
||||
///
|
||||
/// <https://www.anthropic.com/supported-countries>
|
||||
pub fn is_supported_country(country_code: &str) -> bool {
|
||||
SUPPORTED_COUNTRIES.contains(&country_code)
|
||||
}
|
||||
|
||||
/// The list of country codes supported by Anthropic.
|
||||
///
|
||||
/// https://www.anthropic.com/supported-countries
|
||||
static SUPPORTED_COUNTRIES: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
|
||||
vec![
|
||||
"AL", // Albania
|
||||
"DZ", // Algeria
|
||||
"AS", // American Samoa (US)
|
||||
"AD", // Andorra
|
||||
"AO", // Angola
|
||||
"AI", // Anguilla (UK)
|
||||
"AG", // Antigua and Barbuda
|
||||
"AR", // Argentina
|
||||
"AM", // Armenia
|
||||
"AU", // Australia
|
||||
"AT", // Austria
|
||||
"AZ", // Azerbaijan
|
||||
"BS", // Bahamas
|
||||
"BH", // Bahrain
|
||||
"BD", // Bangladesh
|
||||
"BB", // Barbados
|
||||
"BE", // Belgium
|
||||
"BZ", // Belize
|
||||
"BJ", // Benin
|
||||
"BM", // Bermuda (UK)
|
||||
"BT", // Bhutan
|
||||
"BO", // Bolivia
|
||||
"BA", // Bosnia and Herzegovina
|
||||
"BW", // Botswana
|
||||
"BR", // Brazil
|
||||
"IO", // British Indian Ocean Territory (UK)
|
||||
"BN", // Brunei
|
||||
"BG", // Bulgaria
|
||||
"BF", // Burkina Faso
|
||||
"BI", // Burundi
|
||||
"CV", // Cabo Verde
|
||||
"KH", // Cambodia
|
||||
"CM", // Cameroon
|
||||
"CA", // Canada
|
||||
"KY", // Cayman Islands (UK)
|
||||
"TD", // Chad
|
||||
"CL", // Chile
|
||||
"CX", // Christmas Island (AU)
|
||||
"CC", // Cocos (Keeling) Islands (AU)
|
||||
"CO", // Colombia
|
||||
"KM", // Comoros
|
||||
"CG", // Congo (Brazzaville)
|
||||
"CK", // Cook Islands (NZ)
|
||||
"CR", // Costa Rica
|
||||
"CI", // Côte d'Ivoire
|
||||
"HR", // Croatia
|
||||
"CY", // Cyprus
|
||||
"CZ", // Czechia (Czech Republic)
|
||||
"DK", // Denmark
|
||||
"DJ", // Djibouti
|
||||
"DM", // Dominica
|
||||
"DO", // Dominican Republic
|
||||
"EC", // Ecuador
|
||||
"EG", // Egypt
|
||||
"SV", // El Salvador
|
||||
"GQ", // Equatorial Guinea
|
||||
"EE", // Estonia
|
||||
"SZ", // Eswatini
|
||||
"FK", // Falkland Islands (UK)
|
||||
"FJ", // Fiji
|
||||
"FI", // Finland
|
||||
"FR", // France
|
||||
"GF", // French Guiana (FR)
|
||||
"PF", // French Polynesia (FR)
|
||||
"TF", // French Southern Territories
|
||||
"GA", // Gabon
|
||||
"GM", // Gambia
|
||||
"GE", // Georgia
|
||||
"DE", // Germany
|
||||
"GH", // Ghana
|
||||
"GI", // Gibraltar (UK)
|
||||
"GR", // Greece
|
||||
"GD", // Grenada
|
||||
"GT", // Guatemala
|
||||
"GU", // Guam (US)
|
||||
"GN", // Guinea
|
||||
"GW", // Guinea-Bissau
|
||||
"GY", // Guyana
|
||||
"HT", // Haiti
|
||||
"HM", // Heard Island and McDonald Islands (AU)
|
||||
"HN", // Honduras
|
||||
"HU", // Hungary
|
||||
"IS", // Iceland
|
||||
"IN", // India
|
||||
"ID", // Indonesia
|
||||
"IQ", // Iraq
|
||||
"IE", // Ireland
|
||||
"IL", // Israel
|
||||
"IT", // Italy
|
||||
"JM", // Jamaica
|
||||
"JP", // Japan
|
||||
"JO", // Jordan
|
||||
"KZ", // Kazakhstan
|
||||
"KE", // Kenya
|
||||
"KI", // Kiribati
|
||||
"KW", // Kuwait
|
||||
"KG", // Kyrgyzstan
|
||||
"LA", // Laos
|
||||
"LV", // Latvia
|
||||
"LB", // Lebanon
|
||||
"LS", // Lesotho
|
||||
"LR", // Liberia
|
||||
"LI", // Liechtenstein
|
||||
"LT", // Lithuania
|
||||
"LU", // Luxembourg
|
||||
"MG", // Madagascar
|
||||
"MW", // Malawi
|
||||
"MY", // Malaysia
|
||||
"MV", // Maldives
|
||||
"MT", // Malta
|
||||
"MH", // Marshall Islands
|
||||
"MR", // Mauritania
|
||||
"MU", // Mauritius
|
||||
"MX", // Mexico
|
||||
"FM", // Micronesia
|
||||
"MD", // Moldova
|
||||
"MC", // Monaco
|
||||
"MN", // Mongolia
|
||||
"MS", // Montserrat (UK)
|
||||
"ME", // Montenegro
|
||||
"MA", // Morocco
|
||||
"MZ", // Mozambique
|
||||
"NA", // Namibia
|
||||
"NR", // Nauru
|
||||
"NP", // Nepal
|
||||
"NL", // Netherlands
|
||||
"NZ", // New Zealand
|
||||
"NE", // Niger
|
||||
"NG", // Nigeria
|
||||
"NF", // Norfolk Island (AU)
|
||||
"MK", // North Macedonia
|
||||
"MI", // Northern Mariana Islands (UK)
|
||||
"NO", // Norway
|
||||
"NU", // Niue (NZ)
|
||||
"OM", // Oman
|
||||
"PK", // Pakistan
|
||||
"PW", // Palau
|
||||
"PS", // Palestine
|
||||
"PA", // Panama
|
||||
"PG", // Papua New Guinea
|
||||
"PY", // Paraguay
|
||||
"PE", // Peru
|
||||
"PH", // Philippines
|
||||
"PN", // Pitcairn (UK)
|
||||
"PL", // Poland
|
||||
"PT", // Portugal
|
||||
"PR", // Puerto Rico (US)
|
||||
"QA", // Qatar
|
||||
"RO", // Romania
|
||||
"RW", // Rwanda
|
||||
"BL", // Saint Barthélemy (FR)
|
||||
"KN", // Saint Kitts and Nevis
|
||||
"LC", // Saint Lucia
|
||||
"MF", // Saint Martin (FR)
|
||||
"PM", // Saint Pierre and Miquelon (FR)
|
||||
"VC", // Saint Vincent and the Grenadines
|
||||
"WS", // Samoa
|
||||
"SM", // San Marino
|
||||
"ST", // São Tomé and Príncipe
|
||||
"SA", // Saudi Arabia
|
||||
"SN", // Senegal
|
||||
"RS", // Serbia
|
||||
"SC", // Seychelles
|
||||
"SH", // Saint Helena, Ascension and Tristan da Cunha (UK)
|
||||
"SL", // Sierra Leone
|
||||
"SG", // Singapore
|
||||
"SK", // Slovakia
|
||||
"SI", // Slovenia
|
||||
"SB", // Solomon Islands
|
||||
"ZA", // South Africa
|
||||
"KR", // South Korea
|
||||
"ES", // Spain
|
||||
"LK", // Sri Lanka
|
||||
"SR", // Suriname
|
||||
"SE", // Sweden
|
||||
"CH", // Switzerland
|
||||
"TW", // Taiwan
|
||||
"TJ", // Tajikistan
|
||||
"TZ", // Tanzania
|
||||
"TH", // Thailand
|
||||
"TL", // Timor-Leste
|
||||
"TG", // Togo
|
||||
"TK", // Tokelau (NZ)
|
||||
"TO", // Tonga
|
||||
"TT", // Trinidad and Tobago
|
||||
"TN", // Tunisia
|
||||
"TR", // Türkiye (Turkey)
|
||||
"TM", // Turkmenistan
|
||||
"TC", // Turks and Caicos Islands (UK)
|
||||
"TV", // Tuvalu
|
||||
"UG", // Uganda
|
||||
"UA", // Ukraine (except Crimea, Donetsk, and Luhansk regions)
|
||||
"AE", // United Arab Emirates
|
||||
"GB", // United Kingdom
|
||||
"UM", // United States Minor Outlying Islands (US)
|
||||
"US", // United States of America
|
||||
"UY", // Uruguay
|
||||
"UZ", // Uzbekistan
|
||||
"VU", // Vanuatu
|
||||
"VA", // Vatican City
|
||||
"VN", // Vietnam
|
||||
"VI", // Virgin Islands (US)
|
||||
"VG", // Virgin Islands (UK)
|
||||
"WF", // Wallis and Futuna (FR)
|
||||
"ZM", // Zambia
|
||||
"ZW", // Zimbabwe
|
||||
]
|
||||
.into_iter()
|
||||
.collect()
|
||||
});
|
||||
@@ -49,7 +49,7 @@ menu.workspace = true
|
||||
multi_buffer.workspace = true
|
||||
parking_lot.workspace = true
|
||||
project.workspace = true
|
||||
rules_library.workspace = true
|
||||
prompt_library.workspace = true
|
||||
prompt_store.workspace = true
|
||||
proto.workspace = true
|
||||
rope.workspace = true
|
||||
|
||||
@@ -101,7 +101,7 @@ pub fn init(
|
||||
SlashCommandSettings::register(cx);
|
||||
|
||||
assistant_context_editor::init(client.clone(), cx);
|
||||
rules_library::init(cx);
|
||||
prompt_library::init(cx);
|
||||
init_language_model_settings(cx);
|
||||
assistant_slash_command::init(cx);
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -23,10 +23,11 @@ use gpui::{
|
||||
use language::LanguageRegistry;
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModelProviderId, LanguageModelRegistry,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::{PromptBuilder, UserPromptId};
|
||||
use rules_library::{RulesLibrary, open_rules_library};
|
||||
use prompt_library::{PromptLibrary, open_prompt_library};
|
||||
use prompt_store::{PromptBuilder, PromptId, UserPromptId};
|
||||
|
||||
use search::{BufferSearchBar, buffer_search::DivRegistrar};
|
||||
use settings::{Settings, update_settings_file};
|
||||
@@ -43,7 +44,7 @@ use workspace::{
|
||||
dock::{DockPosition, Panel, PanelEvent},
|
||||
pane,
|
||||
};
|
||||
use zed_actions::assistant::{InlineAssist, OpenRulesLibrary, ShowConfiguration, ToggleFocus};
|
||||
use zed_actions::assistant::{InlineAssist, OpenPromptLibrary, ShowConfiguration, ToggleFocus};
|
||||
|
||||
pub fn init(cx: &mut App) {
|
||||
workspace::FollowableViewRegistry::register::<ContextEditor>(cx);
|
||||
@@ -57,11 +58,11 @@ pub fn init(cx: &mut App) {
|
||||
.register_action(AssistantPanel::show_configuration)
|
||||
.register_action(AssistantPanel::create_new_context)
|
||||
.register_action(AssistantPanel::restart_context_servers)
|
||||
.register_action(|workspace, action: &OpenRulesLibrary, window, cx| {
|
||||
.register_action(|workspace, action: &OpenPromptLibrary, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AssistantPanel>(cx) {
|
||||
workspace.focus_panel::<AssistantPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| {
|
||||
panel.deploy_rules_library(action, window, cx)
|
||||
panel.deploy_prompt_library(action, window, cx)
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -272,8 +273,8 @@ impl AssistantPanel {
|
||||
.action("New Chat", Box::new(NewChat))
|
||||
.action("History", Box::new(DeployHistory))
|
||||
.action(
|
||||
"Rules Library",
|
||||
Box::new(OpenRulesLibrary::default()),
|
||||
"Prompt Library",
|
||||
Box::new(OpenPromptLibrary::default()),
|
||||
)
|
||||
.action("Configure", Box::new(ShowConfiguration))
|
||||
.action(zoom_label, Box::new(ToggleZoom))
|
||||
@@ -488,8 +489,8 @@ impl AssistantPanel {
|
||||
|
||||
// If we're signed out and don't have a provider configured, or we're signed-out AND Zed.dev is
|
||||
// the provider, we want to show a nudge to sign in.
|
||||
let show_zed_ai_notice =
|
||||
client_status.is_signed_out() && model.map_or(true, |model| model.is_provided_by_zed());
|
||||
let show_zed_ai_notice = client_status.is_signed_out()
|
||||
&& model.map_or(true, |model| model.provider.id().0 == ZED_CLOUD_PROVIDER_ID);
|
||||
|
||||
self.show_zed_ai_notice = show_zed_ai_notice;
|
||||
cx.notify();
|
||||
@@ -1043,13 +1044,13 @@ impl AssistantPanel {
|
||||
}
|
||||
}
|
||||
|
||||
fn deploy_rules_library(
|
||||
fn deploy_prompt_library(
|
||||
&mut self,
|
||||
action: &OpenRulesLibrary,
|
||||
action: &OpenPromptLibrary,
|
||||
_window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
open_rules_library(
|
||||
open_prompt_library(
|
||||
self.languages.clone(),
|
||||
Box::new(PromptLibraryInlineAssist),
|
||||
Arc::new(|| {
|
||||
@@ -1059,9 +1060,9 @@ impl AssistantPanel {
|
||||
None,
|
||||
))
|
||||
}),
|
||||
action
|
||||
.prompt_to_select
|
||||
.map(|uuid| UserPromptId(uuid).into()),
|
||||
action.prompt_to_select.map(|uuid| PromptId::User {
|
||||
uuid: UserPromptId(uuid),
|
||||
}),
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
@@ -1235,7 +1236,7 @@ impl Render for AssistantPanel {
|
||||
this.show_configuration_tab(window, cx)
|
||||
}))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_history))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_rules_library))
|
||||
.on_action(cx.listener(AssistantPanel::deploy_prompt_library))
|
||||
.child(registrar.size_full().child(self.pane.clone()))
|
||||
.into_any_element()
|
||||
}
|
||||
@@ -1350,13 +1351,13 @@ impl Focusable for AssistantPanel {
|
||||
|
||||
struct PromptLibraryInlineAssist;
|
||||
|
||||
impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
impl prompt_library::InlineAssistDelegate for PromptLibraryInlineAssist {
|
||||
fn assist(
|
||||
&self,
|
||||
prompt_editor: &Entity<Editor>,
|
||||
initial_prompt: Option<String>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<RulesLibrary>,
|
||||
cx: &mut Context<PromptLibrary>,
|
||||
) {
|
||||
InlineAssistant::update_global(cx, |assistant, cx| {
|
||||
assistant.assist(&prompt_editor, None, None, initial_prompt, window, cx)
|
||||
|
||||
@@ -18,11 +18,11 @@ use editor::{
|
||||
},
|
||||
};
|
||||
use feature_flags::{
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedProFeatureFlag,
|
||||
Assistant2FeatureFlag, FeatureFlagAppExt as _, FeatureFlagViewExt as _, ZedPro,
|
||||
};
|
||||
use fs::Fs;
|
||||
use futures::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt as _,
|
||||
SinkExt, Stream, StreamExt,
|
||||
channel::mpsc,
|
||||
future::{BoxFuture, LocalBoxFuture},
|
||||
join,
|
||||
@@ -1226,7 +1226,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<InlineAssist>(
|
||||
row_range,
|
||||
cx.theme().status().info_background,
|
||||
Default::default(),
|
||||
false,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
@@ -1291,7 +1291,7 @@ impl InlineAssistant {
|
||||
editor.highlight_rows::<DeletedLines>(
|
||||
Anchor::min()..Anchor::max(),
|
||||
cx.theme().status().deleted_background,
|
||||
Default::default(),
|
||||
false,
|
||||
cx,
|
||||
);
|
||||
editor
|
||||
@@ -1652,7 +1652,7 @@ impl Render for PromptEditor {
|
||||
|
||||
let error_message = SharedString::from(error.to_string());
|
||||
if error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& cx.has_flag::<ZedProFeatureFlag>()
|
||||
&& cx.has_flag::<ZedPro>()
|
||||
{
|
||||
el.child(
|
||||
v_flex()
|
||||
@@ -1966,7 +1966,7 @@ impl PromptEditor {
|
||||
.update(cx, |editor, _| editor.set_read_only(false));
|
||||
}
|
||||
CodegenStatus::Error(error) => {
|
||||
if cx.has_flag::<ZedProFeatureFlag>()
|
||||
if cx.has_flag::<ZedPro>()
|
||||
&& error.error_code() == proto::ErrorCode::RateLimitExceeded
|
||||
&& !dismissed_rate_limit_notice()
|
||||
{
|
||||
@@ -3056,8 +3056,7 @@ impl CodegenAlternative {
|
||||
let mut response_latency = None;
|
||||
let request_start = Instant::now();
|
||||
let diff = async {
|
||||
let chunks =
|
||||
StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
|
||||
let chunks = StripInvalidSpans::new(stream?.stream);
|
||||
futures::pin_mut!(chunks);
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -44,6 +44,4 @@ impl Settings for SlashCommandSettings {
|
||||
.chain(sources.server),
|
||||
)
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
||||
@@ -112,27 +112,13 @@ impl AssistantSettings {
|
||||
}
|
||||
|
||||
/// Assistant panel settings
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, Default)]
|
||||
pub struct AssistantSettingsContent {
|
||||
#[serde(flatten)]
|
||||
pub inner: Option<AssistantSettingsContentInner>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum AssistantSettingsContentInner {
|
||||
pub enum AssistantSettingsContent {
|
||||
Versioned(Box<VersionedAssistantSettingsContent>),
|
||||
Legacy(LegacyAssistantSettingsContent),
|
||||
}
|
||||
|
||||
impl AssistantSettingsContentInner {
|
||||
fn for_v2(content: AssistantSettingsContentV2) -> Self {
|
||||
AssistantSettingsContentInner::Versioned(Box::new(VersionedAssistantSettingsContent::V2(
|
||||
content,
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchema for AssistantSettingsContent {
|
||||
fn schema_name() -> String {
|
||||
VersionedAssistantSettingsContent::schema_name()
|
||||
@@ -147,21 +133,26 @@ impl JsonSchema for AssistantSettingsContent {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AssistantSettingsContent {
|
||||
fn default() -> Self {
|
||||
Self::Versioned(Box::new(VersionedAssistantSettingsContent::default()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AssistantSettingsContent {
|
||||
pub fn is_version_outdated(&self) -> bool {
|
||||
match &self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(_) => true,
|
||||
VersionedAssistantSettingsContent::V2(_) => false,
|
||||
},
|
||||
Some(AssistantSettingsContentInner::Legacy(_)) => true,
|
||||
None => false,
|
||||
AssistantSettingsContent::Legacy(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn upgrade(&self) -> AssistantSettingsContentV2 {
|
||||
match &self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref settings) => AssistantSettingsContentV2 {
|
||||
enabled: settings.enabled,
|
||||
button: settings.button,
|
||||
@@ -221,7 +212,7 @@ impl AssistantSettingsContent {
|
||||
},
|
||||
VersionedAssistantSettingsContent::V2(ref settings) => settings.clone(),
|
||||
},
|
||||
Some(AssistantSettingsContentInner::Legacy(settings)) => AssistantSettingsContentV2 {
|
||||
AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV2 {
|
||||
enabled: None,
|
||||
button: settings.button,
|
||||
dock: settings.dock,
|
||||
@@ -246,13 +237,12 @@ impl AssistantSettingsContent {
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
},
|
||||
None => AssistantSettingsContentV2::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_dock(&mut self, dock: AssistantDockPosition) {
|
||||
match &mut self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref mut settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
@@ -260,17 +250,9 @@ impl AssistantSettingsContent {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
},
|
||||
Some(AssistantSettingsContentInner::Legacy(settings)) => {
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
settings.dock = Some(dock);
|
||||
}
|
||||
None => {
|
||||
self.inner = Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
dock: Some(dock),
|
||||
..Default::default()
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,8 +260,8 @@ impl AssistantSettingsContent {
|
||||
let model = language_model.id().0.to_string();
|
||||
let provider = language_model.provider_id().0.to_string();
|
||||
|
||||
match &mut self.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(settings)) => match **settings {
|
||||
match self {
|
||||
AssistantSettingsContent::Versioned(settings) => match **settings {
|
||||
VersionedAssistantSettingsContent::V1(ref mut settings) => {
|
||||
match provider.as_ref() {
|
||||
"zed.dev" => {
|
||||
@@ -355,80 +337,56 @@ impl AssistantSettingsContent {
|
||||
settings.default_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
},
|
||||
Some(AssistantSettingsContentInner::Legacy(settings)) => {
|
||||
AssistantSettingsContent::Legacy(settings) => {
|
||||
if let Ok(model) = OpenAiModel::from_id(&language_model.id().0) {
|
||||
settings.default_open_ai_model = Some(model);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.inner = Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection { provider, model }),
|
||||
..Default::default()
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_inline_assistant_model(&mut self, provider: String, model: String) {
|
||||
self.v2_setting(|setting| {
|
||||
setting.inline_assistant_model = Some(LanguageModelSelection { provider, model });
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.inline_assistant_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_commit_message_model(&mut self, provider: String, model: String) {
|
||||
self.v2_setting(|setting| {
|
||||
setting.commit_message_model = Some(LanguageModelSelection { provider, model });
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn v2_setting(
|
||||
&mut self,
|
||||
f: impl FnOnce(&mut AssistantSettingsContentV2) -> anyhow::Result<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
match self.inner.get_or_insert_with(|| {
|
||||
AssistantSettingsContentInner::for_v2(AssistantSettingsContentV2 {
|
||||
..Default::default()
|
||||
})
|
||||
}) {
|
||||
AssistantSettingsContentInner::Versioned(boxed) => {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
f(settings)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.commit_message_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_thread_summary_model(&mut self, provider: String, model: String) {
|
||||
self.v2_setting(|setting| {
|
||||
setting.thread_summary_model = Some(LanguageModelSelection { provider, model });
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
if let AssistantSettingsContent::Versioned(boxed) = self {
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.thread_summary_model = Some(LanguageModelSelection { provider, model });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_always_allow_tool_actions(&mut self, allow: bool) {
|
||||
self.v2_setting(|setting| {
|
||||
setting.always_allow_tool_actions = Some(allow);
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.always_allow_tool_actions = Some(allow);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_profile(&mut self, profile_id: AgentProfileId) {
|
||||
self.v2_setting(|setting| {
|
||||
setting.default_profile = Some(profile_id);
|
||||
Ok(())
|
||||
})
|
||||
.ok();
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
settings.default_profile = Some(profile_id);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_profile(
|
||||
@@ -436,7 +394,11 @@ impl AssistantSettingsContent {
|
||||
profile_id: AgentProfileId,
|
||||
profile: AgentProfile,
|
||||
) -> Result<()> {
|
||||
self.v2_setting(|settings| {
|
||||
let AssistantSettingsContent::Versioned(boxed) = self else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let VersionedAssistantSettingsContent::V2(ref mut settings) = **boxed {
|
||||
let profiles = settings.profiles.get_or_insert_default();
|
||||
if profiles.contains_key(&profile_id) {
|
||||
bail!("profile with ID '{profile_id}' already exists");
|
||||
@@ -462,9 +424,9 @@ impl AssistantSettingsContent {
|
||||
.collect(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -499,7 +461,7 @@ impl Default for VersionedAssistantSettingsContent {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug, Default)]
|
||||
#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
pub struct AssistantSettingsContentV2 {
|
||||
/// Whether the Assistant is enabled.
|
||||
///
|
||||
@@ -746,39 +708,6 @@ impl Settings for AssistantSettings {
|
||||
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
if let Some(b) = vscode
|
||||
.read_value("chat.agent.enabled")
|
||||
.and_then(|b| b.as_bool())
|
||||
{
|
||||
match &mut current.inner {
|
||||
Some(AssistantSettingsContentInner::Versioned(versioned)) => {
|
||||
match versioned.as_mut() {
|
||||
VersionedAssistantSettingsContent::V1(setting) => {
|
||||
setting.enabled = Some(b);
|
||||
setting.button = Some(b);
|
||||
}
|
||||
|
||||
VersionedAssistantSettingsContent::V2(setting) => {
|
||||
setting.enabled = Some(b);
|
||||
setting.button = Some(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(AssistantSettingsContentInner::Legacy(setting)) => setting.button = Some(b),
|
||||
None => {
|
||||
current.inner = Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
enabled: Some(b),
|
||||
button: Some(b),
|
||||
..Default::default()
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge<T>(target: &mut T, value: Option<T>) {
|
||||
@@ -822,30 +751,28 @@ mod tests {
|
||||
settings::SettingsStore::global(cx).update_settings_file::<AssistantSettings>(
|
||||
fs.clone(),
|
||||
|settings, _| {
|
||||
*settings = AssistantSettingsContent {
|
||||
inner: Some(AssistantSettingsContentInner::for_v2(
|
||||
AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enabled: None,
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
},
|
||||
)),
|
||||
}
|
||||
*settings = AssistantSettingsContent::Versioned(Box::new(
|
||||
VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
|
||||
default_model: Some(LanguageModelSelection {
|
||||
provider: "test-provider".into(),
|
||||
model: "gpt-99".into(),
|
||||
}),
|
||||
inline_assistant_model: None,
|
||||
commit_message_model: None,
|
||||
thread_summary_model: None,
|
||||
inline_alternatives: None,
|
||||
enabled: None,
|
||||
button: None,
|
||||
dock: None,
|
||||
default_width: None,
|
||||
default_height: None,
|
||||
enable_experimental_live_diffs: None,
|
||||
default_profile: None,
|
||||
profiles: None,
|
||||
always_allow_tool_actions: None,
|
||||
notify_when_agent_waiting: None,
|
||||
}),
|
||||
))
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -28,7 +28,6 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
text.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -39,9 +39,10 @@ impl ActionLog {
|
||||
self.edited_since_project_diagnostics_check
|
||||
}
|
||||
|
||||
fn track_buffer_internal(
|
||||
fn track_buffer(
|
||||
&mut self,
|
||||
buffer: Entity<Buffer>,
|
||||
created: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut TrackedBuffer {
|
||||
let tracked_buffer = self
|
||||
@@ -58,11 +59,7 @@ impl ActionLog {
|
||||
let base_text;
|
||||
let status;
|
||||
let unreviewed_changes;
|
||||
if buffer
|
||||
.read(cx)
|
||||
.file()
|
||||
.map_or(true, |file| !file.disk_state().exists())
|
||||
{
|
||||
if created {
|
||||
base_text = Rope::default();
|
||||
status = TrackedBufferStatus::Created;
|
||||
unreviewed_changes = Patch::new(vec![Edit {
|
||||
@@ -149,7 +146,7 @@ impl ActionLog {
|
||||
// resurrected externally, we want to clear the changes we
|
||||
// were tracking and reset the buffer's state.
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
self.track_buffer_internal(buffer, cx);
|
||||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
@@ -263,15 +260,26 @@ impl ActionLog {
|
||||
}
|
||||
|
||||
/// Track a buffer as read, so we can notify the model about user edits.
|
||||
pub fn track_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer_internal(buffer, cx);
|
||||
pub fn buffer_read(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
|
||||
/// Track a buffer that was added as context, so we can notify the model about user edits.
|
||||
pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer, false, cx);
|
||||
}
|
||||
|
||||
/// Track a buffer as read, so we can notify the model about user edits.
|
||||
pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.track_buffer(buffer.clone(), true, cx);
|
||||
self.buffer_edited(buffer, cx)
|
||||
}
|
||||
|
||||
/// Mark a buffer as edited, so we can refresh it in the context
|
||||
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
|
||||
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
|
||||
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
|
||||
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
|
||||
tracked_buffer.status = TrackedBufferStatus::Modified;
|
||||
}
|
||||
@@ -279,7 +287,7 @@ impl ActionLog {
|
||||
}
|
||||
|
||||
pub fn will_delete_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
let tracked_buffer = self.track_buffer_internal(buffer.clone(), cx);
|
||||
let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
|
||||
match tracked_buffer.status {
|
||||
TrackedBufferStatus::Created => {
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
@@ -389,7 +397,7 @@ impl ActionLog {
|
||||
|
||||
// Clear all tracked changes for this buffer and start over as if we just read it.
|
||||
self.tracked_buffers.remove(&buffer);
|
||||
self.track_buffer_internal(buffer.clone(), cx);
|
||||
self.track_buffer(buffer.clone(), false, cx);
|
||||
cx.notify();
|
||||
save
|
||||
}
|
||||
@@ -687,20 +695,12 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E")], None, cx)
|
||||
@@ -765,23 +765,12 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
path!("/dir"),
|
||||
json!({"file": "abc\ndef\nghi\njkl\nmno\npqr"}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno\npqr", cx));
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 0)..Point::new(2, 0), "")], None, cx)
|
||||
@@ -850,20 +839,12 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({"file": "abc\ndef\nghi\njkl\nmno"}))
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
|
||||
.unwrap();
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl\nmno", cx));
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 2)..Point::new(2, 3), "F\nGHI")], None, cx)
|
||||
@@ -947,21 +928,25 @@ mod tests {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/dir"), json!({})).await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
|
||||
let file_path = project
|
||||
.read_with(cx, |project, cx| project.find_project_path("dir/file1", cx))
|
||||
.unwrap();
|
||||
|
||||
// Simulate file2 being recreated by a tool.
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(file_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("lorem", cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
|
||||
});
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
@@ -1082,9 +1067,8 @@ mod tests {
|
||||
.update(cx, |project, cx| project.open_buffer(file2_path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer2.clone(), cx));
|
||||
buffer2.update(cx, |buffer, cx| buffer.set_text("IPSUM", cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer2.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.will_create_buffer(buffer2.clone(), cx));
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer2.clone(), cx))
|
||||
.await
|
||||
@@ -1129,7 +1113,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
|
||||
@@ -1264,7 +1248,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer
|
||||
.edit([(Point::new(1, 1)..Point::new(1, 2), "E\nXYZ")], None, cx)
|
||||
@@ -1397,9 +1381,8 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text("content", cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.will_create_buffer(buffer.clone(), cx));
|
||||
});
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
@@ -1455,7 +1438,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
|
||||
action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
|
||||
|
||||
for _ in 0..operations {
|
||||
match rng.gen_range(0..100) {
|
||||
@@ -1507,7 +1490,7 @@ mod tests {
|
||||
log::info!("quiescing...");
|
||||
cx.run_until_parked();
|
||||
action_log.update(cx, |log, cx| {
|
||||
let tracked_buffer = log.track_buffer_internal(buffer.clone(), cx);
|
||||
let tracked_buffer = log.track_buffer(buffer.clone(), false, cx);
|
||||
let mut old_text = tracked_buffer.base_text.clone();
|
||||
let new_text = buffer.read(cx).as_rope();
|
||||
for edit in tracked_buffer.unreviewed_changes.edits() {
|
||||
|
||||
@@ -10,16 +10,14 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::AnyElement;
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::Context;
|
||||
use gpui::IntoElement;
|
||||
use gpui::Window;
|
||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use icons::IconName;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
use project::Project;
|
||||
use workspace::Workspace;
|
||||
|
||||
pub use crate::action_log::*;
|
||||
pub use crate::tool_registry::*;
|
||||
@@ -67,7 +65,6 @@ pub trait ToolCard: 'static + Sized {
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement;
|
||||
}
|
||||
@@ -79,7 +76,6 @@ pub struct AnyToolCard {
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> AnyElement,
|
||||
}
|
||||
@@ -90,14 +86,11 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||
entity: gpui::AnyEntity,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
let entity = entity.downcast::<T>().unwrap();
|
||||
entity.update(cx, |entity, cx| {
|
||||
entity
|
||||
.render(status, window, workspace, cx)
|
||||
.into_any_element()
|
||||
entity.render(status, window, cx).into_any_element()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -109,14 +102,8 @@ impl<T: ToolCard> From<Entity<T>> for AnyToolCard {
|
||||
}
|
||||
|
||||
impl AnyToolCard {
|
||||
pub fn render(
|
||||
&self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut App,
|
||||
) -> AnyElement {
|
||||
(self.render)(self.entity.clone(), status, window, workspace, cx)
|
||||
pub fn render(&self, status: &ToolUseStatus, window: &mut Window, cx: &mut App) -> AnyElement {
|
||||
(self.render)(self.entity.clone(), status, window, cx)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,7 +163,6 @@ pub trait Tool: 'static + Send + Sync {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult;
|
||||
}
|
||||
|
||||
@@ -10,11 +10,6 @@ pub fn adapt_schema_to_format(
|
||||
json: &mut Value,
|
||||
format: LanguageModelToolSchemaFormat,
|
||||
) -> Result<()> {
|
||||
if let Value::Object(obj) = json {
|
||||
obj.remove("$schema");
|
||||
obj.remove("title");
|
||||
}
|
||||
|
||||
match format {
|
||||
LanguageModelToolSchemaFormat::JsonSchema => Ok(()),
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset => adapt_to_json_schema_subset(json),
|
||||
@@ -35,7 +30,10 @@ fn adapt_to_json_schema_subset(json: &mut Value) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
obj.remove("format");
|
||||
const KEYS_TO_REMOVE: [&str; 2] = ["format", "$schema"];
|
||||
for key in KEYS_TO_REMOVE {
|
||||
obj.remove(key);
|
||||
}
|
||||
|
||||
if let Some(default) = obj.get("default") {
|
||||
let is_null = default.is_null();
|
||||
|
||||
@@ -14,14 +14,12 @@ path = "src/assistant_tools.rs"
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
@@ -32,32 +30,23 @@ linkme.workspace = true
|
||||
open.workspace = true
|
||||
project.workspace = true
|
||||
regex.workspace = true
|
||||
rust-embed.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
smallvec.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
web_search.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
collections = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
fs = { workspace = true, features = ["test-support"] }
|
||||
language = { workspace = true, features = ["test-support"] }
|
||||
language_models.workspace = true
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
rand.workspace = true
|
||||
pretty_assertions.workspace = true
|
||||
reqwest_client.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
tree-sitter-rust.workspace = true
|
||||
workspace = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -7,20 +7,19 @@ mod create_directory_tool;
|
||||
mod create_file_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_agent;
|
||||
mod edit_file_tool;
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
mod grep_tool;
|
||||
mod list_directory_tool;
|
||||
mod move_path_tool;
|
||||
mod now_tool;
|
||||
mod open_tool;
|
||||
mod path_search_tool;
|
||||
mod read_file_tool;
|
||||
mod rename_tool;
|
||||
mod replace;
|
||||
mod schema;
|
||||
mod symbol_info_tool;
|
||||
mod templates;
|
||||
mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod ui;
|
||||
@@ -30,14 +29,12 @@ use std::sync::Arc;
|
||||
|
||||
use assistant_tool::ToolRegistry;
|
||||
use copy_path_tool::CopyPathTool;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
use gpui::App;
|
||||
use http_client::HttpClientWithUrl;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use move_path_tool::MovePathTool;
|
||||
use web_search_tool::WebSearchTool;
|
||||
|
||||
pub(crate) use templates::*;
|
||||
|
||||
use crate::batch_tool::BatchTool;
|
||||
use crate::code_action_tool::CodeActionTool;
|
||||
use crate::code_symbols_tool::CodeSymbolsTool;
|
||||
@@ -48,22 +45,17 @@ use crate::delete_path_tool::DeletePathTool;
|
||||
use crate::diagnostics_tool::DiagnosticsTool;
|
||||
use crate::edit_file_tool::EditFileTool;
|
||||
use crate::fetch_tool::FetchTool;
|
||||
use crate::find_path_tool::FindPathTool;
|
||||
use crate::grep_tool::GrepTool;
|
||||
use crate::list_directory_tool::ListDirectoryTool;
|
||||
use crate::now_tool::NowTool;
|
||||
use crate::open_tool::OpenTool;
|
||||
use crate::path_search_tool::PathSearchTool;
|
||||
use crate::read_file_tool::ReadFileTool;
|
||||
use crate::rename_tool::RenameTool;
|
||||
use crate::symbol_info_tool::SymbolInfoTool;
|
||||
use crate::terminal_tool::TerminalTool;
|
||||
use crate::thinking_tool::ThinkingTool;
|
||||
|
||||
pub use create_file_tool::CreateFileToolInput;
|
||||
pub use edit_file_tool::EditFileToolInput;
|
||||
pub use find_path_tool::FindPathToolInput;
|
||||
pub use read_file_tool::ReadFileToolInput;
|
||||
|
||||
pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
assistant_tool::init(cx);
|
||||
|
||||
@@ -84,79 +76,41 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
registry.register_tool(OpenTool);
|
||||
registry.register_tool(CodeSymbolsTool);
|
||||
registry.register_tool(ContentsTool);
|
||||
registry.register_tool(FindPathTool);
|
||||
registry.register_tool(PathSearchTool);
|
||||
registry.register_tool(ReadFileTool);
|
||||
registry.register_tool(GrepTool);
|
||||
registry.register_tool(RenameTool);
|
||||
registry.register_tool(ThinkingTool);
|
||||
registry.register_tool(FetchTool::new(http_client));
|
||||
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
move |registry, event, cx| match event {
|
||||
language_model::Event::DefaultModelChanged => {
|
||||
let using_zed_provider = registry
|
||||
.read(cx)
|
||||
.default_model()
|
||||
.map_or(false, |default| default.is_provided_by_zed());
|
||||
if using_zed_provider {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
}
|
||||
cx.observe_flag::<feature_flags::ZedProWebSearchTool, _>({
|
||||
move |is_enabled, cx| {
|
||||
if is_enabled {
|
||||
ToolRegistry::global(cx).register_tool(WebSearchTool);
|
||||
} else {
|
||||
ToolRegistry::global(cx).unregister_tool(WebSearchTool);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::Client;
|
||||
use clock::FakeSystemClock;
|
||||
use http_client::FakeHttpClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Serialize;
|
||||
|
||||
#[test]
|
||||
fn test_json_schema() {
|
||||
#[derive(Serialize, JsonSchema)]
|
||||
struct GetWeatherTool {
|
||||
location: String,
|
||||
}
|
||||
|
||||
let schema = schema::json_schema_for::<GetWeatherTool>(
|
||||
language_model::LanguageModelToolSchemaFormat::JsonSchema,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
schema,
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
})
|
||||
);
|
||||
}
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_builtin_tool_schema_compatibility(cx: &mut App) {
|
||||
settings::init(cx);
|
||||
|
||||
let client = Client::new(
|
||||
Arc::new(FakeSystemClock::new()),
|
||||
FakeHttpClient::with_200_response(),
|
||||
crate::init(
|
||||
Arc::new(http_client::HttpClientWithUrl::new(
|
||||
FakeHttpClient::with_200_response(),
|
||||
"https://zed.dev",
|
||||
None,
|
||||
)),
|
||||
cx,
|
||||
);
|
||||
language_model::init(client.clone(), cx);
|
||||
crate::init(client.http_client(), cx);
|
||||
|
||||
for tool in ToolRegistry::global(cx).tools() {
|
||||
let actual_schema = tool
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult, ToolWorkingSet};
|
||||
use futures::future::join_all;
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -97,7 +97,7 @@ pub struct BatchToolInput {
|
||||
/// }
|
||||
/// },
|
||||
/// {
|
||||
/// "name": "find_path",
|
||||
/// "name": "path_search",
|
||||
/// "input": {
|
||||
/// "glob": "**/*test*.rs"
|
||||
/// }
|
||||
@@ -218,7 +218,6 @@ impl Tool for BatchTool {
|
||||
messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<BatchToolInput>(input) {
|
||||
@@ -259,9 +258,7 @@ impl Tool for BatchTool {
|
||||
let action_log = action_log.clone();
|
||||
let messages = messages.clone();
|
||||
let tool_result = cx
|
||||
.update(|cx| {
|
||||
tool.run(invocation.input, &messages, project, action_log, window, cx)
|
||||
})
|
||||
.update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
|
||||
.map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
|
||||
|
||||
tasks.push(tool_result.output);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{self, LspAction, Project};
|
||||
@@ -140,7 +140,6 @@ impl Tool for CodeActionTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CodeActionToolInput>(input) {
|
||||
@@ -160,7 +159,7 @@ impl Tool for CodeActionTool {
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let range = {
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use collections::IndexMap;
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use language::{OutlineItem, ParseStatus, Point};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, Symbol};
|
||||
@@ -128,7 +128,6 @@ impl Tool for CodeSymbolsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
|
||||
@@ -175,7 +174,7 @@ pub async fn file_outline(
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
// Wait until the buffer has been fully parsed, so that we can read its outline.
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -102,7 +102,6 @@ impl Tool for ContentsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ContentsToolInput>(input) {
|
||||
@@ -210,7 +209,7 @@ impl Tool for ContentsTool {
|
||||
})?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.track_buffer(buffer, cx);
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
@@ -222,7 +221,7 @@ impl Tool for ContentsTool {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.track_buffer(buffer, cx);
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -77,7 +76,6 @@ impl Tool for CopyPathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CopyPathToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -68,7 +67,6 @@ impl Tool for CreateDirectoryTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateDirectoryToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::AnyWindowHandle;
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
use language_model::LanguageModelToolSchemaFormat;
|
||||
@@ -24,9 +23,6 @@ pub struct CreateFileToolInput {
|
||||
///
|
||||
/// You can create a new file by providing a path of "directory1/new_file.txt"
|
||||
/// </example>
|
||||
///
|
||||
/// Make sure to include this field before the `contents` field in the input object
|
||||
/// so that we can display it immediately.
|
||||
pub path: String,
|
||||
|
||||
/// The text contents of the file to create.
|
||||
@@ -93,7 +89,6 @@ impl Tool for CreateFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<CreateFileToolInput>(input) {
|
||||
@@ -117,12 +112,9 @@ impl Tool for CreateFileTool {
|
||||
.await
|
||||
.map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
|
||||
cx.update(|cx| {
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx)
|
||||
});
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.buffer_edited(buffer.clone(), cx)
|
||||
action_log.will_create_buffer(buffer.clone(), cx)
|
||||
});
|
||||
})?;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::{SinkExt, StreamExt, channel::mpsc};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{Project, ProjectPath};
|
||||
use schemars::JsonSchema;
|
||||
@@ -62,7 +62,6 @@ impl Tool for DeletePathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let path_str = match serde_json::from_value::<DeletePathToolInput>(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -82,7 +82,6 @@ impl Tool for DiagnosticsTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
match serde_json::from_value::<DiagnosticsToolInput>(input)
|
||||
|
||||
@@ -1,526 +0,0 @@
|
||||
mod edit_parser;
|
||||
|
||||
use crate::{Template, Templates};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::ActionLog;
|
||||
use edit_parser::EditParser;
|
||||
use futures::{Stream, StreamExt, stream};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Bias, Buffer, BufferSnapshot};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use smallvec::SmallVec;
|
||||
use std::{ops::Range, path::PathBuf, sync::Arc};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EditAgentTemplate {
|
||||
path: Option<PathBuf>,
|
||||
file_content: String,
|
||||
instructions: String,
|
||||
}
|
||||
|
||||
impl Template for EditAgentTemplate {
|
||||
const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
|
||||
}
|
||||
|
||||
pub struct EditAgent {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
}
|
||||
|
||||
impl EditAgent {
|
||||
pub fn new(
|
||||
model: Arc<dyn LanguageModel>,
|
||||
action_log: Entity<ActionLog>,
|
||||
templates: Arc<Templates>,
|
||||
) -> Self {
|
||||
EditAgent {
|
||||
model,
|
||||
action_log,
|
||||
templates,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn edit(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
instructions: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
let edits = self.stream_edits(buffer.clone(), instructions, cx).await?;
|
||||
self.apply_edits(buffer, edits, cx).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn apply_edits(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
edits: impl Stream<Item = Result<(Range<Anchor>, String)>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<()> {
|
||||
// todo!("group all edits into one transaction")
|
||||
// todo!("add tests for this")
|
||||
|
||||
// Ensure the buffer is tracked by the action log.
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
|
||||
|
||||
futures::pin_mut!(edits);
|
||||
while let Some(edit) = edits.next().await {
|
||||
let (range, content) = edit?;
|
||||
// Edit the buffer and report the edit as part of the same effect cycle, otherwise
|
||||
// the edit will be reported as if the user made it.
|
||||
cx.update(|cx| {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(range, content)], None, cx));
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stream_edits(
|
||||
&self,
|
||||
buffer: Entity<Buffer>,
|
||||
instructions: String,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<impl use<> + Stream<Item = Result<(Range<Anchor>, String)>>> {
|
||||
println!("{}\n\n", instructions);
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
|
||||
// todo!("move to background")
|
||||
let prompt = EditAgentTemplate {
|
||||
path,
|
||||
file_content: snapshot.text(),
|
||||
instructions,
|
||||
}
|
||||
.render(&self.templates)?;
|
||||
let request = LanguageModelRequest {
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text(prompt)],
|
||||
cache: false,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut parser = EditParser::new();
|
||||
let stream = self.model.stream_completion_text(request, cx).await?.stream;
|
||||
Ok(stream.flat_map(move |chunk| {
|
||||
let mut edits = SmallVec::new();
|
||||
let mut error = None;
|
||||
let snapshot = snapshot.clone();
|
||||
match chunk {
|
||||
Ok(chunk) => edits = parser.push(&chunk),
|
||||
Err(err) => error = Some(Err(anyhow!(err))),
|
||||
}
|
||||
stream::iter(
|
||||
edits
|
||||
.into_iter()
|
||||
.map(move |edit| {
|
||||
dbg!(&edit);
|
||||
let range = Self::resolve_location(&snapshot, &edit.old_text);
|
||||
Ok((range, edit.new_text))
|
||||
})
|
||||
.chain(error),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const DELETION_COST: u32 = 10;
|
||||
const WHITESPACE_INSERTION_COST: u32 = 1;
|
||||
const WHITESPACE_DELETION_COST: u32 = 1;
|
||||
|
||||
let buffer_len = buffer.len();
|
||||
let query_len = search_query.len();
|
||||
let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
|
||||
let mut leading_deletion_cost = 0_u32;
|
||||
for (row, query_byte) in search_query.bytes().enumerate() {
|
||||
let deletion_cost = if query_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_DELETION_COST
|
||||
} else {
|
||||
DELETION_COST
|
||||
};
|
||||
|
||||
leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
|
||||
matrix.set(
|
||||
row + 1,
|
||||
0,
|
||||
SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
|
||||
);
|
||||
|
||||
for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
|
||||
let insertion_cost = if buffer_byte.is_ascii_whitespace() {
|
||||
WHITESPACE_INSERTION_COST
|
||||
} else {
|
||||
INSERTION_COST
|
||||
};
|
||||
|
||||
let up = SearchState::new(
|
||||
matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_byte == *buffer_byte {
|
||||
matrix.get(row, col).cost
|
||||
} else {
|
||||
matrix
|
||||
.get(row, col)
|
||||
.cost
|
||||
.saturating_add(deletion_cost + insertion_cost)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
|
||||
}
|
||||
}
|
||||
|
||||
// Traceback to find the best match
|
||||
let mut best_buffer_end = buffer_len;
|
||||
let mut best_cost = u32::MAX;
|
||||
for col in 1..=buffer_len {
|
||||
let cost = matrix.get(query_len, col).cost;
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
best_buffer_end = col;
|
||||
}
|
||||
}
|
||||
|
||||
let mut query_ix = query_len;
|
||||
let mut buffer_ix = best_buffer_end;
|
||||
while query_ix > 0 && buffer_ix > 0 {
|
||||
let current = matrix.get(query_ix, buffer_ix);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_ix -= 1;
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_ix -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
buffer_ix -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
|
||||
start.column = 0;
|
||||
let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
|
||||
if end.column > 0 {
|
||||
end.column = buffer.line_len(end.row);
|
||||
}
|
||||
|
||||
buffer.anchor_after(start)..buffer.anchor_before(end)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
cost: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(cost: u32, direction: SearchDirection) -> Self {
|
||||
Self { cost, direction }
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchMatrix {
|
||||
cols: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl SearchMatrix {
|
||||
fn new(rows: usize, cols: usize) -> Self {
|
||||
SearchMatrix {
|
||||
cols,
|
||||
data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, cost: SearchState) {
|
||||
self.data[row * self.cols + col] = cost;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use client::{Client, UserStore};
|
||||
use collections::HashSet;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::LanguageModelRegistry;
|
||||
use project::Project;
|
||||
use rand::prelude::*;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde_json::json;
|
||||
use std::{fmt::Write as _, io::Write as _, path::Path, sync::mpsc};
|
||||
use util::path;
|
||||
|
||||
#[test]
|
||||
fn test_delete_run_git_blame() {
|
||||
eval(
|
||||
100,
|
||||
0.9,
|
||||
Eval {
|
||||
input_path: "root/blame.rs".into(),
|
||||
input_content: include_str!("fixtures/delete_run_git_blame/before.rs").into(),
|
||||
instructions: indoc! {r#"
|
||||
Let's delete the `run_git_blame` function while keeping all other code intact:
|
||||
|
||||
// ... existing code ...
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
// ... existing code ...
|
||||
"#}
|
||||
.into(),
|
||||
expected_output: include_str!("fixtures/delete_run_git_blame/after.rs").into(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_handle_command_output() {
|
||||
eval(
|
||||
100,
|
||||
0.9,
|
||||
Eval {
|
||||
input_path: "root/blame.rs".into(),
|
||||
input_content: include_str!("fixtures/extract_handle_command_output/before.rs").into(),
|
||||
instructions: indoc! {r#"
|
||||
Extract `handle_command_output` method from `run_git_blame`.
|
||||
|
||||
// ... existing code ...
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
handle_command_output(output)
|
||||
}
|
||||
|
||||
fn handle_command_output(output: std::process::Output) -> Result<String> {
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
// ... existing code ...
|
||||
"#}
|
||||
.into(),
|
||||
expected_output: include_str!("fixtures/extract_handle_command_output/after.rs").into()
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Eval {
|
||||
input_path: PathBuf,
|
||||
input_content: String,
|
||||
instructions: String,
|
||||
expected_output: String,
|
||||
}
|
||||
|
||||
fn eval(iterations: usize, expected_pass_ratio: f32, eval: Eval) {
|
||||
let executor = gpui::background_executor();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
for _ in 0..iterations {
|
||||
let eval = eval.clone();
|
||||
let tx = tx.clone();
|
||||
executor
|
||||
.spawn(async move {
|
||||
let dispatcher = gpui::TestDispatcher::new(StdRng::from_entropy());
|
||||
let mut cx = TestAppContext::build(dispatcher, None);
|
||||
let output = cx.executor().block_test(async {
|
||||
let test = agent_test(&mut cx).await;
|
||||
apply_edits(
|
||||
eval.input_path,
|
||||
eval.input_content,
|
||||
eval.instructions,
|
||||
&test,
|
||||
&mut cx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
tx.send(output).unwrap();
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
drop(tx);
|
||||
|
||||
let mut evaluated_count = 0;
|
||||
report_progress(evaluated_count, iterations);
|
||||
|
||||
let mut failed_count = 0;
|
||||
let mut failed_message = String::new();
|
||||
let mut failed_outputs = HashSet::default();
|
||||
while let Ok(output) = rx.recv() {
|
||||
if output != eval.expected_output {
|
||||
failed_count += 1;
|
||||
if failed_outputs.insert(output.clone()) {
|
||||
writeln!(
|
||||
failed_message,
|
||||
"=======\n{}\n=======",
|
||||
pretty_assertions::StrComparison::new(&output, &eval.expected_output)
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
evaluated_count += 1;
|
||||
report_progress(evaluated_count, iterations);
|
||||
}
|
||||
|
||||
let actual_pass_ratio = (iterations - failed_count) as f32 / iterations as f32;
|
||||
println!("Actual pass ratio: {}\n", actual_pass_ratio);
|
||||
assert!(
|
||||
actual_pass_ratio >= expected_pass_ratio,
|
||||
"Expected pass ratio: {}\nActual pass ratio: {}\nFailures: {}",
|
||||
expected_pass_ratio,
|
||||
actual_pass_ratio,
|
||||
failed_message
|
||||
);
|
||||
}
|
||||
|
||||
fn report_progress(evaluated_count: usize, iterations: usize) {
|
||||
print!("\r\x1b[KEvaluated {}/{}", evaluated_count, iterations);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
async fn apply_edits(
|
||||
path: impl AsRef<Path>,
|
||||
content: impl Into<Arc<str>>,
|
||||
instructions: impl Into<String>,
|
||||
test: &EditAgentTest,
|
||||
cx: &mut TestAppContext,
|
||||
) -> String {
|
||||
let path = test
|
||||
.project
|
||||
.read_with(cx, |project, cx| project.find_project_path(path, cx))
|
||||
.unwrap();
|
||||
let buffer = test
|
||||
.project
|
||||
.update(cx, |project, cx| project.open_buffer(path, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, cx| buffer.set_text(content, cx));
|
||||
test.agent
|
||||
.edit(buffer.clone(), instructions.into(), &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
buffer.update(cx, |buffer, _cx| buffer.text())
|
||||
}
|
||||
|
||||
struct EditAgentTest {
|
||||
agent: EditAgent,
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
async fn agent_test(cx: &mut TestAppContext) -> EditAgentTest {
|
||||
cx.executor().allow_parking();
|
||||
cx.update(settings::init);
|
||||
cx.update(Project::init_settings);
|
||||
cx.update(language::init);
|
||||
cx.update(gpui_tokio::init);
|
||||
cx.update(client::init_settings);
|
||||
|
||||
let fs = FakeFs::new(cx.executor().clone());
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
|
||||
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "gemini-2.0-flash")
|
||||
.unwrap();
|
||||
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
cx.spawn(async move |_| {
|
||||
authenticated.await.unwrap();
|
||||
model
|
||||
})
|
||||
})
|
||||
.await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
|
||||
EditAgentTest {
|
||||
agent: EditAgent::new(model, action_log, Templates::new()),
|
||||
project,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
use smallvec::SmallVec;
|
||||
use std::mem;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
pub struct Edit {
|
||||
pub old_text: String,
|
||||
pub new_text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditParser {
|
||||
state: EditParserState,
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum EditParserState {
|
||||
Pending,
|
||||
WithinOldText,
|
||||
AfterOldText { old_text: String },
|
||||
WithinNewText { old_text: String },
|
||||
}
|
||||
|
||||
impl EditParser {
|
||||
pub fn new() -> Self {
|
||||
EditParser {
|
||||
state: EditParserState::Pending,
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, chunk: &str) -> SmallVec<[Edit; 1]> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
let mut edits = SmallVec::new();
|
||||
loop {
|
||||
match &mut self.state {
|
||||
EditParserState::Pending => {
|
||||
if let Some(start) = self.buffer.find("<old_text>") {
|
||||
self.buffer.drain(..start + "<old_text>".len());
|
||||
self.state = EditParserState::WithinOldText;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::WithinOldText => {
|
||||
if let Some(end) = self.buffer.find("</old_text>") {
|
||||
let mut start = 0;
|
||||
if self.buffer.starts_with('\n') {
|
||||
start = 1;
|
||||
}
|
||||
let mut old_text = self.buffer[start..end].to_string();
|
||||
if old_text.ends_with('\n') {
|
||||
old_text.pop();
|
||||
}
|
||||
|
||||
self.buffer.drain(..end + "</old_text>".len());
|
||||
self.state = EditParserState::AfterOldText { old_text };
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::AfterOldText { old_text } => {
|
||||
if let Some(start) = self.buffer.find("<new_text>") {
|
||||
self.buffer.drain(..start + "<new_text>".len());
|
||||
self.state = EditParserState::WithinNewText {
|
||||
old_text: mem::take(old_text),
|
||||
};
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
EditParserState::WithinNewText { old_text } => {
|
||||
if let Some(end) = self.buffer.find("</new_text>") {
|
||||
let mut start = 0;
|
||||
if self.buffer.starts_with('\n') {
|
||||
start = 1;
|
||||
}
|
||||
let mut new_text = self.buffer[start..end].to_string();
|
||||
if new_text.ends_with('\n') {
|
||||
new_text.pop();
|
||||
}
|
||||
edits.push(Edit {
|
||||
old_text: mem::take(old_text),
|
||||
new_text,
|
||||
});
|
||||
|
||||
self.buffer.drain(..end + "</new_text>".len());
|
||||
self.state = EditParserState::Pending;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edits
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
use rand::prelude::*;
|
||||
use std::cmp;
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_single_edit(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>original</old_text><new_text>updated</new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "original".to_string(),
|
||||
new_text: "updated".to_string(),
|
||||
}]
|
||||
)
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_multiple_edits(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
indoc! {"
|
||||
<old_text>
|
||||
first old
|
||||
</old_text><new_text>first new</new_text>
|
||||
<old_text>second old</old_text><new_text>
|
||||
second new
|
||||
</new_text>
|
||||
"},
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "first old".to_string(),
|
||||
new_text: "first new".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "second old".to_string(),
|
||||
new_text: "second new".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_edits_with_extra_text(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
indoc! {"
|
||||
ignore this <old_text>
|
||||
content</old_text>extra stuff<new_text>updated content</new_text>trailing data
|
||||
more text <old_text>second item
|
||||
</old_text>middle text<new_text>modified second item</new_text>end
|
||||
<old_text>third case</old_text><new_text>improved third case</new_text> with trailing text
|
||||
"},
|
||||
&mut rng
|
||||
),
|
||||
vec![
|
||||
Edit {
|
||||
old_text: "content".to_string(),
|
||||
new_text: "updated content".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "second item".to_string(),
|
||||
new_text: "modified second item".to_string(),
|
||||
},
|
||||
Edit {
|
||||
old_text: "third case".to_string(),
|
||||
new_text: "improved third case".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_nested_tags(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>code with <tag>nested</tag> elements</old_text><new_text>new <code>content</code></new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "code with <tag>nested</tag> elements".to_string(),
|
||||
new_text: "new <code>content</code>".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_empty_old_and_new_text(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse("<old_text></old_text><new_text></new_text>", &mut rng),
|
||||
vec![Edit {
|
||||
old_text: "".to_string(),
|
||||
new_text: "".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 1000)]
|
||||
fn test_with_special_characters(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>function(x) { return x * 2; }</old_text><new_text>function(x) { return x ** 2; }</new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "function(x) { return x * 2; }".to_string(),
|
||||
new_text: "function(x) { return x ** 2; }".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test(iterations = 100)]
|
||||
fn test_multiline_content(mut rng: StdRng) {
|
||||
assert_eq!(
|
||||
parse(
|
||||
"<old_text>line1\nline2\nline3</old_text><new_text>line1\nmodified line2\nline3</new_text>",
|
||||
&mut rng
|
||||
),
|
||||
vec![Edit {
|
||||
old_text: "line1\nline2\nline3".to_string(),
|
||||
new_text: "line1\nmodified line2\nline3".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
fn parse(input: &str, rng: &mut StdRng) -> Vec<Edit> {
|
||||
let mut parser = EditParser::new();
|
||||
let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
|
||||
let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
|
||||
chunk_indices.sort();
|
||||
|
||||
let mut edits = Vec::new();
|
||||
let mut last_ix = 0;
|
||||
for chunk_ix in chunk_indices {
|
||||
edits.extend(parser.push(&input[last_ix..chunk_ix]));
|
||||
last_ix = chunk_ix;
|
||||
}
|
||||
edits.extend(parser.push(&input[last_ix..]));
|
||||
edits
|
||||
}
|
||||
}
|
||||
@@ -1,39 +1,18 @@
|
||||
use crate::{Templates, edit_agent::EditAgent, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
|
||||
};
|
||||
use language::{
|
||||
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolSchemaFormat,
|
||||
};
|
||||
use crate::{replace::replace_with_flexible_indent, schema::json_schema_for};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use ui::{Disclosure, Tooltip, Window, prelude::*};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
use ui::IconName;
|
||||
|
||||
use crate::replace::replace_exact;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct EditFileToolInput {
|
||||
/// A user-friendly markdown description of the edit. This will be shown in the UI.
|
||||
///
|
||||
/// <example>Fix API endpoint URLs</example>
|
||||
/// <example>Update copyright year in `page_footer`</example>
|
||||
///
|
||||
/// Make sure to include this field before all the others in the input object
|
||||
/// so that we can display it immediately.
|
||||
pub display_description: String,
|
||||
|
||||
/// The full path of the file to modify in the project.
|
||||
///
|
||||
/// WARNING: When specifying which file path need changing, you MUST
|
||||
@@ -55,35 +34,17 @@ pub struct EditFileToolInput {
|
||||
/// </example>
|
||||
pub path: PathBuf,
|
||||
|
||||
/// Edit instructions that will be interpreted by a less intelligent model,
|
||||
/// which will quickly apply the edits. You should make it clear what the
|
||||
/// edits are, while also minimizing the unchanged code you write. The model
|
||||
/// does not have access to this conversation, so you must make sure the
|
||||
/// instructions are self-contained and do not rely on external context.
|
||||
/// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
|
||||
///
|
||||
/// Insert `// ... existing code ...` comments in your output to represent
|
||||
/// unchanged code ABOVE, BELOW, and IN BETWEEN edited lines.
|
||||
///
|
||||
/// Bias towards repeating as few lines of the original file as possible to
|
||||
/// convey the change. However, each edit should contain sufficient context
|
||||
/// of unchanged lines to resolve ambiguity. When you want to delete a piece
|
||||
/// of code, indicate a few lines above and below the code you want to
|
||||
/// delete (surrounded by `// ... existing code ...`).
|
||||
///
|
||||
/// Never forget to include `// ... existing code ...` comments to represent
|
||||
/// unchanged lines, otherwise the small model may not understand the
|
||||
/// context of your edit and will delete important code!
|
||||
///
|
||||
/// <your_output>
|
||||
/// // ... existing code ...
|
||||
/// FIRST_EDIT
|
||||
/// // ... existing code ...
|
||||
/// SECOND_EDIT
|
||||
/// // ... existing code ...
|
||||
/// THIRD_EDIT
|
||||
/// // ... existing code ...
|
||||
/// </your_output>
|
||||
pub edit_instructions: String,
|
||||
/// <example>Fix API endpoint URLs</example>
|
||||
/// <example>Update copyright year in `page_footer`</example>
|
||||
pub display_description: String,
|
||||
|
||||
/// The text to replace.
|
||||
pub old_string: String,
|
||||
|
||||
/// The text to replace it with.
|
||||
pub new_string: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -92,11 +53,15 @@ struct PartialInput {
|
||||
path: String,
|
||||
#[serde(default)]
|
||||
display_description: String,
|
||||
#[serde(default)]
|
||||
old_string: String,
|
||||
#[serde(default)]
|
||||
new_string: String,
|
||||
}
|
||||
|
||||
pub struct EditFileTool;
|
||||
|
||||
const DEFAULT_UI_TEXT: &str = "Editing file";
|
||||
const DEFAULT_UI_TEXT: &str = "Edit file";
|
||||
|
||||
impl Tool for EditFileTool {
|
||||
fn name(&self) -> String {
|
||||
@@ -122,7 +87,7 @@ impl Tool for EditFileTool {
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<EditFileToolInput>(input.clone()) {
|
||||
Ok(input) => input.display_description,
|
||||
Err(_) => "Editing file".to_string(),
|
||||
Err(_) => "Edit file".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,7 +113,6 @@ impl Tool for EditFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<EditFileToolInput>(input) {
|
||||
@@ -156,454 +120,98 @@ impl Tool for EditFileTool {
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
|
||||
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
|
||||
return Task::ready(Err(anyhow!(
|
||||
"Path {} not found in project",
|
||||
input.path.display()
|
||||
)))
|
||||
.into();
|
||||
};
|
||||
let Some(worktree) = project
|
||||
.read(cx)
|
||||
.worktree_for_id(project_path.worktree_id, cx)
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Worktree not found for project path"))).into();
|
||||
};
|
||||
let exists = worktree.update(cx, |worktree, cx| {
|
||||
worktree.file_exists(&project_path.path, cx)
|
||||
});
|
||||
|
||||
let card = window.and_then(|window| {
|
||||
window
|
||||
.update(cx, |_, window, cx| {
|
||||
cx.new(|cx| {
|
||||
EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
});
|
||||
|
||||
let card_clone = card.clone();
|
||||
// todo!("read model from settings...")
|
||||
let models = LanguageModelRegistry::read_global(cx);
|
||||
let model = models
|
||||
.available_models(cx)
|
||||
.find(|model| model.id().0 == "gemini-2.0-flash")
|
||||
.unwrap();
|
||||
let provider = models.provider(&model.provider_id()).unwrap();
|
||||
let authenticated = provider.authenticate(cx);
|
||||
|
||||
// todo!("reuse templates")
|
||||
let edit_agent = EditAgent::new(model, action_log, Templates::new());
|
||||
let task = cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
authenticated.await?;
|
||||
if !exists.await? {
|
||||
return Err(anyhow!("{} not found", input.path.display()));
|
||||
}
|
||||
cx.spawn(async move |cx: &mut AsyncApp| {
|
||||
let project_path = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.find_project_path(&input.path, cx)
|
||||
.context("Path not found in project")
|
||||
})??;
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})?
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
|
||||
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
edit_agent
|
||||
.edit(buffer.clone(), input.edit_instructions.clone(), cx)
|
||||
.await?;
|
||||
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer, cx))?
|
||||
.await?;
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let old_text = cx.background_spawn({
|
||||
let old_snapshot = old_snapshot.clone();
|
||||
async move { old_snapshot.text() }
|
||||
});
|
||||
let new_text = cx.background_spawn({
|
||||
let new_snapshot = new_snapshot.clone();
|
||||
async move { new_snapshot.text() }
|
||||
});
|
||||
let diff = cx.background_spawn(async move {
|
||||
language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
|
||||
});
|
||||
let (old_text, new_text, diff) = futures::join!(old_text, new_text, diff);
|
||||
|
||||
if let Some(card) = card_clone {
|
||||
card.update(cx, |card, cx| {
|
||||
card.set_diff(project_path.path.clone(), old_text, new_text, cx);
|
||||
})
|
||||
.log_err();
|
||||
if input.old_string.is_empty() {
|
||||
return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file."));
|
||||
}
|
||||
|
||||
Ok(format!(
|
||||
"Edited {}:\n\n```diff\n{}\n```",
|
||||
input.path.display(),
|
||||
diff
|
||||
))
|
||||
});
|
||||
if input.old_string == input.new_string {
|
||||
return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made."));
|
||||
}
|
||||
|
||||
ToolResult {
|
||||
output: task,
|
||||
card: card.map(AnyToolCard::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
let result = cx
|
||||
.background_spawn(async move {
|
||||
// Try to match exactly
|
||||
let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
|
||||
.await
|
||||
// If that fails, try being flexible about indentation
|
||||
.or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?;
|
||||
|
||||
pub struct EditFileToolCard {
|
||||
path: PathBuf,
|
||||
editor: Entity<Editor>,
|
||||
multibuffer: Entity<MultiBuffer>,
|
||||
project: Entity<Project>,
|
||||
diff_task: Option<Task<Result<()>>>,
|
||||
preview_expanded: bool,
|
||||
full_height_expanded: bool,
|
||||
editor_unique_id: EntityId,
|
||||
}
|
||||
if diff.edits.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
impl EditFileToolCard {
|
||||
fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
|
||||
let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
|
||||
let editor = cx.new(|cx| {
|
||||
let mut editor = Editor::new(
|
||||
EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: true,
|
||||
},
|
||||
multibuffer.clone(),
|
||||
Some(project.clone()),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_show_scrollbars(false, cx);
|
||||
editor.set_show_gutter(false, cx);
|
||||
editor.disable_inline_diagnostics();
|
||||
editor.disable_scrolling(cx);
|
||||
editor.disable_expand_excerpt_buttons(cx);
|
||||
editor.set_show_breakpoints(false, cx);
|
||||
editor.set_show_code_actions(false, cx);
|
||||
editor.set_show_git_diff_gutter(false, cx);
|
||||
editor.set_expand_all_diff_hunks(cx);
|
||||
editor
|
||||
});
|
||||
Self {
|
||||
editor_unique_id: editor.entity_id(),
|
||||
path,
|
||||
project,
|
||||
editor,
|
||||
multibuffer,
|
||||
diff_task: None,
|
||||
preview_expanded: true,
|
||||
full_height_expanded: false,
|
||||
}
|
||||
}
|
||||
let old_text = snapshot.text();
|
||||
|
||||
fn set_diff(
|
||||
&mut self,
|
||||
path: Arc<Path>,
|
||||
old_text: String,
|
||||
new_text: String,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let language_registry = self.project.read(cx).languages().clone();
|
||||
self.diff_task = Some(cx.spawn(async move |this, cx| {
|
||||
let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
|
||||
let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
|
||||
Some((old_text, diff))
|
||||
})
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.multibuffer.update(cx, |multibuffer, cx| {
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let diff = buffer_diff.read(cx);
|
||||
let diff_hunk_ranges = diff
|
||||
.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
|
||||
.map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
|
||||
.collect::<Vec<_>>();
|
||||
let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
|
||||
PathKey::for_buffer(&buffer, cx),
|
||||
buffer,
|
||||
diff_hunk_ranges,
|
||||
editor::DEFAULT_MULTIBUFFER_CONTEXT,
|
||||
cx,
|
||||
);
|
||||
debug_assert!(is_newly_added);
|
||||
multibuffer.add_diff(buffer_diff, cx);
|
||||
let Some((old_text, diff)) = result else {
|
||||
let err = buffer.read_with(cx, |buffer, _cx| {
|
||||
let file_exists = buffer
|
||||
.file()
|
||||
.map_or(false, |file| file.disk_state().exists());
|
||||
|
||||
if !file_exists {
|
||||
anyhow!("{} does not exist", input.path.display())
|
||||
} else if buffer.is_empty() {
|
||||
anyhow!(
|
||||
"{} is empty, so the provided `old_string` wasn't found.",
|
||||
input.path.display()
|
||||
)
|
||||
} else {
|
||||
anyhow!("Failed to match the provided `old_string`")
|
||||
}
|
||||
})?;
|
||||
|
||||
return Err(err)
|
||||
};
|
||||
|
||||
let snapshot = cx.update(|cx| {
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer.clone(), cx)
|
||||
});
|
||||
cx.notify();
|
||||
})
|
||||
}));
|
||||
let snapshot = buffer.update(cx, |buffer, cx| {
|
||||
buffer.finalize_last_transaction();
|
||||
buffer.apply_diff(diff, cx);
|
||||
buffer.finalize_last_transaction();
|
||||
buffer.snapshot()
|
||||
});
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_edited(buffer.clone(), cx)
|
||||
});
|
||||
snapshot
|
||||
})?;
|
||||
|
||||
project.update( cx, |project, cx| {
|
||||
project.save_buffer(buffer, cx)
|
||||
})?.await?;
|
||||
|
||||
let diff_str = cx.background_spawn(async move {
|
||||
let new_text = snapshot.text();
|
||||
language::unified_diff(&old_text, &new_text)
|
||||
}).await;
|
||||
|
||||
|
||||
Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str))
|
||||
|
||||
}).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCard for EditFileToolCard {
|
||||
fn render(
|
||||
&mut self,
|
||||
status: &ToolUseStatus,
|
||||
window: &mut Window,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let failed = matches!(status, ToolUseStatus::Error(_));
|
||||
|
||||
let path_label_button = h_flex()
|
||||
.id(("edit-tool-path-label-button", self.editor_unique_id))
|
||||
.w_full()
|
||||
.max_w_full()
|
||||
.px_1()
|
||||
.gap_0p5()
|
||||
.cursor_pointer()
|
||||
.rounded_sm()
|
||||
.opacity(0.8)
|
||||
.hover(|label| {
|
||||
label
|
||||
.opacity(1.)
|
||||
.bg(cx.theme().colors().element_hover.opacity(0.5))
|
||||
})
|
||||
.tooltip(Tooltip::text("Jump to File"))
|
||||
.child(
|
||||
h_flex()
|
||||
.child(
|
||||
Icon::new(IconName::Pencil)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
div()
|
||||
.text_size(rems(0.8125))
|
||||
.child(self.path.display().to_string())
|
||||
.ml_1p5()
|
||||
.mr_0p5(),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::ArrowUpRight)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Ignored),
|
||||
),
|
||||
)
|
||||
.on_click({
|
||||
let path = self.path.clone();
|
||||
let workspace = workspace.clone();
|
||||
move |_, window, cx| {
|
||||
workspace
|
||||
.update(cx, {
|
||||
|workspace, cx| {
|
||||
let Some(project_path) =
|
||||
workspace.project().read(cx).find_project_path(&path, cx)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let open_task =
|
||||
workspace.open_path(project_path, None, true, window, cx);
|
||||
window
|
||||
.spawn(cx, async move |cx| {
|
||||
let item = open_task.await?;
|
||||
if let Some(active_editor) = item.downcast::<Editor>() {
|
||||
active_editor
|
||||
.update_in(cx, |editor, window, cx| {
|
||||
editor.go_to_singleton_buffer_point(
|
||||
language::Point::new(0, 0),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
anyhow::Ok(())
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
})
|
||||
.into_any_element();
|
||||
|
||||
let codeblock_header_bg = cx
|
||||
.theme()
|
||||
.colors()
|
||||
.element_background
|
||||
.blend(cx.theme().colors().editor_foreground.opacity(0.025));
|
||||
|
||||
let codeblock_header = h_flex()
|
||||
.flex_none()
|
||||
.p_1()
|
||||
.gap_1()
|
||||
.justify_between()
|
||||
.rounded_t_md()
|
||||
.when(!failed, |header| header.bg(codeblock_header_bg))
|
||||
.child(path_label_button)
|
||||
.map(|container| {
|
||||
if failed {
|
||||
container.child(
|
||||
Icon::new(IconName::Close)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Error),
|
||||
)
|
||||
} else {
|
||||
container.child(
|
||||
Disclosure::new(
|
||||
("edit-file-disclosure", self.editor_unique_id),
|
||||
self.preview_expanded,
|
||||
)
|
||||
.opened_icon(IconName::ChevronUp)
|
||||
.closed_icon(IconName::ChevronDown)
|
||||
.on_click(cx.listener(
|
||||
move |this, _event, _window, _cx| {
|
||||
this.preview_expanded = !this.preview_expanded;
|
||||
},
|
||||
)),
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
let editor = self.editor.update(cx, |editor, cx| {
|
||||
editor.render(window, cx).into_any_element()
|
||||
});
|
||||
|
||||
let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
|
||||
(IconName::ChevronUp, "Collapse Code Block")
|
||||
} else {
|
||||
(IconName::ChevronDown, "Expand Code Block")
|
||||
};
|
||||
|
||||
let gradient_overlay = div()
|
||||
.absolute()
|
||||
.bottom_0()
|
||||
.left_0()
|
||||
.w_full()
|
||||
.h_2_5()
|
||||
.rounded_b_lg()
|
||||
.bg(gpui::linear_gradient(
|
||||
0.,
|
||||
gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
|
||||
gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
|
||||
));
|
||||
|
||||
let border_color = cx.theme().colors().border.opacity(0.6);
|
||||
|
||||
v_flex()
|
||||
.mb_2()
|
||||
.border_1()
|
||||
.when(failed, |card| card.border_dashed())
|
||||
.border_color(border_color)
|
||||
.rounded_lg()
|
||||
.overflow_hidden()
|
||||
.child(codeblock_header)
|
||||
.when(!failed && self.preview_expanded, |card| {
|
||||
card.child(
|
||||
v_flex()
|
||||
.relative()
|
||||
.overflow_hidden()
|
||||
.border_t_1()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.map(|editor_container| {
|
||||
if self.full_height_expanded {
|
||||
editor_container.h_full()
|
||||
} else {
|
||||
editor_container.max_h_64()
|
||||
}
|
||||
})
|
||||
.child(div().pl_1().child(editor))
|
||||
.when(!self.full_height_expanded, |editor_container| {
|
||||
editor_container.child(gradient_overlay)
|
||||
}),
|
||||
)
|
||||
})
|
||||
.when(!failed && self.preview_expanded, |card| {
|
||||
card.child(
|
||||
h_flex()
|
||||
.id(("edit-tool-card-inner-hflex", self.editor_unique_id))
|
||||
.flex_none()
|
||||
.cursor_pointer()
|
||||
.h_5()
|
||||
.justify_center()
|
||||
.rounded_b_md()
|
||||
.border_t_1()
|
||||
.border_color(border_color)
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
|
||||
.child(
|
||||
Icon::new(full_height_icon)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.tooltip(Tooltip::text(full_height_tooltip_label))
|
||||
.on_click(cx.listener(move |this, _event, _window, _cx| {
|
||||
this.full_height_expanded = !this.full_height_expanded;
|
||||
})),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_buffer(
|
||||
mut text: String,
|
||||
path: Arc<Path>,
|
||||
language_registry: &Arc<language::LanguageRegistry>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Buffer>> {
|
||||
let line_ending = LineEnding::detect(&text);
|
||||
LineEnding::normalize(&mut text);
|
||||
let text = Rope::from(text);
|
||||
let language = cx
|
||||
.update(|_cx| language_registry.language_for_file_path(&path))?
|
||||
.await
|
||||
.ok();
|
||||
let buffer = cx.new(|cx| {
|
||||
let buffer = TextBuffer::new_normalized(
|
||||
0,
|
||||
cx.entity_id().as_non_zero_u64().into(),
|
||||
line_ending,
|
||||
text,
|
||||
);
|
||||
let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
|
||||
buffer.set_language(language, cx);
|
||||
buffer
|
||||
})?;
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
async fn build_buffer_diff(
|
||||
mut old_text: String,
|
||||
buffer: &Entity<Buffer>,
|
||||
language_registry: &Arc<LanguageRegistry>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<BufferDiff>> {
|
||||
LineEnding::normalize(&mut old_text);
|
||||
|
||||
let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
|
||||
let base_buffer = cx
|
||||
.update(|cx| {
|
||||
Buffer::build_snapshot(
|
||||
old_text.clone().into(),
|
||||
buffer.language().cloned(),
|
||||
Some(language_registry.clone()),
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
let diff_snapshot = cx
|
||||
.update(|cx| {
|
||||
BufferDiffSnapshot::new_with_base_buffer(
|
||||
buffer.text.clone(),
|
||||
Some(old_text.into()),
|
||||
base_buffer,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut diff = BufferDiff::new(&buffer.text, cx);
|
||||
diff.set_snapshot(diff_snapshot, &buffer.text, cx);
|
||||
diff
|
||||
})
|
||||
}
|
||||
|
||||
// todo!("add unit tests for failure modes of edit, like file not found, etc.")
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -7,4 +7,39 @@ Before using this tool:
|
||||
2. Verify the directory path is correct (only applicable when creating new files):
|
||||
- Use the `list_directory` tool to verify the parent directory exists and is the correct location
|
||||
|
||||
Group coherent edits together and include all of them in a single call to this tool. Add the full context needed for a small model to understand the edits.
|
||||
To make a file edit, provide the following:
|
||||
1. path: The full path to the file you wish to modify in the project. This path must include the root directory in the project.
|
||||
2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation)
|
||||
3. new_string: The edited text, which will replace the old_string in the file.
|
||||
|
||||
The tool will replace ONE occurrence of old_string with new_string in the specified file.
|
||||
|
||||
CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
|
||||
1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means:
|
||||
- Include AT LEAST 3-5 lines of context BEFORE the change point
|
||||
- Include AT LEAST 3-5 lines of context AFTER the change point
|
||||
- Include all whitespace, indentation, and surrounding code exactly as it appears in the file
|
||||
|
||||
2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances:
|
||||
- Make separate calls to this tool for each instance
|
||||
- Each call must uniquely identify its specific instance using extensive context
|
||||
|
||||
3. VERIFICATION: Before using this tool:
|
||||
- Check how many instances of the target text exist in the file
|
||||
- If multiple instances exist, gather enough context to uniquely identify each one
|
||||
- Plan separate tool calls for each instance
|
||||
|
||||
WARNING: If you do not follow these requirements:
|
||||
- The tool will fail if old_string matches multiple locations
|
||||
- The tool will fail if old_string doesn't match exactly (including whitespace)
|
||||
- You may change the wrong instance if you don't include enough context
|
||||
|
||||
When making edits:
|
||||
- Ensure the edit results in idiomatic, correct code
|
||||
- Do not leave the code in a broken state
|
||||
- Always use fully-qualified project paths (starting with the name of one of the project's root directories)
|
||||
|
||||
If you want to create a new file, use the `create_file` tool instead of this tool. Don't pass an empty `old_string`.
|
||||
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task};
|
||||
use gpui::{App, AppContext as _, Entity, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
@@ -145,7 +145,6 @@ impl Tool for FetchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<FetchToolInput>(input) {
|
||||
|
||||
@@ -1,328 +0,0 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -1,374 +0,0 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -1,378 +0,0 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
handle_command_output(output)
|
||||
}
|
||||
|
||||
fn handle_command_output(output: std::process::Output) -> Result<String> {
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -1,374 +0,0 @@
|
||||
use crate::commit::get_messages;
|
||||
use crate::{GitRemote, Oid};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::AsyncWriteExt;
|
||||
use gpui::SharedString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::process::Stdio;
|
||||
use std::{ops::Range, path::Path};
|
||||
use text::Rope;
|
||||
use time::OffsetDateTime;
|
||||
use time::UtcOffset;
|
||||
use time::macros::format_description;
|
||||
|
||||
pub use git2 as libgit;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Blame {
|
||||
pub entries: Vec<BlameEntry>,
|
||||
pub messages: HashMap<Oid, String>,
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ParsedCommitMessage {
|
||||
pub message: SharedString,
|
||||
pub permalink: Option<url::Url>,
|
||||
pub pull_request: Option<crate::hosting_provider::PullRequest>,
|
||||
pub remote: Option<GitRemote>,
|
||||
}
|
||||
|
||||
impl Blame {
|
||||
pub async fn for_path(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
content: &Rope,
|
||||
remote_url: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let output = run_git_blame(git_binary, working_directory, path, content).await?;
|
||||
let mut entries = parse_git_blame(&output)?;
|
||||
entries.sort_unstable_by(|a, b| a.range.start.cmp(&b.range.start));
|
||||
|
||||
let mut unique_shas = HashSet::default();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
unique_shas.insert(entry.sha);
|
||||
}
|
||||
|
||||
let shas = unique_shas.into_iter().collect::<Vec<_>>();
|
||||
let messages = get_messages(working_directory, &shas)
|
||||
.await
|
||||
.context("failed to get commit messages")?;
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
messages,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const GIT_BLAME_NO_COMMIT_ERROR: &str = "fatal: no such ref: HEAD";
|
||||
const GIT_BLAME_NO_PATH: &str = "fatal: no such path";
|
||||
|
||||
async fn run_git_blame(
|
||||
git_binary: &Path,
|
||||
working_directory: &Path,
|
||||
path: &Path,
|
||||
contents: &Rope,
|
||||
) -> Result<String> {
|
||||
let mut child = util::command::new_smol_command(git_binary)
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow!("Failed to start git blame process: {}", e))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.context("failed to get pipe to stdin of git blame command")?;
|
||||
|
||||
for chunk in contents.chunks() {
|
||||
stdin.write_all(chunk.as_bytes()).await?;
|
||||
}
|
||||
stdin.flush().await?;
|
||||
|
||||
let output = child
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to read git blame output: {}", e))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let trimmed = stderr.trim();
|
||||
if trimmed == GIT_BLAME_NO_COMMIT_ERROR || trimmed.contains(GIT_BLAME_NO_PATH) {
|
||||
return Ok(String::new());
|
||||
}
|
||||
return Err(anyhow!("git blame process failed: {}", stderr));
|
||||
}
|
||||
|
||||
Ok(String::from_utf8(output.stdout)?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BlameEntry {
|
||||
pub sha: Oid,
|
||||
|
||||
pub range: Range<u32>,
|
||||
|
||||
pub original_line_number: u32,
|
||||
|
||||
pub author: Option<String>,
|
||||
pub author_mail: Option<String>,
|
||||
pub author_time: Option<i64>,
|
||||
pub author_tz: Option<String>,
|
||||
|
||||
pub committer_name: Option<String>,
|
||||
pub committer_email: Option<String>,
|
||||
pub committer_time: Option<i64>,
|
||||
pub committer_tz: Option<String>,
|
||||
|
||||
pub summary: Option<String>,
|
||||
|
||||
pub previous: Option<String>,
|
||||
pub filename: String,
|
||||
}
|
||||
|
||||
impl BlameEntry {
|
||||
// Returns a BlameEntry by parsing the first line of a `git blame --incremental`
|
||||
// entry. The line MUST have this format:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
fn new_from_blame_line(line: &str) -> Result<BlameEntry> {
|
||||
let mut parts = line.split_whitespace();
|
||||
|
||||
let sha = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<Oid>().ok())
|
||||
.ok_or_else(|| anyhow!("failed to parse sha"))?;
|
||||
|
||||
let original_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse original line number"))?;
|
||||
let final_line_number = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let line_count = parts
|
||||
.next()
|
||||
.and_then(|line| line.parse::<u32>().ok())
|
||||
.ok_or_else(|| anyhow!("Failed to parse final line number"))?;
|
||||
|
||||
let start_line = final_line_number.saturating_sub(1);
|
||||
let end_line = start_line + line_count;
|
||||
let range = start_line..end_line;
|
||||
|
||||
Ok(Self {
|
||||
sha,
|
||||
range,
|
||||
original_line_number,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn author_offset_date_time(&self) -> Result<time::OffsetDateTime> {
|
||||
if let (Some(author_time), Some(author_tz)) = (self.author_time, &self.author_tz) {
|
||||
let format = format_description!("[offset_hour][offset_minute]");
|
||||
let offset = UtcOffset::parse(author_tz, &format)?;
|
||||
let date_time_utc = OffsetDateTime::from_unix_timestamp(author_time)?;
|
||||
|
||||
Ok(date_time_utc.to_offset(offset))
|
||||
} else {
|
||||
// Directly return current time in UTC if there's no committer time or timezone
|
||||
Ok(time::OffsetDateTime::now_utc())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse_git_blame parses the output of `git blame --incremental`, which returns
|
||||
// all the blame-entries for a given path incrementally, as it finds them.
|
||||
//
|
||||
// Each entry *always* starts with:
|
||||
//
|
||||
// <40-byte-hex-sha1> <sourceline> <resultline> <num-lines>
|
||||
//
|
||||
// Each entry *always* ends with:
|
||||
//
|
||||
// filename <whitespace-quoted-filename-goes-here>
|
||||
//
|
||||
// Line numbers are 1-indexed.
|
||||
//
|
||||
// A `git blame --incremental` entry looks like this:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 2 2 1
|
||||
// author Joe Schmoe
|
||||
// author-mail <joe.schmoe@example.com>
|
||||
// author-time 1709741400
|
||||
// author-tz +0100
|
||||
// committer Joe Schmoe
|
||||
// committer-mail <joe.schmoe@example.com>
|
||||
// committer-time 1709741400
|
||||
// committer-tz +0100
|
||||
// summary Joe's cool commit
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// If the entry has the same SHA as an entry that was already printed then no
|
||||
// signature information is printed:
|
||||
//
|
||||
// 6ad46b5257ba16d12c5ca9f0d4900320959df7f4 3 4 1
|
||||
// previous 486c2409237a2c627230589e567024a96751d475 index.js
|
||||
// filename index.js
|
||||
//
|
||||
// More about `--incremental` output: https://mirrors.edge.kernel.org/pub/software/scm/git/docs/git-blame.html
|
||||
fn parse_git_blame(output: &str) -> Result<Vec<BlameEntry>> {
|
||||
let mut entries: Vec<BlameEntry> = Vec::new();
|
||||
let mut index: HashMap<Oid, usize> = HashMap::default();
|
||||
|
||||
let mut current_entry: Option<BlameEntry> = None;
|
||||
|
||||
for line in output.lines() {
|
||||
let mut done = false;
|
||||
|
||||
match &mut current_entry {
|
||||
None => {
|
||||
let mut new_entry = BlameEntry::new_from_blame_line(line)?;
|
||||
|
||||
if let Some(existing_entry) = index
|
||||
.get(&new_entry.sha)
|
||||
.and_then(|slot| entries.get(*slot))
|
||||
{
|
||||
new_entry.author.clone_from(&existing_entry.author);
|
||||
new_entry
|
||||
.author_mail
|
||||
.clone_from(&existing_entry.author_mail);
|
||||
new_entry.author_time = existing_entry.author_time;
|
||||
new_entry.author_tz.clone_from(&existing_entry.author_tz);
|
||||
new_entry
|
||||
.committer_name
|
||||
.clone_from(&existing_entry.committer_name);
|
||||
new_entry
|
||||
.committer_email
|
||||
.clone_from(&existing_entry.committer_email);
|
||||
new_entry.committer_time = existing_entry.committer_time;
|
||||
new_entry
|
||||
.committer_tz
|
||||
.clone_from(&existing_entry.committer_tz);
|
||||
new_entry.summary.clone_from(&existing_entry.summary);
|
||||
}
|
||||
|
||||
current_entry.replace(new_entry);
|
||||
}
|
||||
Some(entry) => {
|
||||
let Some((key, value)) = line.split_once(' ') else {
|
||||
continue;
|
||||
};
|
||||
let is_committed = !entry.sha.is_zero();
|
||||
match key {
|
||||
"filename" => {
|
||||
entry.filename = value.into();
|
||||
done = true;
|
||||
}
|
||||
"previous" => entry.previous = Some(value.into()),
|
||||
|
||||
"summary" if is_committed => entry.summary = Some(value.into()),
|
||||
"author" if is_committed => entry.author = Some(value.into()),
|
||||
"author-mail" if is_committed => entry.author_mail = Some(value.into()),
|
||||
"author-time" if is_committed => {
|
||||
entry.author_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"author-tz" if is_committed => entry.author_tz = Some(value.into()),
|
||||
|
||||
"committer" if is_committed => entry.committer_name = Some(value.into()),
|
||||
"committer-mail" if is_committed => entry.committer_email = Some(value.into()),
|
||||
"committer-time" if is_committed => {
|
||||
entry.committer_time = Some(value.parse::<i64>()?)
|
||||
}
|
||||
"committer-tz" if is_committed => entry.committer_tz = Some(value.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if done {
|
||||
if let Some(entry) = current_entry.take() {
|
||||
index.insert(entry.sha, entries.len());
|
||||
|
||||
// We only want annotations that have a commit.
|
||||
if !entry.sha.is_zero() {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::BlameEntry;
|
||||
use super::parse_git_blame;
|
||||
|
||||
fn read_test_data(filename: &str) -> String {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push(filename);
|
||||
|
||||
std::fs::read_to_string(&path)
|
||||
.unwrap_or_else(|_| panic!("Could not read test data at {:?}. Is it generated?", path))
|
||||
}
|
||||
|
||||
fn assert_eq_golden(entries: &Vec<BlameEntry>, golden_filename: &str) {
|
||||
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
path.push("test_data");
|
||||
path.push("golden");
|
||||
path.push(format!("{}.json", golden_filename));
|
||||
|
||||
let mut have_json =
|
||||
serde_json::to_string_pretty(&entries).expect("could not serialize entries to JSON");
|
||||
// We always want to save with a trailing newline.
|
||||
have_json.push('\n');
|
||||
|
||||
let update = std::env::var("UPDATE_GOLDEN")
|
||||
.map(|val| val.eq_ignore_ascii_case("true"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if update {
|
||||
std::fs::create_dir_all(path.parent().unwrap())
|
||||
.expect("could not create golden test data directory");
|
||||
std::fs::write(&path, have_json).expect("could not write out golden data");
|
||||
} else {
|
||||
let want_json =
|
||||
std::fs::read_to_string(&path).unwrap_or_else(|_| {
|
||||
panic!("could not read golden test data file at {:?}. Did you run the test with UPDATE_GOLDEN=true before?", path);
|
||||
}).replace("\r\n", "\n");
|
||||
|
||||
pretty_assertions::assert_eq!(have_json, want_json, "wrong blame entries");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_not_committed() {
|
||||
let output = read_test_data("blame_incremental_not_committed");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_not_committed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_simple() {
|
||||
let output = read_test_data("blame_incremental_simple");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_simple");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_git_blame_complex() {
|
||||
let output = read_test_data("blame_incremental_complex");
|
||||
let entries = parse_git_blame(&output).unwrap();
|
||||
assert_eq_golden(&entries, "blame_incremental_complex");
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::StreamExt;
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::OffsetRangeExt;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::{
|
||||
@@ -20,8 +20,6 @@ use util::paths::PathMatcher;
|
||||
pub struct GrepToolInput {
|
||||
/// A regex pattern to search for in the entire project. Note that the regex
|
||||
/// will be parsed by the Rust `regex` crate.
|
||||
///
|
||||
/// Do NOT specify a path here! This will only be matched against the code **content**.
|
||||
pub regex: String,
|
||||
|
||||
/// A glob pattern for the paths of files to include in the search.
|
||||
@@ -98,7 +96,6 @@ impl Tool for GrepTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
@@ -408,7 +405,7 @@ mod tests {
|
||||
) -> String {
|
||||
let tool = Arc::new(GrepTool);
|
||||
let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
|
||||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
|
||||
let task = cx.update(|cx| tool.run(input, &[], project, action_log, cx));
|
||||
|
||||
match task.output.await {
|
||||
Ok(result) => result,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -76,7 +76,6 @@ impl Tool for ListDirectoryTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ListDirectoryToolInput>(input) {
|
||||
|
||||
@@ -1 +1 @@
|
||||
Lists files and directories in a given path. Prefer the `grep` or `find_path` tools when searching the codebase.
|
||||
Lists files and directories in a given path. Prefer the `grep` or `path_search` tools when searching the codebase.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -89,7 +89,6 @@ impl Tool for MovePathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<MovePathToolInput>(input) {
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use chrono::{Local, Utc};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -59,7 +59,6 @@ impl Tool for NowTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: NowToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -52,7 +52,6 @@ impl Tool for OpenTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: OpenToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -12,7 +12,7 @@ use util::paths::PathMatcher;
|
||||
use worktree::Snapshot;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FindPathToolInput {
|
||||
pub struct PathSearchToolInput {
|
||||
/// The glob to match against every path in the project.
|
||||
///
|
||||
/// <example>
|
||||
@@ -34,11 +34,11 @@ pub struct FindPathToolInput {
|
||||
|
||||
const RESULTS_PER_PAGE: usize = 50;
|
||||
|
||||
pub struct FindPathTool;
|
||||
pub struct PathSearchTool;
|
||||
|
||||
impl Tool for FindPathTool {
|
||||
impl Tool for PathSearchTool {
|
||||
fn name(&self) -> String {
|
||||
"find_path".into()
|
||||
"path_search".into()
|
||||
}
|
||||
|
||||
fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
|
||||
@@ -46,7 +46,7 @@ impl Tool for FindPathTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> String {
|
||||
include_str!("./find_path_tool/description.md").into()
|
||||
include_str!("./path_search_tool/description.md").into()
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
@@ -54,11 +54,11 @@ impl Tool for FindPathTool {
|
||||
}
|
||||
|
||||
fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
|
||||
json_schema_for::<FindPathToolInput>(format)
|
||||
json_schema_for::<PathSearchToolInput>(format)
|
||||
}
|
||||
|
||||
fn ui_text(&self, input: &serde_json::Value) -> String {
|
||||
match serde_json::from_value::<FindPathToolInput>(input.clone()) {
|
||||
match serde_json::from_value::<PathSearchToolInput>(input.clone()) {
|
||||
Ok(input) => format!("Find paths matching “`{}`”", input.glob),
|
||||
Err(_) => "Search paths".to_string(),
|
||||
}
|
||||
@@ -70,10 +70,9 @@ impl Tool for FindPathTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let (offset, glob) = match serde_json::from_value::<FindPathToolInput>(input) {
|
||||
let (offset, glob) = match serde_json::from_value::<PathSearchToolInput>(input) {
|
||||
Ok(input) => (input.offset, input.glob),
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))).into(),
|
||||
};
|
||||
@@ -144,7 +143,7 @@ mod test {
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_find_path_tool(cx: &mut TestAppContext) {
|
||||
async fn test_path_search_tool(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
@@ -1,4 +1,4 @@
|
||||
Fast file path pattern matching tool that works with any codebase size
|
||||
Fast file pattern matching tool that works with any codebase size
|
||||
|
||||
- Supports glob patterns like "**/*.js" or "src/**/*.ts"
|
||||
- Returns matching file paths sorted alphabetically
|
||||
@@ -1,8 +1,7 @@
|
||||
use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
|
||||
use gpui::{App, Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use itertools::Itertools;
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
@@ -88,7 +87,6 @@ impl Tool for ReadFileTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<ReadFileToolInput>(input) {
|
||||
@@ -136,7 +134,7 @@ impl Tool for ReadFileTool {
|
||||
})?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.track_buffer(buffer, cx);
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
@@ -149,7 +147,7 @@ impl Tool for ReadFileTool {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.track_buffer(buffer, cx);
|
||||
log.buffer_read(buffer, cx);
|
||||
})?;
|
||||
|
||||
Ok(result)
|
||||
@@ -195,7 +193,7 @@ mod test {
|
||||
"path": "root/nonexistent_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -225,7 +223,7 @@ mod test {
|
||||
"path": "root/small_file.txt"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -255,7 +253,7 @@ mod test {
|
||||
"path": "root/large_file.rs"
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log.clone(), None, cx)
|
||||
.run(input, &[], project.clone(), action_log.clone(), cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -279,7 +277,7 @@ mod test {
|
||||
"offset": 1
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
@@ -325,7 +323,7 @@ mod test {
|
||||
"end_line": 4
|
||||
});
|
||||
Arc::new(ReadFileTool)
|
||||
.run(input, &[], project.clone(), action_log, None, cx)
|
||||
.run(input, &[], project.clone(), action_log, cx)
|
||||
.output
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{self, Buffer, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -87,7 +87,6 @@ impl Tool for RenameTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<RenameToolInput>(input) {
|
||||
@@ -107,7 +106,7 @@ impl Tool for RenameTool {
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let position = {
|
||||
|
||||
872
crates/assistant_tools/src/replace.rs
Normal file
872
crates/assistant_tools/src/replace.rs
Normal file
@@ -0,0 +1,872 @@
|
||||
use language::{BufferSnapshot, Diff, Point, ToOffset};
|
||||
use project::search::SearchQuery;
|
||||
use std::iter;
|
||||
use util::{ResultExt as _, paths::PathMatcher};
|
||||
|
||||
/// Performs an exact string replacement in a buffer, requiring precise character-for-character matching.
|
||||
/// Uses the search functionality to locate the first occurrence of the exact string.
|
||||
/// Returns None if no exact match is found in the buffer.
|
||||
pub async fn replace_exact(old: &str, new: &str, snapshot: &BufferSnapshot) -> Option<Diff> {
|
||||
let query = SearchQuery::text(
|
||||
old,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
PathMatcher::new(iter::empty::<&str>()).ok()?,
|
||||
PathMatcher::new(iter::empty::<&str>()).ok()?,
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.log_err()?;
|
||||
|
||||
let matches = query.search(&snapshot, None).await;
|
||||
|
||||
if matches.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let edit_range = matches[0].clone();
|
||||
let diff = language::text_diff(&old, &new);
|
||||
|
||||
let edits = diff
|
||||
.into_iter()
|
||||
.map(|(old_range, text)| {
|
||||
let start = edit_range.start + old_range.start;
|
||||
let end = edit_range.start + old_range.end;
|
||||
(start..end, text)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let diff = language::Diff {
|
||||
base_version: snapshot.version().clone(),
|
||||
line_ending: snapshot.line_ending(),
|
||||
edits,
|
||||
};
|
||||
|
||||
Some(diff)
|
||||
}
|
||||
|
||||
/// Performs a replacement that's indentation-aware - matches text content ignoring leading whitespace differences.
|
||||
/// When replacing, preserves the indentation level found in the buffer at each matching line.
|
||||
/// Returns None if no match found or if indentation is offset inconsistently across matched lines.
|
||||
pub fn replace_with_flexible_indent(old: &str, new: &str, buffer: &BufferSnapshot) -> Option<Diff> {
|
||||
let (old_lines, old_min_indent) = lines_with_min_indent(old);
|
||||
let (new_lines, new_min_indent) = lines_with_min_indent(new);
|
||||
let min_indent = old_min_indent.min(new_min_indent);
|
||||
|
||||
let old_lines = drop_lines_prefix(&old_lines, min_indent);
|
||||
let new_lines = drop_lines_prefix(&new_lines, min_indent);
|
||||
|
||||
let max_row = buffer.max_point().row;
|
||||
|
||||
'windows: for start_row in 0..max_row + 1 {
|
||||
let end_row = start_row + old_lines.len().saturating_sub(1) as u32;
|
||||
|
||||
if end_row > max_row {
|
||||
// The buffer ends before fully matching the pattern
|
||||
return None;
|
||||
}
|
||||
|
||||
let start_point = Point::new(start_row, 0);
|
||||
let end_point = Point::new(end_row, buffer.line_len(end_row));
|
||||
let range = start_point.to_offset(buffer)..end_point.to_offset(buffer);
|
||||
|
||||
let window_text = buffer.text_for_range(range.clone());
|
||||
let mut window_lines = window_text.lines();
|
||||
let mut old_lines_iter = old_lines.iter();
|
||||
|
||||
let mut common_mismatch = None;
|
||||
|
||||
#[derive(Eq, PartialEq)]
|
||||
enum Mismatch {
|
||||
OverIndented(String),
|
||||
UnderIndented(String),
|
||||
}
|
||||
|
||||
while let (Some(window_line), Some(old_line)) = (window_lines.next(), old_lines_iter.next())
|
||||
{
|
||||
let line_trimmed = window_line.trim_start();
|
||||
|
||||
if line_trimmed != old_line.trim_start() {
|
||||
continue 'windows;
|
||||
}
|
||||
|
||||
if line_trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let line_mismatch = if window_line.len() > old_line.len() {
|
||||
let prefix = window_line[..window_line.len() - old_line.len()].to_string();
|
||||
Mismatch::UnderIndented(prefix)
|
||||
} else {
|
||||
let prefix = old_line[..old_line.len() - window_line.len()].to_string();
|
||||
Mismatch::OverIndented(prefix)
|
||||
};
|
||||
|
||||
match &common_mismatch {
|
||||
Some(common_mismatch) if common_mismatch != &line_mismatch => {
|
||||
continue 'windows;
|
||||
}
|
||||
Some(_) => (),
|
||||
None => common_mismatch = Some(line_mismatch),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(common_mismatch) = &common_mismatch {
|
||||
let line_ending = buffer.line_ending();
|
||||
let replacement = new_lines
|
||||
.iter()
|
||||
.map(|new_line| {
|
||||
if new_line.trim().is_empty() {
|
||||
new_line.to_string()
|
||||
} else {
|
||||
match common_mismatch {
|
||||
Mismatch::UnderIndented(prefix) => prefix.to_string() + new_line,
|
||||
Mismatch::OverIndented(prefix) => new_line
|
||||
.strip_prefix(prefix)
|
||||
.unwrap_or(new_line)
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(line_ending.as_str());
|
||||
|
||||
let diff = Diff {
|
||||
base_version: buffer.version().clone(),
|
||||
line_ending,
|
||||
edits: vec![(range, replacement.into())],
|
||||
};
|
||||
|
||||
return Some(diff);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn drop_lines_prefix<'a>(lines: &'a [&str], prefix_len: usize) -> Vec<&'a str> {
|
||||
lines
|
||||
.iter()
|
||||
.map(|line| line.get(prefix_len..).unwrap_or(""))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn lines_with_min_indent(input: &str) -> (Vec<&str>, usize) {
|
||||
let mut lines = Vec::new();
|
||||
let mut min_indent: Option<usize> = None;
|
||||
|
||||
for line in input.lines() {
|
||||
lines.push(line);
|
||||
if !line.trim().is_empty() {
|
||||
let indent = line.len() - line.trim_start().len();
|
||||
min_indent = Some(min_indent.map_or(indent, |m| m.min(indent)));
|
||||
}
|
||||
}
|
||||
|
||||
(lines, min_indent.unwrap_or(0))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod replace_exact_tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use gpui::prelude::*;
|
||||
|
||||
#[gpui::test]
|
||||
async fn basic(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, Some("let x = 42;".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn no_match(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", "let y = 42;", "let y = 43;").await;
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn multi_line(cx: &mut TestAppContext) {
|
||||
let whole = "fn example() {\n let x = 41;\n println!(\"x = {}\", x);\n}";
|
||||
let old_text = " let x = 41;\n println!(\"x = {}\", x);";
|
||||
let new_text = " let x = 42;\n println!(\"x = {}\", x);";
|
||||
let result = test_replace_exact(cx, whole, old_text, new_text).await;
|
||||
assert_eq!(
|
||||
result,
|
||||
Some("fn example() {\n let x = 42;\n println!(\"x = {}\", x);\n}".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn multiple_occurrences(cx: &mut TestAppContext) {
|
||||
let whole = "let x = 41;\nlet y = 41;\nlet z = 41;";
|
||||
let result = test_replace_exact(cx, whole, "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(
|
||||
result,
|
||||
Some("let x = 42;\nlet y = 41;\nlet z = 41;".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn empty_buffer(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "", "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn partial_match(cx: &mut TestAppContext) {
|
||||
let whole = "let x = 41; let y = 42;";
|
||||
let result = test_replace_exact(cx, whole, "let x = 41", "let x = 42").await;
|
||||
assert_eq!(result, Some("let x = 42; let y = 42;".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn whitespace_sensitive(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", " let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn entire_buffer(cx: &mut TestAppContext) {
|
||||
let result = test_replace_exact(cx, "let x = 41;", "let x = 41;", "let x = 42;").await;
|
||||
assert_eq!(result, Some("let x = 42;".to_string()));
|
||||
}
|
||||
|
||||
async fn test_replace_exact(
|
||||
cx: &mut TestAppContext,
|
||||
whole: &str,
|
||||
old: &str,
|
||||
new: &str,
|
||||
) -> Option<String> {
|
||||
let buffer = cx.new(|cx| language::Buffer::local(whole, cx));
|
||||
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact(old, new, &buffer_snapshot).await;
|
||||
diff.map(|diff| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod flexible_indent_tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use gpui::prelude::*;
|
||||
use unindent::Unindent;
|
||||
|
||||
#[gpui::test]
|
||||
fn test_underindented_single_line(cx: &mut TestAppContext) {
|
||||
let cur = " let a = 41;".to_string();
|
||||
let old = " let a = 41;".to_string();
|
||||
let new = " let a = 42;".to_string();
|
||||
let exp = " let a = 42;".to_string();
|
||||
|
||||
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
|
||||
|
||||
assert_eq!(result, Some(exp.to_string()))
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_overindented_single_line(cx: &mut TestAppContext) {
|
||||
let cur = " let a = 41;".to_string();
|
||||
let old = " let a = 41;".to_string();
|
||||
let new = " let a = 42;".to_string();
|
||||
let exp = " let a = 42;".to_string();
|
||||
|
||||
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
|
||||
|
||||
assert_eq!(result, Some(exp.to_string()))
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_underindented_multi_line(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
let y = 10;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 42;
|
||||
println!("New value: {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let expected = r#"
|
||||
fn test() {
|
||||
let x = 42;
|
||||
println!("New value: {}", x);
|
||||
let y = 10;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_overindented_multi_line(cx: &mut TestAppContext) {
|
||||
let cur = r#"
|
||||
fn foo() {
|
||||
let a = 41;
|
||||
let b = 3.13;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
// 6 space indent instead of 4
|
||||
let old = " let a = 41;\n let b = 3.13;";
|
||||
let new = " let a = 42;\n let b = 3.14;";
|
||||
|
||||
let expected = r#"
|
||||
fn foo() {
|
||||
let a = 42;
|
||||
let b = 3.14;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let result = test_replace_with_flexible_indent(cx, &cur, &old, &new);
|
||||
|
||||
assert_eq!(result, Some(expected.to_string()))
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_inconsistent_indentation(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
if condition {
|
||||
println!("{}", 43);
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
if condition {
|
||||
println!("{}", 43);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
if condition {
|
||||
println!("{}", 42);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_with_empty_lines(cx: &mut TestAppContext) {
|
||||
// Test with empty lines
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
|
||||
println!("x = {}", x);
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 10;
|
||||
|
||||
println!("New x: {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let expected = r#"
|
||||
fn test() {
|
||||
let x = 10;
|
||||
|
||||
println!("New x: {}", x);
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_no_match(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let y = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let y = 20;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_whole_ends_before_matching_old(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 10;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
// Should return None because whole doesn't fully contain the old text
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_whole_is_shorter_than_old(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
let x = 5;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
let y = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 5;
|
||||
let y = 20;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_old_is_empty(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
let x = 5;
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = "";
|
||||
let new = r#"
|
||||
let y = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_whole_is_empty(cx: &mut TestAppContext) {
|
||||
let whole = "";
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 10;
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lines_with_min_indent() {
|
||||
// Empty string
|
||||
assert_eq!(lines_with_min_indent(""), (vec![], 0));
|
||||
|
||||
// Single line without indentation
|
||||
assert_eq!(lines_with_min_indent("hello"), (vec!["hello"], 0));
|
||||
|
||||
// Multiple lines with no indentation
|
||||
assert_eq!(
|
||||
lines_with_min_indent("line1\nline2\nline3"),
|
||||
(vec!["line1", "line2", "line3"], 0)
|
||||
);
|
||||
|
||||
// Multiple lines with consistent indentation
|
||||
assert_eq!(
|
||||
lines_with_min_indent(" line1\n line2\n line3"),
|
||||
(vec![" line1", " line2", " line3"], 2)
|
||||
);
|
||||
|
||||
// Multiple lines with varying indentation
|
||||
assert_eq!(
|
||||
lines_with_min_indent(" line1\n line2\n line3"),
|
||||
(vec![" line1", " line2", " line3"], 2)
|
||||
);
|
||||
|
||||
// Lines with mixed indentation and empty lines
|
||||
assert_eq!(
|
||||
lines_with_min_indent(" line1\n\n line2"),
|
||||
(vec![" line1", "", " line2"], 2)
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_with_missing_indent_uneven_match(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
fn test() {
|
||||
if true {
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
let x = 5;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
let x = 42;
|
||||
println!("x = {}", x);
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let expected = r#"
|
||||
fn test() {
|
||||
if true {
|
||||
let x = 42;
|
||||
println!("x = {}", x);
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_replace_big_example(cx: &mut TestAppContext) {
|
||||
let whole = r#"
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let old = r#"
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
|
||||
let new = r#"
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_group_people_by_age() {
|
||||
let people = vec![
|
||||
Person::new("Young One", 5, "young@example.com").unwrap(),
|
||||
Person::new("Teen One", 15, "teen@example.com").unwrap(),
|
||||
Person::new("Teen Two", 18, "teen2@example.com").unwrap(),
|
||||
Person::new("Adult One", 25, "adult@example.com").unwrap(),
|
||||
];
|
||||
|
||||
let groups = group_people_by_age(&people);
|
||||
|
||||
assert_eq!(groups.get(&0).unwrap().len(), 1); // One person in 0-9
|
||||
assert_eq!(groups.get(&10).unwrap().len(), 2); // Two people in 10-19
|
||||
assert_eq!(groups.get(&20).unwrap().len(), 1); // One person in 20-29
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
let expected = r#"
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_age() {
|
||||
assert!(is_valid_age(0));
|
||||
assert!(!is_valid_age(151));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_group_people_by_age() {
|
||||
let people = vec![
|
||||
Person::new("Young One", 5, "young@example.com").unwrap(),
|
||||
Person::new("Teen One", 15, "teen@example.com").unwrap(),
|
||||
Person::new("Teen Two", 18, "teen2@example.com").unwrap(),
|
||||
Person::new("Adult One", 25, "adult@example.com").unwrap(),
|
||||
];
|
||||
|
||||
let groups = group_people_by_age(&people);
|
||||
|
||||
assert_eq!(groups.get(&0).unwrap().len(), 1); // One person in 0-9
|
||||
assert_eq!(groups.get(&10).unwrap().len(), 2); // Two people in 10-19
|
||||
assert_eq!(groups.get(&20).unwrap().len(), 1); // One person in 20-29
|
||||
}
|
||||
}
|
||||
"#
|
||||
.unindent();
|
||||
assert_eq!(
|
||||
test_replace_with_flexible_indent(cx, &whole, &old, &new),
|
||||
Some(expected.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_drop_lines_prefix() {
|
||||
// Empty array
|
||||
assert_eq!(drop_lines_prefix(&[], 2), Vec::<&str>::new());
|
||||
|
||||
// Zero prefix length
|
||||
assert_eq!(
|
||||
drop_lines_prefix(&["line1", "line2"], 0),
|
||||
vec!["line1", "line2"]
|
||||
);
|
||||
|
||||
// Normal prefix drop
|
||||
assert_eq!(
|
||||
drop_lines_prefix(&[" line1", " line2"], 2),
|
||||
vec!["line1", "line2"]
|
||||
);
|
||||
|
||||
// Prefix longer than some lines
|
||||
assert_eq!(drop_lines_prefix(&[" line1", "a"], 2), vec!["line1", ""]);
|
||||
|
||||
// Prefix longer than all lines
|
||||
assert_eq!(drop_lines_prefix(&["a", "b"], 5), vec!["", ""]);
|
||||
|
||||
// Mixed length lines
|
||||
assert_eq!(
|
||||
drop_lines_prefix(&[" line1", " line2", " line3"], 2),
|
||||
vec![" line1", "line2", " line3"]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_basic(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
assert_eq!(diff.edits.len(), 1);
|
||||
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42;");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_no_match(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let y = 42;", "let y = 43;", &snapshot).await;
|
||||
assert!(diff.is_none());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_multi_line(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| {
|
||||
language::Buffer::local(
|
||||
"fn example() {\n let x = 41;\n println!(\"x = {}\", x);\n}",
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let old_text = " let x = 41;\n println!(\"x = {}\", x);";
|
||||
let new_text = " let x = 42;\n println!(\"x = {}\", x);";
|
||||
let diff = replace_exact(old_text, new_text, &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
"fn example() {\n let x = 42;\n println!(\"x = {}\", x);\n}"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_multiple_occurrences(cx: &mut TestAppContext) {
|
||||
let buffer =
|
||||
cx.new(|cx| language::Buffer::local("let x = 41;\nlet y = 41;\nlet z = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// Should replace only the first occurrence
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42;\nlet y = 41;\nlet z = 41;");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_empty_buffer(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_none());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_partial_match(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41; let y = 42;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// Verify substring replacement actually works
|
||||
let diff = replace_exact("let x = 41", "let x = 42", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42; let y = 42;");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_whitespace_sensitive(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact(" let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_none());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_replace_exact_entire_buffer(cx: &mut TestAppContext) {
|
||||
let buffer = cx.new(|cx| language::Buffer::local("let x = 41;", cx));
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
let diff = replace_exact("let x = 41;", "let x = 42;", &snapshot).await;
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
let result = buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
});
|
||||
|
||||
assert_eq!(result, "let x = 42;");
|
||||
}
|
||||
|
||||
fn test_replace_with_flexible_indent(
|
||||
cx: &mut TestAppContext,
|
||||
whole: &str,
|
||||
old: &str,
|
||||
new: &str,
|
||||
) -> Option<String> {
|
||||
// Create a local buffer with the test content
|
||||
let buffer = cx.new(|cx| language::Buffer::local(whole, cx));
|
||||
|
||||
// Get the buffer snapshot
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
|
||||
// Call replace_flexible and transform the result
|
||||
replace_with_flexible_indent(old, new, &buffer_snapshot).map(|diff| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
let _ = buffer.apply_diff(diff, cx);
|
||||
buffer.text()
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
|
||||
use gpui::{App, AsyncApp, Entity, Task};
|
||||
use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
@@ -121,7 +121,6 @@ impl Tool for SymbolInfoTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
|
||||
@@ -141,7 +140,7 @@ impl Tool for SymbolInfoTool {
|
||||
};
|
||||
|
||||
action_log.update(cx, |action_log, cx| {
|
||||
action_log.track_buffer(buffer.clone(), cx);
|
||||
action_log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
let position = {
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use handlebars::Handlebars;
|
||||
use rust_embed::RustEmbed;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "src/templates"]
|
||||
#[include = "*.hbs"]
|
||||
struct Assets;
|
||||
|
||||
pub struct Templates(Handlebars<'static>);
|
||||
|
||||
impl Templates {
|
||||
pub fn new() -> Arc<Self> {
|
||||
let mut handlebars = Handlebars::new();
|
||||
handlebars.register_embed_templates::<Assets>().unwrap();
|
||||
handlebars.register_escape_fn(|text| text.into());
|
||||
Arc::new(Self(handlebars))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Template: Sized {
|
||||
const TEMPLATE_NAME: &'static str;
|
||||
|
||||
fn render(&self, templates: &Templates) -> Result<String>
|
||||
where
|
||||
Self: Serialize + Sized,
|
||||
{
|
||||
Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
You are an expert text editor. Taking the following file as an input:
|
||||
|
||||
```{{path}}
|
||||
{{file_content}}
|
||||
```
|
||||
|
||||
Your response must be a series of edits in the following format:
|
||||
|
||||
<edits>
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 2 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 2 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 3 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 3 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
Rules for editing:
|
||||
|
||||
- `old_text` represents full lines (including indentation) in the input file that will be replaced with `new_text`
|
||||
- It is crucial that `old_text` is unique and unambiguous.
|
||||
- Always include enough context around the lines you want to replace in `old_text` such that it's impossible to mistake them for other lines.
|
||||
- If you want to replace many occurrences of the same text, repeat the same `old_text`/`new_text` pair multiple times and I will apply them sequentially, one occurrence at a time.
|
||||
- Don't explain why you made a change, just report the edits.
|
||||
- Never do MORE than what the user has requested.
|
||||
- Never do LESS than what the user has requested.
|
||||
|
||||
<user_instructions>
|
||||
{{instructions}}
|
||||
</user_instructions>
|
||||
|
||||
<edits>
|
||||
@@ -3,7 +3,7 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use futures::io::BufReader;
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, FutureExt};
|
||||
use gpui::{AnyWindowHandle, App, AppContext, Entity, Task};
|
||||
use gpui::{App, AppContext, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -78,7 +78,6 @@ impl Tool for TerminalTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input: TerminalToolInput = match serde_json::from_value(input) {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use crate::schema::json_schema_for;
|
||||
use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolResult};
|
||||
use gpui::{AnyWindowHandle, App, Entity, Task};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
@@ -50,7 +50,6 @@ impl Tool for ThinkingTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
_cx: &mut App,
|
||||
) -> ToolResult {
|
||||
// This tool just "thinks out loud" and doesn't perform any actions.
|
||||
|
||||
@@ -5,16 +5,13 @@ use crate::ui::ToolCallCardHeader;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
|
||||
use futures::{Future, FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
};
|
||||
use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
|
||||
use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::{IconName, Tooltip, prelude::*};
|
||||
use web_search::WebSearchRegistry;
|
||||
use workspace::Workspace;
|
||||
use zed_llm_client::{WebSearchCitation, WebSearchResponse};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -57,7 +54,6 @@ impl Tool for WebSearchTool {
|
||||
_messages: &[LanguageModelRequestMessage],
|
||||
_project: Entity<Project>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
) -> ToolResult {
|
||||
let input = match serde_json::from_value::<WebSearchToolInput>(input) {
|
||||
@@ -115,7 +111,6 @@ impl ToolCard for WebSearchToolCard {
|
||||
&mut self,
|
||||
_status: &ToolUseStatus,
|
||||
_window: &mut Window,
|
||||
_workspace: WeakEntity<Workspace>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let header = match self.response.as_ref() {
|
||||
@@ -225,13 +220,8 @@ impl Component for WebSearchTool {
|
||||
div()
|
||||
.size_full()
|
||||
.child(in_progress_search.update(cx, |tool, cx| {
|
||||
tool.render(
|
||||
&ToolUseStatus::Pending,
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
tool.render(&ToolUseStatus::Pending, window, cx)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
@@ -240,13 +230,8 @@ impl Component for WebSearchTool {
|
||||
div()
|
||||
.size_full()
|
||||
.child(successful_search.update(cx, |tool, cx| {
|
||||
tool.render(
|
||||
&ToolUseStatus::Finished("".into()),
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
tool.render(&ToolUseStatus::Finished("".into()), window, cx)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
@@ -255,13 +240,8 @@ impl Component for WebSearchTool {
|
||||
div()
|
||||
.size_full()
|
||||
.child(error_search.update(cx, |tool, cx| {
|
||||
tool.render(
|
||||
&ToolUseStatus::Error("".into()),
|
||||
window,
|
||||
WeakEntity::new_invalid(),
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
tool.render(&ToolUseStatus::Error("".into()), window, cx)
|
||||
.into_any_element()
|
||||
}))
|
||||
.into_any_element(),
|
||||
),
|
||||
|
||||
@@ -118,13 +118,6 @@ impl Settings for AutoUpdateSetting {
|
||||
|
||||
Ok(Self(auto_update.0))
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
vscode.enum_setting("update.mode", current, |s| match s {
|
||||
"none" | "manual" => Some(AutoUpdateSettingContent(false)),
|
||||
_ => Some(AutoUpdateSettingContent(true)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
||||
@@ -32,6 +32,4 @@ impl Settings for CallSettings {
|
||||
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
|
||||
sources.json_merge()
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
||||
@@ -104,8 +104,6 @@ impl Settings for ClientSettings {
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {}
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -132,10 +130,6 @@ impl Settings for ProxySettings {
|
||||
.or(sources.default.proxy.clone()),
|
||||
})
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
vscode.string_setting("http.proxy", &mut current.proxy);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init_settings(cx: &mut App) {
|
||||
@@ -524,18 +518,6 @@ impl settings::Settings for TelemetrySettings {
|
||||
.unwrap_or(sources.default.metrics.ok_or_else(Self::missing_default)?),
|
||||
})
|
||||
}
|
||||
|
||||
fn import_from_vscode(vscode: &settings::VsCodeSettings, current: &mut Self::FileContent) {
|
||||
vscode.enum_setting("telemetry.telemetryLevel", &mut current.metrics, |s| {
|
||||
Some(s == "all")
|
||||
});
|
||||
vscode.enum_setting("telemetry.telemetryLevel", &mut current.diagnostics, |s| {
|
||||
Some(matches!(s, "all" | "error" | "crash"))
|
||||
});
|
||||
// we could translate telemetry.telemetryLevel, but just because users didn't want
|
||||
// to send microsoft telemetry doesn't mean they don't want to send it to zed. their
|
||||
// all/error/crash/off correspond to combinations of our "diagnostics" and "metrics".
|
||||
}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
|
||||
@@ -34,12 +34,14 @@ dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
envy = "0.4.2"
|
||||
futures.workspace = true
|
||||
google_ai.workspace = true
|
||||
hex.workspace = true
|
||||
http_client.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
livekit_api.workspace = true
|
||||
log.workspace = true
|
||||
nanoid.workspace = true
|
||||
open_ai.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prometheus = "0.14"
|
||||
prost.workspace = true
|
||||
@@ -126,7 +128,6 @@ serde_json.workspace = true
|
||||
session = { workspace = true, features = ["test-support"] }
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
sqlx = { version = "0.8", features = ["sqlite"] }
|
||||
task.workspace = true
|
||||
theme.workspace = true
|
||||
unindent.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -492,8 +492,7 @@ CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
user_id INTEGER NOT NULL REFERENCES users (id),
|
||||
has_overdue_invoices BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
stripe_customer_id TEXT NOT NULL,
|
||||
trial_started_at TIMESTAMP
|
||||
stripe_customer_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "uix_billing_customers_on_user_id" ON billing_customers (user_id);
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
alter table billing_customers
|
||||
add column trial_started_at timestamp without time zone;
|
||||
@@ -1,2 +0,0 @@
|
||||
alter table project_repositories
|
||||
add column head_commit_details varchar;
|
||||
@@ -287,7 +287,7 @@ async fn create_billing_subscription(
|
||||
}
|
||||
}
|
||||
|
||||
let customer_id = if let Some(existing_customer) = &existing_billing_customer {
|
||||
let customer_id = if let Some(existing_customer) = existing_billing_customer {
|
||||
CustomerId::from_str(&existing_customer.stripe_customer_id)
|
||||
.context("failed to parse customer ID")?
|
||||
} else {
|
||||
@@ -320,15 +320,6 @@ async fn create_billing_subscription(
|
||||
.await?
|
||||
}
|
||||
Some(ProductCode::ZedProTrial) => {
|
||||
if let Some(existing_billing_customer) = &existing_billing_customer {
|
||||
if existing_billing_customer.trial_started_at.is_some() {
|
||||
return Err(Error::http(
|
||||
StatusCode::FORBIDDEN,
|
||||
"user already used free trial".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
stripe_billing
|
||||
.checkout_with_zed_pro_trial(
|
||||
app.config.zed_pro_price_id()?,
|
||||
@@ -826,24 +817,6 @@ async fn handle_customer_subscription_event(
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("billing customer not found"))?;
|
||||
|
||||
if let Some(SubscriptionKind::ZedProTrial) = subscription_kind {
|
||||
if subscription.status == SubscriptionStatus::Trialing {
|
||||
let current_period_start =
|
||||
DateTime::from_timestamp(subscription.current_period_start, 0)
|
||||
.ok_or_else(|| anyhow!("No trial subscription period start"))?;
|
||||
|
||||
app.db
|
||||
.update_billing_customer(
|
||||
billing_customer.id,
|
||||
&UpdateBillingCustomerParams {
|
||||
trial_started_at: ActiveValue::set(Some(current_period_start.naive_utc())),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let was_canceled_due_to_payment_failure = subscription.status == SubscriptionStatus::Canceled
|
||||
&& subscription
|
||||
.cancellation_details
|
||||
@@ -870,28 +843,6 @@ async fn handle_customer_subscription_event(
|
||||
.get_billing_subscription_by_stripe_subscription_id(&subscription.id)
|
||||
.await?
|
||||
{
|
||||
let llm_db = app
|
||||
.llm_db
|
||||
.clone()
|
||||
.ok_or_else(|| anyhow!("LLM DB not initialized"))?;
|
||||
|
||||
let new_period_start_at =
|
||||
chrono::DateTime::from_timestamp(subscription.current_period_start, 0)
|
||||
.ok_or_else(|| anyhow!("No subscription period start"))?;
|
||||
let new_period_end_at =
|
||||
chrono::DateTime::from_timestamp(subscription.current_period_end, 0)
|
||||
.ok_or_else(|| anyhow!("No subscription period end"))?;
|
||||
|
||||
llm_db
|
||||
.transfer_existing_subscription_usage(
|
||||
billing_customer.user_id,
|
||||
&existing_subscription,
|
||||
subscription_kind,
|
||||
new_period_start_at,
|
||||
new_period_end_at,
|
||||
)
|
||||
.await?;
|
||||
|
||||
app.db
|
||||
.update_billing_subscription(
|
||||
existing_subscription.id,
|
||||
|
||||
@@ -14,6 +14,7 @@ pub mod messages;
|
||||
pub mod notifications;
|
||||
pub mod processed_stripe_events;
|
||||
pub mod projects;
|
||||
pub mod rate_buckets;
|
||||
pub mod rooms;
|
||||
pub mod servers;
|
||||
pub mod users;
|
||||
|
||||
@@ -11,7 +11,6 @@ pub struct UpdateBillingCustomerParams {
|
||||
pub user_id: ActiveValue<UserId>,
|
||||
pub stripe_customer_id: ActiveValue<String>,
|
||||
pub has_overdue_invoices: ActiveValue<bool>,
|
||||
pub trial_started_at: ActiveValue<Option<DateTime>>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
@@ -46,8 +45,7 @@ impl Database {
|
||||
user_id: params.user_id.clone(),
|
||||
stripe_customer_id: params.stripe_customer_id.clone(),
|
||||
has_overdue_invoices: params.has_overdue_invoices.clone(),
|
||||
trial_started_at: params.trial_started_at.clone(),
|
||||
created_at: ActiveValue::not_set(),
|
||||
..Default::default()
|
||||
})
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
58
crates/collab/src/db/queries/rate_buckets.rs
Normal file
58
crates/collab/src/db/queries/rate_buckets.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use super::*;
|
||||
use crate::db::tables::rate_buckets;
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
impl Database {
|
||||
/// Saves the rate limit for the given user and rate limit name if the last_refill is later
|
||||
/// than the currently saved timestamp.
|
||||
pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
|
||||
if buckets.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.transaction(|tx| async move {
|
||||
rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
|
||||
rate_buckets::ActiveModel {
|
||||
user_id: ActiveValue::Set(bucket.user_id),
|
||||
rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
|
||||
token_count: ActiveValue::Set(bucket.token_count),
|
||||
last_refill: ActiveValue::Set(bucket.last_refill),
|
||||
}
|
||||
}))
|
||||
.on_conflict(
|
||||
OnConflict::columns([
|
||||
rate_buckets::Column::UserId,
|
||||
rate_buckets::Column::RateLimitName,
|
||||
])
|
||||
.update_columns([
|
||||
rate_buckets::Column::TokenCount,
|
||||
rate_buckets::Column::LastRefill,
|
||||
])
|
||||
.to_owned(),
|
||||
)
|
||||
.exec(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retrieves the rate limit for the given user and rate limit name.
|
||||
pub async fn get_rate_bucket(
|
||||
&self,
|
||||
user_id: UserId,
|
||||
rate_limit_name: &str,
|
||||
) -> Result<Option<rate_buckets::Model>> {
|
||||
self.transaction(|tx| async move {
|
||||
let rate_limit = rate_buckets::Entity::find()
|
||||
.filter(rate_buckets::Column::UserId.eq(user_id))
|
||||
.filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
|
||||
.one(&*tx)
|
||||
.await?;
|
||||
|
||||
Ok(rate_limit)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@ pub mod project;
|
||||
pub mod project_collaborator;
|
||||
pub mod project_repository;
|
||||
pub mod project_repository_statuses;
|
||||
pub mod rate_buckets;
|
||||
pub mod room;
|
||||
pub mod room_participant;
|
||||
pub mod server;
|
||||
|
||||
@@ -10,7 +10,6 @@ pub struct Model {
|
||||
pub user_id: UserId,
|
||||
pub stripe_customer_id: String,
|
||||
pub has_overdue_invoices: bool,
|
||||
pub trial_started_at: Option<DateTime>,
|
||||
pub created_at: DateTime,
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user