Compare commits
67 Commits
resource-l
...
message-ed
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d36304963e | ||
|
|
0d71351b02 | ||
|
|
b06fe288f3 | ||
|
|
4a35498829 | ||
|
|
fd0ffb737f | ||
|
|
e52f148304 | ||
|
|
cb0bc463f1 | ||
|
|
9a375f1419 | ||
|
|
4238e640fa | ||
|
|
0b9c9f5f2d | ||
|
|
2da80e4641 | ||
|
|
d9a94a5496 | ||
|
|
a7442d8880 | ||
|
|
6c1f19571a | ||
|
|
23cd5b59b2 | ||
|
|
f4b0332f78 | ||
|
|
abde7306e3 | ||
|
|
2b3dbe8815 | ||
|
|
7f1a5c6ad7 | ||
|
|
6307105976 | ||
|
|
8d63312eca | ||
|
|
81474a3de0 | ||
|
|
db497ac867 | ||
|
|
8ff2e3e195 | ||
|
|
96093aa465 | ||
|
|
dc87f4b32e | ||
|
|
1957e1f642 | ||
|
|
d78bd8f1d7 | ||
|
|
32975c4208 | ||
|
|
658d56bd72 | ||
|
|
13a2c53381 | ||
|
|
cd234e28ce | ||
|
|
b564b1d5d0 | ||
|
|
48ae02c1ca | ||
|
|
255bb0a3f8 | ||
|
|
628b1058be | ||
|
|
7167f193c0 | ||
|
|
7ff0f1525e | ||
|
|
7df8e05ad9 | ||
|
|
d030bb6281 | ||
|
|
b62f959528 | ||
|
|
3a04657730 | ||
|
|
42b7dbeaee | ||
|
|
bfbb18476f | ||
|
|
978b75bba9 | ||
|
|
1f20d5bf54 | ||
|
|
9de04ce215 | ||
|
|
d8fc53608e | ||
|
|
39c19abdfd | ||
|
|
b105028c05 | ||
|
|
d2162446d0 | ||
|
|
360d4db87c | ||
|
|
44953375cc | ||
|
|
2444321756 | ||
|
|
13bf45dd4a | ||
|
|
b61b71405d | ||
|
|
cc5eb24066 | ||
|
|
52a9101970 | ||
|
|
1a798830cb | ||
|
|
481e3e5092 | ||
|
|
b35e69692d | ||
|
|
add67bde43 | ||
|
|
fa3d0aaed4 | ||
|
|
094e878ccf | ||
|
|
54d4665100 | ||
|
|
2c84e33b7b | ||
|
|
bb6ea22944 |
35
.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml
vendored
Normal file
35
.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Bug Report (Windows Alpha)
|
||||
description: Zed Windows Alpha Related Bugs
|
||||
type: "Bug"
|
||||
labels: ["windows"]
|
||||
title: "Windows Alpha: <a short description of the Windows bug>"
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Summary
|
||||
description: Describe the bug with a one-line summary, and provide detailed reproduction steps
|
||||
value: |
|
||||
<!-- Please insert a one-line summary of the issue below -->
|
||||
SUMMARY_SENTENCE_HERE
|
||||
|
||||
### Description
|
||||
<!-- Describe with sufficient detail to reproduce from a clean Zed install. -->
|
||||
Steps to trigger the problem:
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
**Expected Behavior**:
|
||||
**Actual Behavior**:
|
||||
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: environment
|
||||
attributes:
|
||||
label: Zed Version and System Specs
|
||||
description: 'Open Zed, and in the command palette select "zed: copy system specs into clipboard"'
|
||||
placeholder: |
|
||||
Output of "zed: copy system specs into clipboard"
|
||||
validations:
|
||||
required: true
|
||||
49
Cargo.lock
generated
49
Cargo.lock
generated
@@ -10,6 +10,7 @@ dependencies = [
|
||||
"agent-client-protocol",
|
||||
"anyhow",
|
||||
"buffer_diff",
|
||||
"collections",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"futures 0.3.31",
|
||||
@@ -17,7 +18,6 @@ dependencies = [
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"markdown",
|
||||
"parking_lot",
|
||||
"project",
|
||||
@@ -29,7 +29,10 @@ dependencies = [
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"ui",
|
||||
"url",
|
||||
"util",
|
||||
"uuid",
|
||||
"watch",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -196,6 +199,7 @@ dependencies = [
|
||||
"clock",
|
||||
"cloud_llm_client",
|
||||
"collections",
|
||||
"context_server",
|
||||
"ctor",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
@@ -204,6 +208,8 @@ dependencies = [
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"handlebars 4.5.0",
|
||||
"html_to_markdown",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
@@ -227,6 +233,7 @@ dependencies = [
|
||||
"task",
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"text",
|
||||
"theme",
|
||||
"tree-sitter-rust",
|
||||
"ui",
|
||||
@@ -6440,6 +6447,7 @@ dependencies = [
|
||||
"log",
|
||||
"parking_lot",
|
||||
"pretty_assertions",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rope",
|
||||
"schemars",
|
||||
@@ -11144,14 +11152,13 @@ dependencies = [
|
||||
"ai_onboarding",
|
||||
"anyhow",
|
||||
"client",
|
||||
"command_palette_hooks",
|
||||
"component",
|
||||
"db",
|
||||
"documented",
|
||||
"editor",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"fuzzy",
|
||||
"git",
|
||||
"gpui",
|
||||
"itertools 0.14.0",
|
||||
"language",
|
||||
@@ -11163,6 +11170,7 @@ dependencies = [
|
||||
"schemars",
|
||||
"serde",
|
||||
"settings",
|
||||
"telemetry",
|
||||
"theme",
|
||||
"ui",
|
||||
"util",
|
||||
@@ -11238,6 +11246,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"log",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -18023,6 +18032,7 @@ dependencies = [
|
||||
"command_palette_hooks",
|
||||
"db",
|
||||
"editor",
|
||||
"env_logger 0.11.8",
|
||||
"futures 0.3.31",
|
||||
"git_ui",
|
||||
"gpui",
|
||||
@@ -18876,33 +18886,6 @@ version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
|
||||
|
||||
[[package]]
|
||||
name = "welcome"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"component",
|
||||
"db",
|
||||
"documented",
|
||||
"editor",
|
||||
"fuzzy",
|
||||
"gpui",
|
||||
"install_cli",
|
||||
"language",
|
||||
"picker",
|
||||
"project",
|
||||
"serde",
|
||||
"settings",
|
||||
"telemetry",
|
||||
"ui",
|
||||
"util",
|
||||
"vim_mode_setting",
|
||||
"workspace",
|
||||
"workspace-hack",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
@@ -20517,7 +20500,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.200.0"
|
||||
version = "0.201.0"
|
||||
dependencies = [
|
||||
"activity_indicator",
|
||||
"agent",
|
||||
@@ -20657,7 +20640,6 @@ dependencies = [
|
||||
"watch",
|
||||
"web_search",
|
||||
"web_search_providers",
|
||||
"welcome",
|
||||
"windows 0.61.1",
|
||||
"winresource",
|
||||
"workspace",
|
||||
@@ -20681,7 +20663,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed_emmet"
|
||||
version = "0.0.4"
|
||||
version = "0.0.6"
|
||||
dependencies = [
|
||||
"zed_extension_api 0.1.0",
|
||||
]
|
||||
@@ -20920,6 +20902,7 @@ dependencies = [
|
||||
"menu",
|
||||
"postage",
|
||||
"project",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"release_channel",
|
||||
"reqwest_client",
|
||||
|
||||
@@ -185,7 +185,6 @@ members = [
|
||||
"crates/watch",
|
||||
"crates/web_search",
|
||||
"crates/web_search_providers",
|
||||
"crates/welcome",
|
||||
"crates/workspace",
|
||||
"crates/worktree",
|
||||
"crates/x_ai",
|
||||
@@ -412,7 +411,6 @@ vim_mode_setting = { path = "crates/vim_mode_setting" }
|
||||
watch = { path = "crates/watch" }
|
||||
web_search = { path = "crates/web_search" }
|
||||
web_search_providers = { path = "crates/web_search_providers" }
|
||||
welcome = { path = "crates/welcome" }
|
||||
workspace = { path = "crates/workspace" }
|
||||
worktree = { path = "crates/worktree" }
|
||||
x_ai = { path = "crates/x_ai" }
|
||||
@@ -566,6 +564,7 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
|
||||
"socks",
|
||||
"stream",
|
||||
] }
|
||||
rodio = { version = "0.21.1", default-features = false }
|
||||
rsa = "0.9.6"
|
||||
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
|
||||
"async-dispatcher-runtime",
|
||||
@@ -714,6 +713,7 @@ features = [
|
||||
"Win32_System_LibraryLoader",
|
||||
"Win32_System_Memory",
|
||||
"Win32_System_Ole",
|
||||
"Win32_System_Performance",
|
||||
"Win32_System_Pipes",
|
||||
"Win32_System_SystemInformation",
|
||||
"Win32_System_SystemServices",
|
||||
|
||||
@@ -239,6 +239,7 @@
|
||||
"ctrl-shift-a": "agent::ToggleContextPicker",
|
||||
"ctrl-shift-j": "agent::ToggleNavigationMenu",
|
||||
"ctrl-shift-i": "agent::ToggleOptionsMenu",
|
||||
"ctrl-alt-shift-n": "agent::ToggleNewThreadMenu",
|
||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||
"ctrl->": "assistant::QuoteSelection",
|
||||
"ctrl-alt-e": "agent::RemoveAllContext",
|
||||
@@ -330,8 +331,6 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"up": "agent::PreviousHistoryMessage",
|
||||
"down": "agent::NextHistoryMessage",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"ctrl-shift-y": "agent::KeepAll",
|
||||
"ctrl-shift-n": "agent::RejectAll"
|
||||
|
||||
@@ -279,6 +279,7 @@
|
||||
"cmd-shift-a": "agent::ToggleContextPicker",
|
||||
"cmd-shift-j": "agent::ToggleNavigationMenu",
|
||||
"cmd-shift-i": "agent::ToggleOptionsMenu",
|
||||
"cmd-alt-shift-n": "agent::ToggleNewThreadMenu",
|
||||
"shift-alt-escape": "agent::ExpandMessageEditor",
|
||||
"cmd->": "assistant::QuoteSelection",
|
||||
"cmd-alt-e": "agent::RemoveAllContext",
|
||||
@@ -382,8 +383,6 @@
|
||||
"use_key_equivalents": true,
|
||||
"bindings": {
|
||||
"enter": "agent::Chat",
|
||||
"up": "agent::PreviousHistoryMessage",
|
||||
"down": "agent::NextHistoryMessage",
|
||||
"shift-ctrl-r": "agent::OpenAgentDiff",
|
||||
"cmd-shift-y": "agent::KeepAll",
|
||||
"cmd-shift-n": "agent::RejectAll"
|
||||
|
||||
@@ -333,10 +333,14 @@
|
||||
"ctrl-x ctrl-c": "editor::ShowEditPrediction", // zed specific
|
||||
"ctrl-x ctrl-l": "editor::ToggleCodeActions", // zed specific
|
||||
"ctrl-x ctrl-z": "editor::Cancel",
|
||||
"ctrl-x ctrl-e": "vim::LineDown",
|
||||
"ctrl-x ctrl-y": "vim::LineUp",
|
||||
"ctrl-w": "editor::DeleteToPreviousWordStart",
|
||||
"ctrl-u": "editor::DeleteToBeginningOfLine",
|
||||
"ctrl-t": "vim::Indent",
|
||||
"ctrl-d": "vim::Outdent",
|
||||
"ctrl-y": "vim::InsertFromAbove",
|
||||
"ctrl-e": "vim::InsertFromBelow",
|
||||
"ctrl-k": ["vim::PushDigraph", {}],
|
||||
"ctrl-v": ["vim::PushLiteral", {}],
|
||||
"ctrl-shift-v": "editor::Paste", // note: this is *very* similar to ctrl-v in vim, but ctrl-shift-v on linux is the typical shortcut for paste when ctrl-v is already in use.
|
||||
|
||||
@@ -82,10 +82,10 @@
|
||||
// Layout mode of the bottom dock. Defaults to "contained"
|
||||
// choices: contained, full, left_aligned, right_aligned
|
||||
"bottom_dock_layout": "contained",
|
||||
// The direction that you want to split panes horizontally. Defaults to "up"
|
||||
"pane_split_direction_horizontal": "up",
|
||||
// The direction that you want to split panes vertically. Defaults to "left"
|
||||
"pane_split_direction_vertical": "left",
|
||||
// The direction that you want to split panes horizontally. Defaults to "down"
|
||||
"pane_split_direction_horizontal": "down",
|
||||
// The direction that you want to split panes vertically. Defaults to "right"
|
||||
"pane_split_direction_vertical": "right",
|
||||
// Centered layout related settings.
|
||||
"centered_layout": {
|
||||
// The relative width of the left padding of the central pane from the
|
||||
|
||||
@@ -20,12 +20,12 @@ action_log.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
collections.workspace = true
|
||||
editor.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
markdown.workspace = true
|
||||
project.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -34,7 +34,10 @@ settings.workspace = true
|
||||
smol.workspace = true
|
||||
terminal.workspace = true
|
||||
ui.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,18 +1,78 @@
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use crate::AcpThread;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language_model::LanguageModel;
|
||||
use collections::IndexMap;
|
||||
use gpui::{AsyncApp, Entity, SharedString, Task};
|
||||
use project::Project;
|
||||
use ui::App;
|
||||
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::AcpThread;
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct UserMessageId(Arc<str>);
|
||||
|
||||
impl UserMessageId {
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string().into())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
user_message_id: Option<UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
_session_id: &acp::SessionId,
|
||||
_cx: &mut App,
|
||||
) -> Option<Rc<dyn AgentSessionEditor>> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AgentSessionEditor {
|
||||
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for agents that support listing, selecting, and querying language models.
|
||||
///
|
||||
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
|
||||
pub trait ModelSelector: 'static {
|
||||
pub trait AgentModelSelector: 'static {
|
||||
/// Lists all available language models for this agent.
|
||||
///
|
||||
/// # Parameters
|
||||
@@ -20,7 +80,7 @@ pub trait ModelSelector: 'static {
|
||||
///
|
||||
/// # Returns
|
||||
/// A task resolving to the list of models or an error (e.g., if no models are configured).
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
|
||||
|
||||
/// Selects a model for a specific session (thread).
|
||||
///
|
||||
@@ -37,8 +97,8 @@ pub trait ModelSelector: 'static {
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
model_id: AgentModelId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>>;
|
||||
|
||||
/// Retrieves the currently selected model for a specific session (thread).
|
||||
@@ -52,42 +112,51 @@ pub trait ModelSelector: 'static {
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>>;
|
||||
cx: &mut App,
|
||||
) -> Task<Result<AgentModelInfo>>;
|
||||
|
||||
/// Whenever the model list is updated the receiver will be notified.
|
||||
fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
project: Entity<Project>,
|
||||
cwd: &Path,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Entity<AcpThread>>>;
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct AgentModelId(pub SharedString);
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod];
|
||||
impl std::ops::Deref for AgentModelId {
|
||||
type Target = SharedString;
|
||||
|
||||
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
|
||||
|
||||
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
|
||||
-> Task<Result<acp::PromptResponse>>;
|
||||
|
||||
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
|
||||
|
||||
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
|
||||
///
|
||||
/// If the agent does not support model selection, returns [None].
|
||||
/// This allows sharing the selector in UI components.
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
None // Default impl for agents that don't support it
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthRequired;
|
||||
|
||||
impl Error for AuthRequired {}
|
||||
impl fmt::Display for AuthRequired {
|
||||
impl fmt::Display for AgentModelId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "AuthRequired")
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentModelInfo {
|
||||
pub id: AgentModelId,
|
||||
pub name: SharedString,
|
||||
pub icon: Option<IconName>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct AgentModelGroupName(pub SharedString);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AgentModelList {
|
||||
Flat(Vec<AgentModelInfo>),
|
||||
Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
|
||||
}
|
||||
|
||||
impl AgentModelList {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
AgentModelList::Flat(models) => models.is_empty(),
|
||||
AgentModelList::Grouped(groups) => groups.is_empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
125
crates/acp_thread/src/mention.rs
Normal file
125
crates/acp_thread/src/mention.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Result, bail};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum MentionUri {
|
||||
File(PathBuf),
|
||||
Symbol(PathBuf, String),
|
||||
Thread(acp::SessionId),
|
||||
Rule(String),
|
||||
}
|
||||
|
||||
impl MentionUri {
|
||||
pub fn parse(input: &str) -> Result<Self> {
|
||||
let url = url::Url::parse(input)?;
|
||||
let path = url.path();
|
||||
match url.scheme() {
|
||||
"file" => {
|
||||
if let Some(fragment) = url.fragment() {
|
||||
Ok(Self::Symbol(path.into(), fragment.into()))
|
||||
} else {
|
||||
let file_path =
|
||||
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
|
||||
|
||||
Ok(Self::File(file_path))
|
||||
}
|
||||
}
|
||||
"zed" => {
|
||||
if let Some(thread) = path.strip_prefix("/agent/thread/") {
|
||||
Ok(Self::Thread(acp::SessionId(thread.into())))
|
||||
} else if let Some(rule) = path.strip_prefix("/agent/rule/") {
|
||||
Ok(Self::Rule(rule.into()))
|
||||
} else {
|
||||
bail!("invalid zed url: {:?}", input);
|
||||
}
|
||||
}
|
||||
other => bail!("unrecognized scheme {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
MentionUri::File(path) => path.file_name().unwrap().to_string_lossy().into_owned(),
|
||||
MentionUri::Symbol(_path, name) => name.clone(),
|
||||
MentionUri::Thread(thread) => thread.to_string(),
|
||||
MentionUri::Rule(rule) => rule.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_link(&self) -> String {
|
||||
let name = self.name();
|
||||
let uri = self.to_uri();
|
||||
format!("[{name}]({uri})")
|
||||
}
|
||||
|
||||
pub fn to_uri(&self) -> String {
|
||||
match self {
|
||||
MentionUri::File(path) => {
|
||||
format!("file://{}", path.display())
|
||||
}
|
||||
MentionUri::Symbol(path, name) => {
|
||||
format!("file://{}#{}", path.display(), name)
|
||||
}
|
||||
MentionUri::Thread(thread) => {
|
||||
format!("zed:///agent/thread/{}", thread.0)
|
||||
}
|
||||
MentionUri::Rule(rule) => {
|
||||
format!("zed:///agent/rule/{}", rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mention_uri_parse_and_display() {
|
||||
// Test file URI
|
||||
let file_uri = "file:///path/to/file.rs";
|
||||
let parsed = MentionUri::parse(file_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"),
|
||||
_ => panic!("Expected File variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), file_uri);
|
||||
|
||||
// Test symbol URI
|
||||
let symbol_uri = "file:///path/to/file.rs#MySymbol";
|
||||
let parsed = MentionUri::parse(symbol_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Symbol(path, symbol) => {
|
||||
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
|
||||
assert_eq!(symbol, "MySymbol");
|
||||
}
|
||||
_ => panic!("Expected Symbol variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), symbol_uri);
|
||||
|
||||
// Test thread URI
|
||||
let thread_uri = "zed:///agent/thread/session123";
|
||||
let parsed = MentionUri::parse(thread_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Thread(session_id) => assert_eq!(session_id.0.as_ref(), "session123"),
|
||||
_ => panic!("Expected Thread variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), thread_uri);
|
||||
|
||||
// Test rule URI
|
||||
let rule_uri = "zed:///agent/rule/my_rule";
|
||||
let parsed = MentionUri::parse(rule_uri).unwrap();
|
||||
match &parsed {
|
||||
MentionUri::Rule(rule) => assert_eq!(rule, "my_rule"),
|
||||
_ => panic!("Expected Rule variant"),
|
||||
}
|
||||
assert_eq!(parsed.to_uri(), rule_uri);
|
||||
|
||||
// Test invalid scheme
|
||||
assert!(MentionUri::parse("http://example.com").is_err());
|
||||
|
||||
// Test invalid zed path
|
||||
assert!(MentionUri::parse("zed:///invalid/path").is_err());
|
||||
}
|
||||
}
|
||||
@@ -29,8 +29,14 @@ impl Terminal {
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
command: cx
|
||||
.new(|cx| Markdown::new(command.into(), Some(language_registry.clone()), None, cx)),
|
||||
command: cx.new(|cx| {
|
||||
Markdown::new(
|
||||
format!("```\n{}\n```", command).into(),
|
||||
Some(language_registry.clone()),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
working_dir,
|
||||
terminal,
|
||||
started_at: Instant::now(),
|
||||
|
||||
@@ -17,8 +17,6 @@ use util::{
|
||||
pub struct ActionLog {
|
||||
/// Buffers that we want to notify the model about when they change.
|
||||
tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
|
||||
/// Has the model edited a file since it last checked diagnostics?
|
||||
edited_since_project_diagnostics_check: bool,
|
||||
/// The project this action log is associated with
|
||||
project: Entity<Project>,
|
||||
}
|
||||
@@ -28,7 +26,6 @@ impl ActionLog {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self {
|
||||
tracked_buffers: BTreeMap::default(),
|
||||
edited_since_project_diagnostics_check: false,
|
||||
project,
|
||||
}
|
||||
}
|
||||
@@ -37,16 +34,6 @@ impl ActionLog {
|
||||
&self.project
|
||||
}
|
||||
|
||||
/// Notifies a diagnostics check
|
||||
pub fn checked_project_diagnostics(&mut self) {
|
||||
self.edited_since_project_diagnostics_check = false;
|
||||
}
|
||||
|
||||
/// Returns true if any files have been edited since the last project diagnostics check
|
||||
pub fn has_edited_files_since_project_diagnostics_check(&self) -> bool {
|
||||
self.edited_since_project_diagnostics_check
|
||||
}
|
||||
|
||||
pub fn latest_snapshot(&self, buffer: &Entity<Buffer>) -> Option<text::BufferSnapshot> {
|
||||
Some(self.tracked_buffers.get(buffer)?.snapshot.clone())
|
||||
}
|
||||
@@ -543,14 +530,11 @@ impl ActionLog {
|
||||
|
||||
/// Mark a buffer as created by agent, so we can refresh it in the context
|
||||
pub fn buffer_created(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
self.track_buffer_internal(buffer.clone(), true, cx);
|
||||
}
|
||||
|
||||
/// Mark a buffer as edited by agent, so we can refresh it in the context
|
||||
pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
|
||||
self.edited_since_project_diagnostics_check = true;
|
||||
|
||||
let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
|
||||
if let TrackedBufferStatus::Deleted = tracked_buffer.status {
|
||||
tracked_buffer.status = TrackedBufferStatus::Modified;
|
||||
|
||||
@@ -716,18 +716,10 @@ impl ActivityIndicator {
|
||||
})),
|
||||
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||
}),
|
||||
AutoUpdateStatus::Updated {
|
||||
binary_path,
|
||||
version,
|
||||
} => Some(Content {
|
||||
AutoUpdateStatus::Updated { version } => Some(Content {
|
||||
icon: None,
|
||||
message: "Click to restart and update Zed".to_string(),
|
||||
on_click: Some(Arc::new({
|
||||
let reload = workspace::Reload {
|
||||
binary_path: Some(binary_path.clone()),
|
||||
};
|
||||
move |_, _, cx| workspace::reload(&reload, cx)
|
||||
})),
|
||||
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
|
||||
tooltip_message: Some(Self::version_tooltip_message(&version)),
|
||||
}),
|
||||
AutoUpdateStatus::Errored => Some(Content {
|
||||
|
||||
@@ -2268,6 +2268,15 @@ impl Thread {
|
||||
max_attempts: 3,
|
||||
})
|
||||
}
|
||||
Other(err)
|
||||
if err.is::<PaymentRequiredError>()
|
||||
|| err.is::<ModelRequestLimitReachedError>() =>
|
||||
{
|
||||
// Retrying won't help for Payment Required or Model Request Limit errors (where
|
||||
// the user must upgrade to usage-based billing to get more requests, or else wait
|
||||
// for a significant amount of time for the request limit to reset).
|
||||
None
|
||||
}
|
||||
// Conservatively assume that any other errors are non-retryable
|
||||
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
|
||||
delay: BASE_RETRY_DELAY,
|
||||
|
||||
@@ -23,10 +23,13 @@ assistant_tools.workspace = true
|
||||
chrono.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
handlebars = { workspace = true, features = ["rust-embed"] }
|
||||
html_to_markdown.workspace = true
|
||||
http_client.workspace = true
|
||||
indoc.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
@@ -46,6 +49,7 @@ settings.workspace = true
|
||||
smol.workspace = true
|
||||
task.workspace = true
|
||||
terminal.workspace = true
|
||||
text.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
@@ -58,6 +62,7 @@ workspace-hack.workspace = true
|
||||
ctor.workspace = true
|
||||
client = { workspace = true, "features" = ["test-support"] }
|
||||
clock = { workspace = true, "features" = ["test-support"] }
|
||||
context_server = { workspace = true, "features" = ["test-support"] }
|
||||
editor = { workspace = true, "features" = ["test-support"] }
|
||||
env_logger.workspace = true
|
||||
fs = { workspace = true, "features" = ["test-support"] }
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
use crate::{AgentResponseEvent, Thread, templates::Templates};
|
||||
use crate::{
|
||||
CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool,
|
||||
MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool,
|
||||
ToolCallAuthorization, WebSearchTool,
|
||||
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
|
||||
FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
|
||||
ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
|
||||
WebSearchTool,
|
||||
};
|
||||
use acp_thread::ModelSelector;
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{LanguageModel, LanguageModelRegistry};
|
||||
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
|
||||
};
|
||||
use settings::update_settings_file;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
@@ -48,6 +53,104 @@ struct Session {
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
pub struct LanguageModels {
|
||||
/// Access language model by ID
|
||||
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
|
||||
/// Cached list for returning language model information
|
||||
model_list: acp_thread::AgentModelList,
|
||||
refresh_models_rx: watch::Receiver<()>,
|
||||
refresh_models_tx: watch::Sender<()>,
|
||||
}
|
||||
|
||||
impl LanguageModels {
|
||||
fn new(cx: &App) -> Self {
|
||||
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
|
||||
let mut this = Self {
|
||||
models: HashMap::default(),
|
||||
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
|
||||
refresh_models_rx,
|
||||
refresh_models_tx,
|
||||
};
|
||||
this.refresh_list(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn refresh_list(&mut self, cx: &App) {
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut language_model_list = IndexMap::default();
|
||||
let mut recommended_models = HashSet::default();
|
||||
|
||||
let mut recommended = Vec::new();
|
||||
for provider in &providers {
|
||||
for model in provider.recommended_models(cx) {
|
||||
recommended_models.insert(model.id());
|
||||
recommended.push(Self::map_language_model_to_info(&model, &provider));
|
||||
}
|
||||
}
|
||||
if !recommended.is_empty() {
|
||||
language_model_list.insert(
|
||||
acp_thread::AgentModelGroupName("Recommended".into()),
|
||||
recommended,
|
||||
);
|
||||
}
|
||||
|
||||
let mut models = HashMap::default();
|
||||
for provider in providers {
|
||||
let mut provider_models = Vec::new();
|
||||
for model in provider.provided_models(cx) {
|
||||
let model_info = Self::map_language_model_to_info(&model, &provider);
|
||||
let model_id = model_info.id.clone();
|
||||
if !recommended_models.contains(&model.id()) {
|
||||
provider_models.push(model_info);
|
||||
}
|
||||
models.insert(model_id, model);
|
||||
}
|
||||
if !provider_models.is_empty() {
|
||||
language_model_list.insert(
|
||||
acp_thread::AgentModelGroupName(provider.name().0.clone()),
|
||||
provider_models,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
self.models = models;
|
||||
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
|
||||
self.refresh_models_tx.send(()).ok();
|
||||
}
|
||||
|
||||
fn watch(&self) -> watch::Receiver<()> {
|
||||
self.refresh_models_rx.clone()
|
||||
}
|
||||
|
||||
pub fn model_from_id(
|
||||
&self,
|
||||
model_id: &acp_thread::AgentModelId,
|
||||
) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.models.get(model_id).cloned()
|
||||
}
|
||||
|
||||
fn map_language_model_to_info(
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
) -> acp_thread::AgentModelInfo {
|
||||
acp_thread::AgentModelInfo {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
icon: Some(provider.icon()),
|
||||
}
|
||||
}
|
||||
|
||||
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
|
||||
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NativeAgent {
|
||||
/// Session ID -> Session mapping
|
||||
sessions: HashMap<acp::SessionId, Session>,
|
||||
@@ -55,10 +158,14 @@ pub struct NativeAgent {
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
project_context_needs_refresh: watch::Sender<()>,
|
||||
_maintain_project_context: Task<Result<()>>,
|
||||
context_server_registry: Entity<ContextServerRegistry>,
|
||||
/// Shared templates for all threads
|
||||
templates: Arc<Templates>,
|
||||
/// Cached model information
|
||||
models: LanguageModels,
|
||||
project: Entity<Project>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
@@ -67,6 +174,7 @@ impl NativeAgent {
|
||||
project: Entity<Project>,
|
||||
templates: Arc<Templates>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
fs: Arc<dyn Fs>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<NativeAgent>> {
|
||||
log::info!("Creating new NativeAgent");
|
||||
@@ -76,7 +184,13 @@ impl NativeAgent {
|
||||
.await;
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
|
||||
let mut subscriptions = vec![
|
||||
cx.subscribe(&project, Self::handle_project_event),
|
||||
cx.subscribe(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
Self::handle_models_updated_event,
|
||||
),
|
||||
];
|
||||
if let Some(prompt_store) = prompt_store.as_ref() {
|
||||
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
|
||||
}
|
||||
@@ -90,14 +204,23 @@ impl NativeAgent {
|
||||
_maintain_project_context: cx.spawn(async move |this, cx| {
|
||||
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
|
||||
}),
|
||||
context_server_registry: cx.new(|cx| {
|
||||
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
|
||||
}),
|
||||
templates,
|
||||
models: LanguageModels::new(cx),
|
||||
project,
|
||||
prompt_store,
|
||||
fs,
|
||||
_subscriptions: subscriptions,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn models(&self) -> &LanguageModels {
|
||||
&self.models
|
||||
}
|
||||
|
||||
async fn maintain_project_context(
|
||||
this: WeakEntity<Self>,
|
||||
mut needs_refresh: watch::Receiver<()>,
|
||||
@@ -293,75 +416,104 @@ impl NativeAgent {
|
||||
) {
|
||||
self.project_context_needs_refresh.send(()).ok();
|
||||
}
|
||||
|
||||
fn handle_models_updated_event(
|
||||
&mut self,
|
||||
_registry: Entity<LanguageModelRegistry>,
|
||||
_event: &language_model::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.models.refresh_list(cx);
|
||||
for session in self.sessions.values_mut() {
|
||||
session.thread.update(cx, |thread, _| {
|
||||
let model_id = LanguageModels::model_id(&thread.selected_model);
|
||||
if let Some(model) = self.models.model_from_id(&model_id) {
|
||||
thread.selected_model = model.clone();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper struct that implements the AgentConnection trait
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
|
||||
|
||||
impl ModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
|
||||
impl AgentModelSelector for NativeAgentConnection {
|
||||
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
|
||||
log::debug!("NativeAgentConnection::list_models called");
|
||||
cx.spawn(async move |cx| {
|
||||
cx.update(|cx| {
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
let models = registry.available_models(cx).collect::<Vec<_>>();
|
||||
log::info!("Found {} available models", models.len());
|
||||
if models.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(models)
|
||||
}
|
||||
})?
|
||||
let list = self.0.read(cx).models.model_list.clone();
|
||||
Task::ready(if list.is_empty() {
|
||||
Err(anyhow::anyhow!("No models available"))
|
||||
} else {
|
||||
Ok(list)
|
||||
})
|
||||
}
|
||||
|
||||
fn select_model(
|
||||
&self,
|
||||
session_id: acp::SessionId,
|
||||
model: Arc<dyn LanguageModel>,
|
||||
cx: &mut AsyncApp,
|
||||
model_id: acp_thread::AgentModelId,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!(
|
||||
"Setting model for session {}: {:?}",
|
||||
session_id,
|
||||
model.name()
|
||||
);
|
||||
let agent = self.0.clone();
|
||||
log::info!("Setting model for session {}: {}", session_id, model_id);
|
||||
let Some(thread) = self
|
||||
.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
agent.update(cx, |agent, cx| {
|
||||
if let Some(session) = agent.sessions.get(&session_id) {
|
||||
session.thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model;
|
||||
});
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow!("Session not found"))
|
||||
}
|
||||
})?
|
||||
})
|
||||
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
|
||||
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
|
||||
};
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.selected_model = model.clone();
|
||||
});
|
||||
|
||||
update_settings_file::<AgentSettings>(
|
||||
self.0.read(cx).fs.clone(),
|
||||
cx,
|
||||
move |settings, _cx| {
|
||||
settings.set_model(model);
|
||||
},
|
||||
);
|
||||
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn selected_model(
|
||||
&self,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Task<Result<Arc<dyn LanguageModel>>> {
|
||||
let agent = self.0.clone();
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp_thread::AgentModelInfo>> {
|
||||
let session_id = session_id.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let thread = agent
|
||||
.read_with(cx, |agent, _| {
|
||||
agent
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
})?
|
||||
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
|
||||
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
|
||||
Ok(selected)
|
||||
})
|
||||
|
||||
let Some(thread) = self
|
||||
.0
|
||||
.read(cx)
|
||||
.sessions
|
||||
.get(&session_id)
|
||||
.map(|session| session.thread.clone())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Session not found")));
|
||||
};
|
||||
let model = thread.read(cx).selected_model.clone();
|
||||
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("Provider not found")));
|
||||
};
|
||||
Task::ready(Ok(LanguageModels::map_language_model_to_info(
|
||||
&model, &provider,
|
||||
)))
|
||||
}
|
||||
|
||||
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
|
||||
self.0.read(cx).models.watch()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,7 +537,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
// Create AcpThread
|
||||
let acp_thread = cx.update(|cx| {
|
||||
cx.new(|cx| {
|
||||
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
|
||||
acp_thread::AcpThread::new(
|
||||
"agent2",
|
||||
self.clone(),
|
||||
project.clone(),
|
||||
session_id.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
|
||||
@@ -403,28 +561,37 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
|
||||
let default_model = registry
|
||||
.default_model()
|
||||
.map(|configured| {
|
||||
log::info!(
|
||||
"Using configured default model: {:?} from provider: {:?}",
|
||||
configured.model.name(),
|
||||
configured.provider.name()
|
||||
);
|
||||
configured.model
|
||||
.and_then(|default_model| {
|
||||
agent
|
||||
.models
|
||||
.model_from_id(&LanguageModels::model_id(&default_model.model))
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
log::warn!("No default model configured in settings");
|
||||
anyhow!("No default model configured. Please configure a default model in settings.")
|
||||
anyhow!(
|
||||
"No default model. Please configure a default model in settings."
|
||||
)
|
||||
})?;
|
||||
|
||||
let thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
|
||||
let mut thread = Thread::new(
|
||||
project.clone(),
|
||||
agent.project_context.clone(),
|
||||
agent.context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
agent.templates.clone(),
|
||||
default_model,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(CreateDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(CopyPathTool::new(project.clone()));
|
||||
thread.add_tool(DiagnosticsTool::new(project.clone()));
|
||||
thread.add_tool(MovePathTool::new(project.clone()));
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()));
|
||||
thread.add_tool(OpenTool::new(project.clone()));
|
||||
thread.add_tool(ThinkingTool);
|
||||
thread.add_tool(FindPathTool::new(project.clone()));
|
||||
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
|
||||
thread.add_tool(GrepTool::new(project.clone()));
|
||||
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
|
||||
thread.add_tool(EditFileTool::new(cx.entity()));
|
||||
@@ -448,7 +615,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
acp_thread: acp_thread.downgrade(),
|
||||
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
|
||||
this.sessions.remove(acp_thread.session_id());
|
||||
})
|
||||
}),
|
||||
},
|
||||
);
|
||||
})?;
|
||||
@@ -465,15 +632,17 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
|
||||
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
|
||||
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
|
||||
}
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
let id = id.expect("UserMessageId is required");
|
||||
let session_id = params.session_id.clone();
|
||||
let agent = self.0.clone();
|
||||
log::info!("Received prompt request for session: {}", session_id);
|
||||
@@ -494,10 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
})?;
|
||||
log::debug!("Found session for: {}", session_id);
|
||||
|
||||
// Convert prompt to message
|
||||
let message = convert_prompt_to_message(params.prompt);
|
||||
log::info!("Converted prompt to message: {} chars", message.len());
|
||||
log::debug!("Message content: {}", message);
|
||||
let content: Vec<UserMessageContent> = params
|
||||
.prompt
|
||||
.into_iter()
|
||||
.map(Into::into)
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Converted prompt to message: {} chars", content.len());
|
||||
log::debug!("Message id: {:?}", id);
|
||||
log::debug!("Message content: {:?}", content);
|
||||
|
||||
// Get model using the ModelSelector capability (always available for agent2)
|
||||
// Get the selected model from the thread directly
|
||||
@@ -505,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
|
||||
// Send to thread
|
||||
log::info!("Sending message to thread with model: {:?}", model.name());
|
||||
let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
|
||||
let mut response_stream =
|
||||
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
|
||||
|
||||
// Handle response stream and forward to session.acp_thread
|
||||
while let Some(result) = response_stream.next().await {
|
||||
@@ -599,44 +773,33 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn session_editor(
|
||||
&self,
|
||||
session_id: &agent_client_protocol::SessionId,
|
||||
cx: &mut App,
|
||||
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
|
||||
self.0.update(cx, |agent, _cx| {
|
||||
agent
|
||||
.sessions
|
||||
.get(session_id)
|
||||
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert ACP content blocks to a message string
|
||||
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
|
||||
log::debug!("Converting {} content blocks to message", blocks.len());
|
||||
let mut message = String::new();
|
||||
struct NativeAgentSessionEditor(Entity<Thread>);
|
||||
|
||||
for block in blocks {
|
||||
match block {
|
||||
acp::ContentBlock::Text(text) => {
|
||||
log::trace!("Processing text block: {} chars", text.text.len());
|
||||
message.push_str(&text.text);
|
||||
}
|
||||
acp::ContentBlock::ResourceLink(link) => {
|
||||
log::trace!("Processing resource link: {}", link.uri);
|
||||
message.push_str(&format!(" @{} ", link.uri));
|
||||
}
|
||||
acp::ContentBlock::Image(_) => {
|
||||
log::trace!("Processing image block");
|
||||
message.push_str(" [image] ");
|
||||
}
|
||||
acp::ContentBlock::Audio(_) => {
|
||||
log::trace!("Processing audio block");
|
||||
message.push_str(" [audio] ");
|
||||
}
|
||||
acp::ContentBlock::Resource(resource) => {
|
||||
log::trace!("Processing resource block: {:?}", resource.resource);
|
||||
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
|
||||
}
|
||||
}
|
||||
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
|
||||
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
|
||||
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
|
||||
}
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
|
||||
use fs::FakeFs;
|
||||
use gpui::TestAppContext;
|
||||
use serde_json::json;
|
||||
@@ -654,9 +817,15 @@ mod tests {
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
agent.read_with(cx, |agent, _| {
|
||||
assert_eq!(agent.project_context.borrow().worktrees, vec![])
|
||||
});
|
||||
@@ -697,13 +866,131 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_listing_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/", json!({ "a": {} })).await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
let connection = NativeAgentConnection(
|
||||
NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
|
||||
|
||||
let acp_thread::AgentModelList::Grouped(models) = models else {
|
||||
panic!("Unexpected model group");
|
||||
};
|
||||
assert_eq!(
|
||||
models,
|
||||
IndexMap::from_iter([(
|
||||
AgentModelGroupName("Fake".into()),
|
||||
vec![AgentModelInfo {
|
||||
id: AgentModelId("fake/fake".into()),
|
||||
name: "Fake".into(),
|
||||
icon: Some(ui::IconName::ZedAssistant),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"default_model": {
|
||||
"provider": "foo",
|
||||
"model": "bar"
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [], cx).await;
|
||||
|
||||
// Create the agent and connection
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Create a thread/session
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
Rc::new(connection.clone()).new_thread(
|
||||
project.clone(),
|
||||
Path::new("/a"),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
|
||||
|
||||
// Select a model
|
||||
let model_id = AgentModelId("fake/fake".into());
|
||||
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify the thread has the selected model
|
||||
agent.read_with(cx, |agent, _| {
|
||||
let session = agent.sessions.get(&session_id).unwrap();
|
||||
session.thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.selected_model.id().0, "fake");
|
||||
});
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Verify settings file was updated
|
||||
let settings_content = fs.load(paths::settings_file()).await.unwrap();
|
||||
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
|
||||
|
||||
// Check that the agent settings contain the selected model
|
||||
assert_eq!(
|
||||
settings_json["agent"]["default_model"]["model"],
|
||||
json!("fake")
|
||||
);
|
||||
assert_eq!(
|
||||
settings_json["agent"]["default_model"]["provider"],
|
||||
json!("fake")
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
env_logger::try_init().ok();
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
Project::init_settings(cx);
|
||||
agent_settings::init(cx);
|
||||
language::init(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::{path::Path, rc::Rc, sync::Arc};
|
||||
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use fs::Fs;
|
||||
use gpui::{App, Entity, Task};
|
||||
use project::Project;
|
||||
use prompt_store::PromptStore;
|
||||
@@ -10,7 +10,15 @@ use prompt_store::PromptStore;
|
||||
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NativeAgentServer;
|
||||
pub struct NativeAgentServer {
|
||||
fs: Arc<dyn Fs>,
|
||||
}
|
||||
|
||||
impl NativeAgentServer {
|
||||
pub fn new(fs: Arc<dyn Fs>) -> Self {
|
||||
Self { fs }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn name(&self) -> &'static str {
|
||||
@@ -41,6 +49,7 @@ impl AgentServer for NativeAgentServer {
|
||||
_root_dir
|
||||
);
|
||||
let project = project.clone();
|
||||
let fs = self.fs.clone();
|
||||
let prompt_store = PromptStore::global(cx);
|
||||
cx.spawn(async move |cx| {
|
||||
log::debug!("Creating templates for native agent");
|
||||
@@ -48,7 +57,7 @@ impl AgentServer for NativeAgentServer {
|
||||
let prompt_store = prompt_store.await?;
|
||||
|
||||
log::debug!("Creating native agent entity");
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
|
||||
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
|
||||
|
||||
// Create the connection wrapper
|
||||
let connection = NativeAgentConnection(agent);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use acp_thread::AgentConnection;
|
||||
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_settings::AgentProfileId;
|
||||
use anyhow::Result;
|
||||
use client::{Client, UserStore};
|
||||
use fs::{FakeFs, Fs};
|
||||
@@ -12,8 +13,8 @@ use gpui::{
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
|
||||
StopReason, fake_provider::FakeLanguageModel,
|
||||
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
|
||||
fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
@@ -36,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
|
||||
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send("Testing: Reply with 'Hello'", cx)
|
||||
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## Assistant
|
||||
|
||||
Hello
|
||||
"}
|
||||
)
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
}
|
||||
@@ -57,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send(
|
||||
indoc! {"
|
||||
UserMessageId::new(),
|
||||
[indoc! {"
|
||||
Testing:
|
||||
|
||||
Generate a thinking step where you just think the word 'Think',
|
||||
and have your final answer be 'Hello'
|
||||
"},
|
||||
"}],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -70,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().to_markdown(),
|
||||
thread.last_message().unwrap().to_markdown(),
|
||||
indoc! {"
|
||||
## assistant
|
||||
## Assistant
|
||||
|
||||
<think>Think</think>
|
||||
Hello
|
||||
"}
|
||||
@@ -93,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
|
||||
|
||||
project_context.borrow_mut().shell = "test-shell".into();
|
||||
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
|
||||
thread.update(cx, |thread, cx| thread.send("abc", cx));
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(
|
||||
@@ -130,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
|
||||
UserMessageId::new(),
|
||||
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -144,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
thread.remove_tool(&AgentTool::name(&EchoTool));
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
|
||||
UserMessageId::new(),
|
||||
[
|
||||
"Now call the delay tool with 200ms.",
|
||||
"When the timer goes off, then you echo the output of the tool.",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -154,18 +168,21 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert!(
|
||||
thread
|
||||
.messages()
|
||||
.last()
|
||||
.last_message()
|
||||
.unwrap()
|
||||
.as_agent_message()
|
||||
.unwrap()
|
||||
.content
|
||||
.iter()
|
||||
.any(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
if let AgentMessageContent::Text(text) = content {
|
||||
text.contains("Ding")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
}),
|
||||
"{}",
|
||||
thread.to_markdown()
|
||||
);
|
||||
});
|
||||
}
|
||||
@@ -178,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
// Test a tool call that's likely to complete *before* streaming stops.
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(WordListTool);
|
||||
thread.send("Test the word_list tool.", cx)
|
||||
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
|
||||
});
|
||||
|
||||
let mut saw_partial_tool_use = false;
|
||||
@@ -186,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
|
||||
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
|
||||
thread.update(cx, |thread, _cx| {
|
||||
// Look for a tool use in the thread's last message
|
||||
let last_content = thread.messages().last().unwrap().content.last().unwrap();
|
||||
if let MessageContent::ToolUse(last_tool_use) = last_content {
|
||||
let message = thread.last_message().unwrap();
|
||||
let agent_message = message.as_agent_message().unwrap();
|
||||
let last_content = agent_message.content.last().unwrap();
|
||||
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
|
||||
assert_eq!(last_tool_use.name.as_ref(), "word_list");
|
||||
if tool_call.status == acp::ToolCallStatus::Pending {
|
||||
if !last_tool_use.is_input_complete
|
||||
@@ -225,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.add_tool(ToolRequiringPermission);
|
||||
thread.send("abc", cx)
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
@@ -269,14 +288,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
}),
|
||||
MessageContent::ToolResult(LanguageModelToolResult {
|
||||
language_model::MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: true,
|
||||
@@ -309,13 +328,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
let message = completion.messages.last().unwrap();
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
})]
|
||||
vec![language_model::MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
}
|
||||
)]
|
||||
);
|
||||
|
||||
// Simulate a final tool call, ensuring we don't trigger authorization.
|
||||
@@ -334,13 +355,15 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
|
||||
let message = completion.messages.last().unwrap();
|
||||
assert_eq!(
|
||||
message.content,
|
||||
vec![MessageContent::ToolResult(LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_4".into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
})]
|
||||
vec![language_model::MessageContent::ToolResult(
|
||||
LanguageModelToolResult {
|
||||
tool_use_id: "tool_id_4".into(),
|
||||
tool_name: ToolRequiringPermission.name().into(),
|
||||
is_error: false,
|
||||
content: "Allowed".into(),
|
||||
output: Some("Allowed".into())
|
||||
}
|
||||
)]
|
||||
);
|
||||
}
|
||||
|
||||
@@ -349,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["abc"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
@@ -441,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
.update(cx, |thread, cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.send(
|
||||
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
|
||||
UserMessageId::new(),
|
||||
[
|
||||
"Call the delay tool twice in the same message.",
|
||||
"Once with 100ms. Once with 300ms.",
|
||||
"When both timers are complete, describe the outputs.",
|
||||
],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -452,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let last_message = thread.messages().last().unwrap();
|
||||
let text = last_message
|
||||
let last_message = thread.last_message().unwrap();
|
||||
let agent_message = last_message.as_agent_message().unwrap();
|
||||
let text = agent_message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|content| {
|
||||
if let MessageContent::Text(text) = content {
|
||||
if let AgentMessageContent::Text(text) = content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
@@ -469,6 +500,82 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_profiles(cx: &mut TestAppContext) {
|
||||
let ThreadTest {
|
||||
model, thread, fs, ..
|
||||
} = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
thread.update(cx, |thread, _cx| {
|
||||
thread.add_tool(DelayTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.add_tool(InfiniteTool);
|
||||
});
|
||||
|
||||
// Override profiles and wait for settings to be loaded.
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"profiles": {
|
||||
"test-1": {
|
||||
"name": "Test Profile 1",
|
||||
"tools": {
|
||||
EchoTool.name(): true,
|
||||
DelayTool.name(): true,
|
||||
}
|
||||
},
|
||||
"test-2": {
|
||||
"name": "Test Profile 2",
|
||||
"tools": {
|
||||
InfiniteTool.name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
cx.run_until_parked();
|
||||
|
||||
// Test that test-1 profile (default) has echo and delay tools
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-1".into()));
|
||||
thread.send(UserMessageId::new(), ["test"], cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(pending_completions.len(), 1);
|
||||
let completion = pending_completions.pop().unwrap();
|
||||
let tool_names: Vec<String> = completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect();
|
||||
assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
|
||||
fake_model.end_last_completion_stream();
|
||||
|
||||
// Switch to test-2 profile, and verify that it has only the infinite tool.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_profile(AgentProfileId("test-2".into()));
|
||||
thread.send(UserMessageId::new(), ["test2"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
let mut pending_completions = fake_model.pending_completions();
|
||||
assert_eq!(pending_completions.len(), 1);
|
||||
let completion = pending_completions.pop().unwrap();
|
||||
let tool_names: Vec<String> = completion
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect();
|
||||
assert_eq!(tool_names, vec![InfiniteTool.name()]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
#[ignore = "can't run on CI yet"]
|
||||
async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
@@ -478,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
thread.add_tool(InfiniteTool);
|
||||
thread.add_tool(EchoTool);
|
||||
thread.send(
|
||||
"Call the echo tool and then call the infinite tool, then explain their output",
|
||||
UserMessageId::new(),
|
||||
["Call the echo tool, then call the infinite tool, then explain their output"],
|
||||
cx,
|
||||
)
|
||||
});
|
||||
@@ -523,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
|
||||
// Ensure we can still send a new message after cancellation.
|
||||
let events = thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.send("Testing: reply with 'Hello' then stop.", cx)
|
||||
thread.send(
|
||||
UserMessageId::new(),
|
||||
["Testing: reply with 'Hello' then stop."],
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
thread.update(cx, |thread, _cx| {
|
||||
let message = thread.last_message().unwrap();
|
||||
let agent_message = message.as_agent_message().unwrap();
|
||||
assert_eq!(
|
||||
thread.messages().last().unwrap().content,
|
||||
vec![MessageContent::Text("Hello".to_string())]
|
||||
agent_message.content,
|
||||
vec![AgentMessageContent::Text("Hello".to_string())]
|
||||
);
|
||||
});
|
||||
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
|
||||
@@ -541,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
|
||||
let events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hello"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
## User
|
||||
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
@@ -559,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## user
|
||||
## User
|
||||
|
||||
Hello
|
||||
## assistant
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
@@ -577,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_truncate(cx: &mut TestAppContext) {
|
||||
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let message_id = UserMessageId::new();
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(message_id.clone(), ["Hello"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hello
|
||||
|
||||
## Assistant
|
||||
|
||||
Hey!
|
||||
"}
|
||||
);
|
||||
});
|
||||
|
||||
thread
|
||||
.update(cx, |thread, _cx| thread.truncate(message_id))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(thread.to_markdown(), "");
|
||||
});
|
||||
|
||||
// Ensure we can still send a new message after truncation.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Hi"], cx)
|
||||
});
|
||||
thread.update(cx, |thread, _cx| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hi
|
||||
"}
|
||||
);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||
cx.run_until_parked();
|
||||
thread.read_with(cx, |thread, _| {
|
||||
assert_eq!(
|
||||
thread.to_markdown(),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Hi
|
||||
|
||||
## Assistant
|
||||
|
||||
Ahoy!
|
||||
"}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
cx.update(settings::init);
|
||||
@@ -595,19 +794,26 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
LanguageModelRegistry::test(cx);
|
||||
agent_settings::init(cx);
|
||||
});
|
||||
cx.executor().forbid_parking();
|
||||
|
||||
// Create a project for new_thread
|
||||
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
|
||||
fake_fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
|
||||
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
|
||||
let cwd = Path::new("/test");
|
||||
|
||||
// Create agent and connection
|
||||
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
|
||||
.await
|
||||
.unwrap();
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
templates.clone(),
|
||||
None,
|
||||
fake_fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = NativeAgentConnection(agent.clone());
|
||||
|
||||
// Test model_selector returns Some
|
||||
@@ -620,22 +826,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
|
||||
// Test list_models
|
||||
let listed_models = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.list_models(&mut async_cx)
|
||||
})
|
||||
.update(|cx| selector.list_models(cx))
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
let AgentModelList::Grouped(listed_models) = listed_models else {
|
||||
panic!("Unexpected model list type");
|
||||
};
|
||||
assert!(!listed_models.is_empty(), "should have at least one model");
|
||||
assert_eq!(listed_models[0].id().0, "fake");
|
||||
assert_eq!(
|
||||
listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
|
||||
"fake/fake"
|
||||
);
|
||||
|
||||
// Create a thread using new_thread
|
||||
let connection_rc = Rc::new(connection.clone());
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
connection_rc.new_thread(project, cwd, &mut async_cx)
|
||||
})
|
||||
.update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
|
||||
.await
|
||||
.expect("new_thread should succeed");
|
||||
|
||||
@@ -644,12 +850,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
|
||||
// Test selected_model returns the default
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
let mut async_cx = cx.to_async();
|
||||
selector.selected_model(&session_id, &mut async_cx)
|
||||
})
|
||||
.update(|cx| selector.selected_model(&session_id, cx))
|
||||
.await
|
||||
.expect("selected_model should succeed");
|
||||
let model = cx
|
||||
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
|
||||
.unwrap();
|
||||
let model = model.as_fake();
|
||||
assert_eq!(model.id().0, "fake", "should return default model");
|
||||
|
||||
@@ -683,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
connection.prompt(
|
||||
Some(acp_thread::UserMessageId::new()),
|
||||
acp::PromptRequest {
|
||||
session_id: session_id.clone(),
|
||||
prompt: vec!["ghi".into()],
|
||||
@@ -705,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
|
||||
let fake_model = model.as_fake();
|
||||
|
||||
let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
|
||||
let mut events = thread.update(cx, |thread, cx| {
|
||||
thread.send(UserMessageId::new(), ["Think"], cx)
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
// Simulate streaming partial input.
|
||||
@@ -790,6 +999,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
|
||||
id: acp::ToolCallId("1".into()),
|
||||
fields: acp::ToolCallUpdateFields {
|
||||
status: Some(acp::ToolCallStatus::Completed),
|
||||
raw_output: Some("Finished thinking.".into()),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
@@ -813,6 +1023,7 @@ struct ThreadTest {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
thread: Entity<Thread>,
|
||||
project_context: Rc<RefCell<ProjectContext>>,
|
||||
fs: Arc<FakeFs>,
|
||||
}
|
||||
|
||||
enum TestModel {
|
||||
@@ -835,30 +1046,57 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
cx.executor().allow_parking();
|
||||
|
||||
let fs = FakeFs::new(cx.background_executor.clone());
|
||||
fs.create_dir(paths::settings_file().parent().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
fs.insert_file(
|
||||
paths::settings_file(),
|
||||
json!({
|
||||
"agent": {
|
||||
"default_profile": "test-profile",
|
||||
"profiles": {
|
||||
"test-profile": {
|
||||
"name": "Test Profile",
|
||||
"tools": {
|
||||
EchoTool.name(): true,
|
||||
DelayTool.name(): true,
|
||||
WordListTool.name(): true,
|
||||
ToolRequiringPermission.name(): true,
|
||||
InfiniteTool.name(): true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
.into_bytes(),
|
||||
)
|
||||
.await;
|
||||
|
||||
cx.update(|cx| {
|
||||
settings::init(cx);
|
||||
watch_settings(fs.clone(), cx);
|
||||
Project::init_settings(cx);
|
||||
agent_settings::init(cx);
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
watch_settings(fs.clone(), cx);
|
||||
});
|
||||
|
||||
let templates = Templates::new();
|
||||
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
|
||||
let model = cx
|
||||
.update(|cx| {
|
||||
gpui_tokio::init(cx);
|
||||
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
|
||||
cx.set_http_client(Arc::new(http_client));
|
||||
|
||||
client::init_settings(cx);
|
||||
let client = Client::production(cx);
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
language_model::init(client.clone(), cx);
|
||||
language_models::init(user_store.clone(), client.clone(), cx);
|
||||
|
||||
if let TestModel::Fake = model {
|
||||
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
|
||||
} else {
|
||||
@@ -881,20 +1119,25 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
|
||||
.await;
|
||||
|
||||
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
project_context.clone(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
templates,
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
ThreadTest {
|
||||
model,
|
||||
thread,
|
||||
project_context,
|
||||
fs,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,10 @@
|
||||
mod context_server_registry;
|
||||
mod copy_path_tool;
|
||||
mod create_directory_tool;
|
||||
mod delete_path_tool;
|
||||
mod diagnostics_tool;
|
||||
mod edit_file_tool;
|
||||
mod fetch_tool;
|
||||
mod find_path_tool;
|
||||
mod grep_tool;
|
||||
mod list_directory_tool;
|
||||
@@ -13,10 +16,13 @@ mod terminal_tool;
|
||||
mod thinking_tool;
|
||||
mod web_search_tool;
|
||||
|
||||
pub use context_server_registry::*;
|
||||
pub use copy_path_tool::*;
|
||||
pub use create_directory_tool::*;
|
||||
pub use delete_path_tool::*;
|
||||
pub use diagnostics_tool::*;
|
||||
pub use edit_file_tool::*;
|
||||
pub use fetch_tool::*;
|
||||
pub use find_path_tool::*;
|
||||
pub use grep_tool::*;
|
||||
pub use list_directory_tool::*;
|
||||
|
||||
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
231
crates/agent2/src/tools/context_server_registry.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol::ToolKind;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use context_server::ContextServerId;
|
||||
use gpui::{App, Context, Entity, SharedString, Task};
|
||||
use project::context_server_store::{ContextServerStatus, ContextServerStore};
|
||||
use std::sync::Arc;
|
||||
use util::ResultExt;
|
||||
|
||||
pub struct ContextServerRegistry {
|
||||
server_store: Entity<ContextServerStore>,
|
||||
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
|
||||
_subscription: gpui::Subscription,
|
||||
}
|
||||
|
||||
struct RegisteredContextServer {
|
||||
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
load_tools: Task<Result<()>>,
|
||||
}
|
||||
|
||||
impl ContextServerRegistry {
|
||||
pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
|
||||
let mut this = Self {
|
||||
server_store: server_store.clone(),
|
||||
registered_servers: HashMap::default(),
|
||||
_subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
|
||||
};
|
||||
for server in server_store.read(cx).running_servers() {
|
||||
this.reload_tools_for_server(server.id(), cx);
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
pub fn servers(
|
||||
&self,
|
||||
) -> impl Iterator<
|
||||
Item = (
|
||||
&ContextServerId,
|
||||
&BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
|
||||
),
|
||||
> {
|
||||
self.registered_servers
|
||||
.iter()
|
||||
.map(|(id, server)| (id, &server.tools))
|
||||
}
|
||||
|
||||
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
|
||||
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(client) = server.client() else {
|
||||
return;
|
||||
};
|
||||
if !client.capable(context_server::protocol::ServerCapability::Tools) {
|
||||
return;
|
||||
}
|
||||
|
||||
let registered_server =
|
||||
self.registered_servers
|
||||
.entry(server_id.clone())
|
||||
.or_insert(RegisteredContextServer {
|
||||
tools: BTreeMap::default(),
|
||||
load_tools: Task::ready(Ok(())),
|
||||
});
|
||||
registered_server.load_tools = cx.spawn(async move |this, cx| {
|
||||
let response = client
|
||||
.request::<context_server::types::requests::ListTools>(())
|
||||
.await;
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
registered_server.tools.clear();
|
||||
if let Some(response) = response.log_err() {
|
||||
for tool in response.tools {
|
||||
let tool = Arc::new(ContextServerTool::new(
|
||||
this.server_store.clone(),
|
||||
server.id(),
|
||||
tool,
|
||||
));
|
||||
registered_server.tools.insert(tool.name(), tool);
|
||||
}
|
||||
cx.notify();
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_context_server_store_event(
|
||||
&mut self,
|
||||
_: Entity<ContextServerStore>,
|
||||
event: &project::context_server_store::Event,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
match event {
|
||||
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
|
||||
match status {
|
||||
ContextServerStatus::Starting => {}
|
||||
ContextServerStatus::Running => {
|
||||
self.reload_tools_for_server(server_id.clone(), cx);
|
||||
}
|
||||
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
|
||||
self.registered_servers.remove(&server_id);
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContextServerTool {
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: context_server::types::Tool,
|
||||
}
|
||||
|
||||
impl ContextServerTool {
|
||||
fn new(
|
||||
store: Entity<ContextServerStore>,
|
||||
server_id: ContextServerId,
|
||||
tool: context_server::types::Tool,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
server_id,
|
||||
tool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AnyAgentTool for ContextServerTool {
|
||||
fn name(&self) -> SharedString {
|
||||
self.tool.name.clone().into()
|
||||
}
|
||||
|
||||
fn description(&self) -> SharedString {
|
||||
self.tool.description.clone().unwrap_or_default().into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Other
|
||||
}
|
||||
|
||||
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
|
||||
format!("Run MCP tool `{}`", self.tool.name).into()
|
||||
}
|
||||
|
||||
fn input_schema(
|
||||
&self,
|
||||
format: language_model::LanguageModelToolSchemaFormat,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut schema = self.tool.input_schema.clone();
|
||||
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
|
||||
Ok(match schema {
|
||||
serde_json::Value::Null => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
serde_json::Value::Object(map) if map.is_empty() => {
|
||||
serde_json::json!({ "type": "object", "properties": [] })
|
||||
}
|
||||
_ => schema,
|
||||
})
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: serde_json::Value,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<AgentToolOutput>> {
|
||||
let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
|
||||
return Task::ready(Err(anyhow!("Context server not found")));
|
||||
};
|
||||
let tool_name = self.tool.name.clone();
|
||||
let server_clone = server.clone();
|
||||
let input_clone = input.clone();
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let Some(protocol) = server_clone.client() else {
|
||||
bail!("Context server not initialized");
|
||||
};
|
||||
|
||||
let arguments = if let serde_json::Value::Object(map) = input_clone {
|
||||
Some(map.into_iter().collect())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
log::trace!(
|
||||
"Running tool: {} with arguments: {:?}",
|
||||
tool_name,
|
||||
arguments
|
||||
);
|
||||
let response = protocol
|
||||
.request::<context_server::types::requests::CallTool>(
|
||||
context_server::types::CallToolParams {
|
||||
name: tool_name,
|
||||
arguments,
|
||||
meta: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut result = String::new();
|
||||
for content in response.content {
|
||||
match content {
|
||||
context_server::types::ToolResponseContent::Text { text } => {
|
||||
result.push_str(&text);
|
||||
}
|
||||
context_server::types::ToolResponseContent::Image { .. } => {
|
||||
log::warn!("Ignoring image content from tool response");
|
||||
}
|
||||
context_server::types::ToolResponseContent::Audio { .. } => {
|
||||
log::warn!("Ignoring audio content from tool response");
|
||||
}
|
||||
context_server::types::ToolResponseContent::Resource { .. } => {
|
||||
log::warn!("Ignoring resource content from tool response");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(AgentToolOutput {
|
||||
raw_output: result.clone().into(),
|
||||
llm_output: result.into(),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
163
crates/agent2/src/tools/diagnostics_tool.rs
Normal file
163
crates/agent2/src/tools/diagnostics_tool.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use gpui::{App, Entity, Task};
|
||||
use language::{DiagnosticSeverity, OffsetRangeExt};
|
||||
use project::Project;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Write, path::Path, sync::Arc};
|
||||
use ui::SharedString;
|
||||
use util::markdown::MarkdownInlineCode;
|
||||
|
||||
/// Get errors and warnings for the project or a specific file.
|
||||
///
|
||||
/// This tool can be invoked after a series of edits to determine if further edits are necessary, or if the user asks to fix errors or warnings in their codebase.
|
||||
///
|
||||
/// When a path is provided, shows all diagnostics for that specific file.
|
||||
/// When no path is provided, shows a summary of error and warning counts for all files in the project.
|
||||
///
|
||||
/// <example>
|
||||
/// To get diagnostics for a specific file:
|
||||
/// {
|
||||
/// "path": "src/main.rs"
|
||||
/// }
|
||||
///
|
||||
/// To get a project-wide diagnostic summary:
|
||||
/// {}
|
||||
/// </example>
|
||||
///
|
||||
/// <guidelines>
|
||||
/// - If you think you can fix a diagnostic, make 1-2 attempts and then give up.
|
||||
/// - Don't remove code you've generated just because you can't fix an error. The user can help you fix it.
|
||||
/// </guidelines>
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DiagnosticsToolInput {
|
||||
/// The path to get diagnostics for. If not provided, returns a project-wide summary.
|
||||
///
|
||||
/// This path should never be absolute, and the first component
|
||||
/// of the path should always be a root directory in a project.
|
||||
///
|
||||
/// <example>
|
||||
/// If the project has the following root directories:
|
||||
///
|
||||
/// - lorem
|
||||
/// - ipsum
|
||||
///
|
||||
/// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`.
|
||||
/// </example>
|
||||
pub path: Option<String>,
|
||||
}
|
||||
|
||||
pub struct DiagnosticsTool {
|
||||
project: Entity<Project>,
|
||||
}
|
||||
|
||||
impl DiagnosticsTool {
|
||||
pub fn new(project: Entity<Project>) -> Self {
|
||||
Self { project }
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for DiagnosticsTool {
|
||||
type Input = DiagnosticsToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"diagnostics".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Read
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
if let Some(path) = input.ok().and_then(|input| match input.path {
|
||||
Some(path) if !path.is_empty() => Some(path),
|
||||
_ => None,
|
||||
}) {
|
||||
format!("Check diagnostics for {}", MarkdownInlineCode(&path)).into()
|
||||
} else {
|
||||
"Check project diagnostics".into()
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
match input.path {
|
||||
Some(path) if !path.is_empty() => {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&path, cx) else {
|
||||
return Task::ready(Err(anyhow!("Could not find path {path} in project",)));
|
||||
};
|
||||
|
||||
let buffer = self
|
||||
.project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx));
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let mut output = String::new();
|
||||
let buffer = buffer.await?;
|
||||
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
for (_, group) in snapshot.diagnostic_groups(None) {
|
||||
let entry = &group.entries[group.primary_ix];
|
||||
let range = entry.range.to_point(&snapshot);
|
||||
let severity = match entry.diagnostic.severity {
|
||||
DiagnosticSeverity::ERROR => "error",
|
||||
DiagnosticSeverity::WARNING => "warning",
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"{} at line {}: {}",
|
||||
severity,
|
||||
range.start.row + 1,
|
||||
entry.diagnostic.message
|
||||
)?;
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
Ok("File doesn't have errors or warnings!".to_string())
|
||||
} else {
|
||||
Ok(output)
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
let project = self.project.read(cx);
|
||||
let mut output = String::new();
|
||||
let mut has_diagnostics = false;
|
||||
|
||||
for (project_path, _, summary) in project.diagnostic_summaries(true, cx) {
|
||||
if summary.error_count > 0 || summary.warning_count > 0 {
|
||||
let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
has_diagnostics = true;
|
||||
output.push_str(&format!(
|
||||
"{}: {} error(s), {} warning(s)\n",
|
||||
Path::new(worktree.read(cx).root_name())
|
||||
.join(project_path.path)
|
||||
.display(),
|
||||
summary.error_count,
|
||||
summary.warning_count
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if has_diagnostics {
|
||||
Task::ready(Ok(output))
|
||||
} else {
|
||||
Task::ready(Ok("No errors or warnings found in the project.".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
use crate::{AgentTool, Thread, ToolCallEventStream};
|
||||
use acp_thread::Diff;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use gpui::{App, AppContext, AsyncApp, Entity, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::ToPoint;
|
||||
use language::language_settings::{self, FormatOnSave};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use paths;
|
||||
@@ -225,6 +226,16 @@ impl AgentTool for EditFileTool {
|
||||
Ok(path) => path,
|
||||
Err(err) => return Task::ready(Err(anyhow!(err))),
|
||||
};
|
||||
let abs_path = project.read(cx).absolute_path(&project_path, cx);
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: abs_path,
|
||||
line: None,
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
let request = self.thread.update(cx, |thread, cx| {
|
||||
thread.build_completion_request(CompletionIntent::ToolResults, cx)
|
||||
@@ -283,13 +294,38 @@ impl AgentTool for EditFileTool {
|
||||
|
||||
let mut hallucinated_old_text = false;
|
||||
let mut ambiguous_ranges = Vec::new();
|
||||
let mut emitted_location = false;
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited => {},
|
||||
EditAgentOutputEvent::Edited(range) => {
|
||||
if !emitted_location {
|
||||
let line = buffer.update(cx, |buffer, _cx| {
|
||||
range.start.to_point(&buffer.snapshot()).row
|
||||
}).ok();
|
||||
if let Some(abs_path) = abs_path.clone() {
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
emitted_location = true;
|
||||
}
|
||||
},
|
||||
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
|
||||
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
|
||||
EditAgentOutputEvent::ResolvingEditRange(range) => {
|
||||
diff.update(cx, |card, cx| card.reveal_range(range, cx))?;
|
||||
diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
|
||||
// if !emitted_location {
|
||||
// let line = buffer.update(cx, |buffer, _cx| {
|
||||
// range.start.to_point(&buffer.snapshot()).row
|
||||
// }).ok();
|
||||
// if let Some(abs_path) = abs_path.clone() {
|
||||
// event_stream.update_fields(ToolCallUpdateFields {
|
||||
// locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
|
||||
// ..Default::default()
|
||||
// });
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -454,9 +490,8 @@ fn resolve_path(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::Templates;
|
||||
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates};
|
||||
use action_log::ActionLog;
|
||||
use client::TelemetrySettings;
|
||||
use fs::Fs;
|
||||
@@ -475,9 +510,20 @@ mod tests {
|
||||
fs.insert_tree("/root", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread =
|
||||
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log,
|
||||
Templates::new(),
|
||||
model,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let result = cx
|
||||
.update(|cx| {
|
||||
let input = EditFileToolInput {
|
||||
@@ -661,14 +707,18 @@ mod tests {
|
||||
});
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
@@ -792,15 +842,19 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
@@ -914,15 +968,19 @@ mod tests {
|
||||
init_test(cx);
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1041,15 +1099,19 @@ mod tests {
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project,
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1148,14 +1210,18 @@ mod tests {
|
||||
.await;
|
||||
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1225,14 +1291,18 @@ mod tests {
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1305,14 +1375,18 @@ mod tests {
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry.clone(),
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
@@ -1382,14 +1456,18 @@ mod tests {
|
||||
let fs = project::FakeFs::new(cx.executor());
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
let action_log = cx.new(|_| ActionLog::new(project.clone()));
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread = cx.new(|_| {
|
||||
let thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
Rc::default(),
|
||||
context_server_registry,
|
||||
action_log.clone(),
|
||||
Templates::new(),
|
||||
model.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let tool = Arc::new(EditFileTool { thread });
|
||||
|
||||
155
crates/agent2/src/tools/fetch_tool.rs
Normal file
155
crates/agent2/src/tools/fetch_tool.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, cell::RefCell};
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use futures::AsyncReadExt as _;
|
||||
use gpui::{App, AppContext as _, Task};
|
||||
use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown};
|
||||
use http_client::{AsyncBody, HttpClientWithUrl};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::SharedString;
|
||||
use util::markdown::MarkdownEscaped;
|
||||
|
||||
use crate::{AgentTool, ToolCallEventStream};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
|
||||
enum ContentType {
|
||||
Html,
|
||||
Plaintext,
|
||||
Json,
|
||||
}
|
||||
|
||||
/// Fetches a URL and returns the content as Markdown.
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct FetchToolInput {
|
||||
/// The URL to fetch.
|
||||
url: String,
|
||||
}
|
||||
|
||||
pub struct FetchTool {
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
}
|
||||
|
||||
impl FetchTool {
|
||||
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
|
||||
Self { http_client }
|
||||
}
|
||||
|
||||
async fn build_message(http_client: Arc<HttpClientWithUrl>, url: &str) -> Result<String> {
|
||||
let url = if !url.starts_with("https://") && !url.starts_with("http://") {
|
||||
Cow::Owned(format!("https://{url}"))
|
||||
} else {
|
||||
Cow::Borrowed(url)
|
||||
};
|
||||
|
||||
let mut response = http_client.get(&url, AsyncBody::default(), true).await?;
|
||||
|
||||
let mut body = Vec::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_end(&mut body)
|
||||
.await
|
||||
.context("error reading response body")?;
|
||||
|
||||
if response.status().is_client_error() {
|
||||
let text = String::from_utf8_lossy(body.as_slice());
|
||||
bail!(
|
||||
"status error {}, response: {text:?}",
|
||||
response.status().as_u16()
|
||||
);
|
||||
}
|
||||
|
||||
let Some(content_type) = response.headers().get("content-type") else {
|
||||
bail!("missing Content-Type header");
|
||||
};
|
||||
let content_type = content_type
|
||||
.to_str()
|
||||
.context("invalid Content-Type header")?;
|
||||
|
||||
let content_type = if content_type.starts_with("text/plain") {
|
||||
ContentType::Plaintext
|
||||
} else if content_type.starts_with("application/json") {
|
||||
ContentType::Json
|
||||
} else {
|
||||
ContentType::Html
|
||||
};
|
||||
|
||||
match content_type {
|
||||
ContentType::Html => {
|
||||
let mut handlers: Vec<TagHandler> = vec![
|
||||
Rc::new(RefCell::new(markdown::WebpageChromeRemover)),
|
||||
Rc::new(RefCell::new(markdown::ParagraphHandler)),
|
||||
Rc::new(RefCell::new(markdown::HeadingHandler)),
|
||||
Rc::new(RefCell::new(markdown::ListHandler)),
|
||||
Rc::new(RefCell::new(markdown::TableHandler::new())),
|
||||
Rc::new(RefCell::new(markdown::StyledTextHandler)),
|
||||
];
|
||||
if url.contains("wikipedia.org") {
|
||||
use html_to_markdown::structure::wikipedia;
|
||||
|
||||
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover)));
|
||||
handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler)));
|
||||
handlers.push(Rc::new(
|
||||
RefCell::new(wikipedia::WikipediaCodeHandler::new()),
|
||||
));
|
||||
} else {
|
||||
handlers.push(Rc::new(RefCell::new(markdown::CodeHandler)));
|
||||
}
|
||||
|
||||
convert_html_to_markdown(&body[..], &mut handlers)
|
||||
}
|
||||
ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()),
|
||||
ContentType::Json => {
|
||||
let json: serde_json::Value = serde_json::from_slice(&body)?;
|
||||
|
||||
Ok(format!(
|
||||
"```json\n{}\n```",
|
||||
serde_json::to_string_pretty(&json)?
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for FetchTool {
|
||||
type Input = FetchToolInput;
|
||||
type Output = String;
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"fetch".into()
|
||||
}
|
||||
|
||||
fn kind(&self) -> acp::ToolKind {
|
||||
acp::ToolKind::Fetch
|
||||
}
|
||||
|
||||
fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
|
||||
match input {
|
||||
Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)).into(),
|
||||
Err(_) => "Fetch URL".into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
let text = cx.background_spawn({
|
||||
let http_client = self.http_client.clone();
|
||||
async move { Self::build_message(http_client, &input.url).await }
|
||||
});
|
||||
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let text = text.await?;
|
||||
if text.trim().is_empty() {
|
||||
bail!("no textual content found");
|
||||
}
|
||||
Ok(text)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -139,9 +139,6 @@ impl AgentTool for FindPathTool {
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
raw_output: Some(serde_json::json!({
|
||||
"paths": &matches,
|
||||
})),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Self::Output>> {
|
||||
const CONTEXT_LINES: u32 = 2;
|
||||
@@ -282,33 +282,22 @@ impl AgentTool for GrepTool {
|
||||
}
|
||||
}
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
matches_found += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let output = if matches_found == 0 {
|
||||
"No matches found".to_string()
|
||||
if matches_found == 0 {
|
||||
Ok("No matches found".into())
|
||||
} else if has_more_matches {
|
||||
format!(
|
||||
Ok(format!(
|
||||
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
|
||||
input.offset + 1,
|
||||
input.offset + matches_found,
|
||||
input.offset + RESULTS_PER_PAGE,
|
||||
)
|
||||
))
|
||||
} else {
|
||||
format!("Found {matches_found} matches:\n{output}")
|
||||
};
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![output.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Ok(output)
|
||||
Ok(format!("Found {matches_found} matches:\n{output}"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,20 +47,13 @@ impl AgentTool for NowTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
event_stream: ToolCallEventStream,
|
||||
_event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Task<Result<String>> {
|
||||
let now = match input.timezone {
|
||||
Timezone::Utc => Utc::now().to_rfc3339(),
|
||||
Timezone::Local => Local::now().to_rfc3339(),
|
||||
};
|
||||
let content = format!("The current datetime is {now}.");
|
||||
|
||||
event_stream.update_fields(acp::ToolCallUpdateFields {
|
||||
content: Some(vec![content.clone().into()]),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
Task::ready(Ok(content))
|
||||
Task::ready(Ok(format!("The current datetime is {now}.")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use action_log::ActionLog;
|
||||
use agent_client_protocol::{self as acp};
|
||||
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::outline;
|
||||
use gpui::{App, Entity, SharedString, Task};
|
||||
use indoc::formatdoc;
|
||||
use language::{Anchor, Point};
|
||||
use language::Point;
|
||||
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
|
||||
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
|
||||
use schemars::JsonSchema;
|
||||
@@ -97,7 +97,7 @@ impl AgentTool for ReadFileTool {
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
_event_stream: ToolCallEventStream,
|
||||
event_stream: ToolCallEventStream,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<LanguageModelToolResultContent>> {
|
||||
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
|
||||
@@ -166,7 +166,9 @@ impl AgentTool for ReadFileTool {
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = cx
|
||||
.update(|cx| {
|
||||
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
|
||||
project.update(cx, |project, cx| {
|
||||
project.open_buffer(project_path.clone(), cx)
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
if buffer.read_with(cx, |buffer, _| {
|
||||
@@ -178,19 +180,10 @@ impl AgentTool for ReadFileTool {
|
||||
anyhow::bail!("{file_path} not found");
|
||||
}
|
||||
|
||||
project.update(cx, |project, cx| {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: Anchor::MIN,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
let mut anchor = None;
|
||||
|
||||
// Check if specific line ranges are provided
|
||||
if input.start_line.is_some() || input.end_line.is_some() {
|
||||
let mut anchor = None;
|
||||
let result = if input.start_line.is_some() || input.end_line.is_some() {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
|
||||
@@ -214,18 +207,6 @@ impl AgentTool for ReadFileTool {
|
||||
log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
if let Some(anchor) = anchor {
|
||||
project.update(cx, |project, cx| {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: anchor,
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(result.into())
|
||||
} else {
|
||||
// No line ranges specified, so check file size to see if it's too big.
|
||||
@@ -236,7 +217,7 @@ impl AgentTool for ReadFileTool {
|
||||
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
action_log.update(cx, |log, cx| {
|
||||
log.buffer_read(buffer, cx);
|
||||
log.buffer_read(buffer.clone(), cx);
|
||||
})?;
|
||||
|
||||
Ok(result.into())
|
||||
@@ -244,7 +225,8 @@ impl AgentTool for ReadFileTool {
|
||||
// File is too big, so return the outline
|
||||
// and a suggestion to read again with line numbers.
|
||||
let outline =
|
||||
outline::file_outline(project, file_path, action_log, None, cx).await?;
|
||||
outline::file_outline(project.clone(), file_path, action_log, None, cx)
|
||||
.await?;
|
||||
Ok(formatdoc! {"
|
||||
This file was too big to read all at once.
|
||||
|
||||
@@ -261,7 +243,28 @@ impl AgentTool for ReadFileTool {
|
||||
}
|
||||
.into())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
project.update(cx, |project, cx| {
|
||||
if let Some(abs_path) = project.absolute_path(&project_path, cx) {
|
||||
project.set_agent_location(
|
||||
Some(AgentLocation {
|
||||
buffer: buffer.downgrade(),
|
||||
position: anchor.unwrap_or(text::Anchor::MIN),
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
event_stream.update_fields(ToolCallUpdateFields {
|
||||
locations: Some(vec![acp::ToolCallLocation {
|
||||
path: abs_path,
|
||||
line: input.start_line.map(|line| line.saturating_sub(1)),
|
||||
}]),
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
})?;
|
||||
|
||||
result
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ use agent_client_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::WebSearchResponse;
|
||||
use gpui::{App, AppContext, Task};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use language_model::{
|
||||
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ui::prelude::*;
|
||||
@@ -50,6 +52,11 @@ impl AgentTool for WebSearchTool {
|
||||
"Searching the Web".into()
|
||||
}
|
||||
|
||||
/// We currently only support Zed Cloud as a provider.
|
||||
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
|
||||
provider == &ZED_CLOUD_PROVIDER_ID
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
input: Self::Input,
|
||||
|
||||
@@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection {
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
|
||||
@@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection {
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
|
||||
@@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection {
|
||||
|
||||
fn prompt(
|
||||
&self,
|
||||
_id: Option<acp_thread::UserMessageId>,
|
||||
params: acp::PromptRequest,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<acp::PromptResponse>> {
|
||||
@@ -423,7 +424,7 @@ impl ClaudeAgentSession {
|
||||
if !turn_state.borrow().is_cancelled() {
|
||||
thread
|
||||
.update(cx, |thread, cx| {
|
||||
thread.push_user_content_block(text.into(), cx)
|
||||
thread.push_user_content_block(None, text.into(), cx)
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
|
||||
@@ -48,6 +48,20 @@ pub struct AgentProfileSettings {
|
||||
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
|
||||
}
|
||||
|
||||
impl AgentProfileSettings {
|
||||
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
|
||||
self.tools.get(tool_name) == Some(&true)
|
||||
}
|
||||
|
||||
pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
|
||||
self.enable_all_context_servers
|
||||
|| self
|
||||
.context_servers
|
||||
.get(server_id)
|
||||
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ContextServerPreset {
|
||||
pub tools: IndexMap<Arc<str>, bool>,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
mod completion_provider;
|
||||
mod message_history;
|
||||
mod message_editor;
|
||||
mod model_selector;
|
||||
mod model_selector_popover;
|
||||
mod thread_view;
|
||||
|
||||
pub use message_history::MessageHistory;
|
||||
pub use model_selector::AcpModelSelector;
|
||||
pub use model_selector_popover::AcpModelSelectorPopover;
|
||||
pub use thread_view::AcpThreadView;
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use anyhow::Result;
|
||||
use acp_thread::MentionUri;
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use editor::display_map::CreaseId;
|
||||
use editor::{CompletionProvider, Editor, ExcerptId};
|
||||
use file_icons::FileIcons;
|
||||
use futures::future::try_join_all;
|
||||
use gpui::{App, Entity, Task, WeakEntity};
|
||||
use language::{Buffer, CodeLabel, HighlightId};
|
||||
use lsp::CompletionContext;
|
||||
use parking_lot::Mutex;
|
||||
use project::{Completion, CompletionIntent, CompletionResponse, ProjectPath, WorktreeId};
|
||||
use project::{Completion, CompletionIntent, CompletionResponse, Project, ProjectPath, WorktreeId};
|
||||
use rope::Point;
|
||||
use text::{Anchor, ToPoint};
|
||||
use ui::prelude::*;
|
||||
@@ -23,21 +25,63 @@ use crate::context_picker::file_context_picker::{extract_file_name_and_directory
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MentionSet {
|
||||
paths_by_crease_id: HashMap<CreaseId, ProjectPath>,
|
||||
paths_by_crease_id: HashMap<CreaseId, MentionUri>,
|
||||
}
|
||||
|
||||
impl MentionSet {
|
||||
pub fn insert(&mut self, crease_id: CreaseId, path: ProjectPath) {
|
||||
self.paths_by_crease_id.insert(crease_id, path);
|
||||
}
|
||||
|
||||
pub fn path_for_crease_id(&self, crease_id: CreaseId) -> Option<ProjectPath> {
|
||||
self.paths_by_crease_id.get(&crease_id).cloned()
|
||||
pub fn insert(&mut self, crease_id: CreaseId, path: PathBuf) {
|
||||
self.paths_by_crease_id
|
||||
.insert(crease_id, MentionUri::File(path));
|
||||
}
|
||||
|
||||
pub fn drain(&mut self) -> impl Iterator<Item = CreaseId> {
|
||||
self.paths_by_crease_id.drain().map(|(id, _)| id)
|
||||
}
|
||||
|
||||
pub fn contents(
|
||||
&self,
|
||||
project: Entity<Project>,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<HashMap<CreaseId, Mention>>> {
|
||||
let contents = self
|
||||
.paths_by_crease_id
|
||||
.iter()
|
||||
.map(|(crease_id, uri)| match uri {
|
||||
MentionUri::File(path) => {
|
||||
let crease_id = *crease_id;
|
||||
let uri = uri.clone();
|
||||
let path = path.to_path_buf();
|
||||
let buffer_task = project.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(path, cx)
|
||||
.context("Failed to find project path")?;
|
||||
anyhow::Ok(project.open_buffer(path, cx))
|
||||
});
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let buffer = buffer_task?.await?;
|
||||
let content = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
|
||||
|
||||
anyhow::Ok((crease_id, Mention { uri, content }))
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
// TODO
|
||||
unimplemented!()
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
let contents = try_join_all(contents).await?.into_iter().collect();
|
||||
anyhow::Ok(contents)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Mention {
|
||||
pub uri: MentionUri,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
pub struct ContextPickerCompletionProvider {
|
||||
@@ -68,6 +112,7 @@ impl ContextPickerCompletionProvider {
|
||||
source_range: Range<Anchor>,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
project: Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Completion {
|
||||
let (file_name, directory) =
|
||||
@@ -112,6 +157,7 @@ impl ContextPickerCompletionProvider {
|
||||
new_text_len - 1,
|
||||
editor,
|
||||
mention_set,
|
||||
project,
|
||||
)),
|
||||
}
|
||||
}
|
||||
@@ -159,6 +205,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
return Task::ready(Ok(Vec::new()));
|
||||
};
|
||||
|
||||
let project = workspace.read(cx).project().clone();
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
let source_range = snapshot.anchor_before(state.source_range.start)
|
||||
..snapshot.anchor_after(state.source_range.end);
|
||||
@@ -195,6 +242,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
|
||||
source_range.clone(),
|
||||
editor.clone(),
|
||||
mention_set.clone(),
|
||||
project.clone(),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
@@ -254,6 +302,7 @@ fn confirm_completion_callback(
|
||||
content_len: usize,
|
||||
editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
project: Entity<Project>,
|
||||
) -> Arc<dyn Fn(CompletionIntent, &mut Window, &mut App) -> bool + Send + Sync> {
|
||||
Arc::new(move |_, window, cx| {
|
||||
let crease_text = crease_text.clone();
|
||||
@@ -261,6 +310,7 @@ fn confirm_completion_callback(
|
||||
let editor = editor.clone();
|
||||
let project_path = project_path.clone();
|
||||
let mention_set = mention_set.clone();
|
||||
let project = project.clone();
|
||||
window.defer(cx, move |window, cx| {
|
||||
let crease_id = crate::context_picker::insert_crease_for_mention(
|
||||
excerpt_id,
|
||||
@@ -272,8 +322,13 @@ fn confirm_completion_callback(
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
let Some(path) = project.read(cx).absolute_path(&project_path, cx) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(crease_id) = crease_id {
|
||||
mention_set.lock().insert(crease_id, project_path);
|
||||
mention_set.lock().insert(crease_id, path);
|
||||
}
|
||||
});
|
||||
false
|
||||
|
||||
479
crates/agent_ui/src/acp/message_editor.rs
Normal file
479
crates/agent_ui/src/acp/message_editor.rs
Normal file
@@ -0,0 +1,479 @@
|
||||
use crate::acp::completion_provider::ContextPickerCompletionProvider;
|
||||
use crate::acp::completion_provider::MentionSet;
|
||||
use acp_thread::MentionUri;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use editor::{
|
||||
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode,
|
||||
EditorStyle, MultiBuffer,
|
||||
};
|
||||
use file_icons::FileIcons;
|
||||
use gpui::{
|
||||
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity,
|
||||
};
|
||||
use language::Language;
|
||||
use language::{Buffer, BufferSnapshot};
|
||||
use parking_lot::Mutex;
|
||||
use project::{CompletionIntent, Project};
|
||||
use settings::Settings;
|
||||
use std::path::Path;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{
|
||||
ActiveTheme, App, IconName, InteractiveElement, IntoElement, ParentElement, Render,
|
||||
SharedString, Styled, TextSize, Window, div,
|
||||
};
|
||||
use util::ResultExt;
|
||||
use workspace::Workspace;
|
||||
use zed_actions::agent::Chat;
|
||||
|
||||
pub const MIN_EDITOR_LINES: usize = 4;
|
||||
pub const MAX_EDITOR_LINES: usize = 8;
|
||||
|
||||
pub struct MessageEditor {
|
||||
editor: Entity<Editor>,
|
||||
project: Entity<Project>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
}
|
||||
|
||||
pub enum MessageEditorEvent {
|
||||
Chat,
|
||||
}
|
||||
|
||||
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
|
||||
|
||||
impl MessageEditor {
|
||||
pub fn new(
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let language = Language::new(
|
||||
language::LanguageConfig {
|
||||
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
|
||||
..Default::default()
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
|
||||
let editor = cx.new(|cx| {
|
||||
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
|
||||
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
|
||||
|
||||
let mut editor = Editor::new(
|
||||
editor::EditorMode::AutoHeight {
|
||||
min_lines: MIN_EDITOR_LINES,
|
||||
max_lines: Some(MAX_EDITOR_LINES),
|
||||
},
|
||||
buffer,
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
editor.set_placeholder_text("Message the agent - @ to include files", cx);
|
||||
editor.set_show_indent_guides(false, cx);
|
||||
editor.set_soft_wrap();
|
||||
editor.set_use_modal_editing(true);
|
||||
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
|
||||
mention_set.clone(),
|
||||
workspace,
|
||||
cx.weak_entity(),
|
||||
))));
|
||||
editor.set_context_menu_options(ContextMenuOptions {
|
||||
min_entries_visible: 12,
|
||||
max_entries_visible: 12,
|
||||
placement: Some(ContextMenuPlacement::Above),
|
||||
});
|
||||
editor
|
||||
});
|
||||
|
||||
Self {
|
||||
editor,
|
||||
project,
|
||||
mention_set,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self, cx: &App) -> bool {
|
||||
self.editor.read(cx).is_empty(cx)
|
||||
}
|
||||
|
||||
pub fn contents(&self, cx: &mut Context<Self>) -> Task<Result<Vec<acp::ContentBlock>>> {
|
||||
let contents = self.mention_set.lock().contents(self.project.clone(), cx);
|
||||
let editor = self.editor.clone();
|
||||
|
||||
cx.spawn(async move |_, cx| {
|
||||
let contents = contents.await?;
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
let mut ix = 0;
|
||||
let mut chunks: Vec<acp::ContentBlock> = Vec::new();
|
||||
let text = editor.text(cx);
|
||||
editor.display_map.update(cx, |map, cx| {
|
||||
let snapshot = map.snapshot(cx);
|
||||
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
|
||||
// Skip creases that have been edited out of the message buffer.
|
||||
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(mention) = contents.get(&crease_id) {
|
||||
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
|
||||
if crease_range.start > ix {
|
||||
chunks.push(text[ix..crease_range.start].into());
|
||||
}
|
||||
chunks.push(acp::ContentBlock::Resource(acp::EmbeddedResource {
|
||||
annotations: None,
|
||||
resource: acp::EmbeddedResourceResource::TextResourceContents(
|
||||
acp::TextResourceContents {
|
||||
mime_type: None,
|
||||
text: mention.content.clone(),
|
||||
uri: mention.uri.to_uri(),
|
||||
},
|
||||
),
|
||||
}));
|
||||
ix = crease_range.end;
|
||||
}
|
||||
}
|
||||
|
||||
if ix < text.len() {
|
||||
let last_chunk = text[ix..].trim_end();
|
||||
if !last_chunk.is_empty() {
|
||||
chunks.push(last_chunk.into());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
chunks
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.clear(window, cx);
|
||||
editor.remove_creases(self.mention_set.lock().drain(), cx)
|
||||
});
|
||||
}
|
||||
|
||||
fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
|
||||
cx.emit(MessageEditorEvent::Chat)
|
||||
}
|
||||
|
||||
pub fn insert_dragged_files(
|
||||
&self,
|
||||
paths: Vec<project::ProjectPath>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let buffer = self.editor.read(cx).buffer().clone();
|
||||
let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
let Some(buffer) = buffer.read(cx).as_singleton() else {
|
||||
return;
|
||||
};
|
||||
for path in paths {
|
||||
let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
|
||||
continue;
|
||||
};
|
||||
let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
|
||||
let path_prefix = abs_path
|
||||
.file_name()
|
||||
.unwrap_or(path.path.as_os_str())
|
||||
.display()
|
||||
.to_string();
|
||||
let completion = ContextPickerCompletionProvider::completion_for_path(
|
||||
path,
|
||||
&path_prefix,
|
||||
false,
|
||||
entry.is_dir(),
|
||||
excerpt_id,
|
||||
anchor..anchor,
|
||||
self.editor.clone(),
|
||||
self.mention_set.clone(),
|
||||
self.project.clone(),
|
||||
cx,
|
||||
);
|
||||
|
||||
self.editor.update(cx, |message_editor, cx| {
|
||||
message_editor.edit(
|
||||
[(
|
||||
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
|
||||
completion.new_text,
|
||||
)],
|
||||
cx,
|
||||
);
|
||||
});
|
||||
if let Some(confirm) = completion.confirm.clone() {
|
||||
confirm(CompletionIntent::Complete, window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_expanded(&mut self, expanded: bool, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
if expanded {
|
||||
editor.set_mode(EditorMode::Full {
|
||||
scale_ui_elements_with_buffer_font_size: false,
|
||||
show_active_line_background: false,
|
||||
sized_by_content: false,
|
||||
})
|
||||
} else {
|
||||
editor.set_mode(EditorMode::AutoHeight {
|
||||
min_lines: MIN_EDITOR_LINES,
|
||||
max_lines: Some(MAX_EDITOR_LINES),
|
||||
})
|
||||
}
|
||||
cx.notify()
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn set_draft_message(
|
||||
message_editor: Entity<Editor>,
|
||||
mention_set: Arc<Mutex<MentionSet>>,
|
||||
project: Entity<Project>,
|
||||
message: Option<&[acp::ContentBlock]>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<BufferSnapshot> {
|
||||
cx.notify();
|
||||
|
||||
let message = message?;
|
||||
|
||||
let mut text = String::new();
|
||||
let mut mentions = Vec::new();
|
||||
|
||||
for chunk in message {
|
||||
match chunk {
|
||||
acp::ContentBlock::Text(text_content) => {
|
||||
text.push_str(&text_content.text);
|
||||
}
|
||||
acp::ContentBlock::Resource(acp::EmbeddedResource {
|
||||
resource: acp::EmbeddedResourceResource::TextResourceContents(resource),
|
||||
..
|
||||
}) => {
|
||||
if let Some(ref mention @ MentionUri::File(ref abs_path)) =
|
||||
MentionUri::parse(&resource.uri).log_err()
|
||||
{
|
||||
let project_path = project
|
||||
.read(cx)
|
||||
.project_path_for_absolute_path(&abs_path, cx);
|
||||
let start = text.len();
|
||||
let content = mention.to_uri();
|
||||
text.push_str(&content);
|
||||
let end = text.len();
|
||||
if let Some(project_path) = project_path {
|
||||
let filename: SharedString = project_path
|
||||
.path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string()
|
||||
.into();
|
||||
mentions.push((start..end, project_path, filename));
|
||||
}
|
||||
}
|
||||
}
|
||||
acp::ContentBlock::Image(_)
|
||||
| acp::ContentBlock::Audio(_)
|
||||
| acp::ContentBlock::Resource(_)
|
||||
| acp::ContentBlock::ResourceLink(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
let snapshot = message_editor.update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
editor.buffer().read(cx).snapshot(cx)
|
||||
});
|
||||
|
||||
for (range, project_path, filename) in mentions {
|
||||
let crease_icon_path = if project_path.path.is_dir() {
|
||||
FileIcons::get_folder_icon(false, cx)
|
||||
.unwrap_or_else(|| IconName::Folder.path().into())
|
||||
} else {
|
||||
FileIcons::get_icon(Path::new(project_path.path.as_ref()), cx)
|
||||
.unwrap_or_else(|| IconName::File.path().into())
|
||||
};
|
||||
|
||||
let anchor = snapshot.anchor_before(range.start);
|
||||
if let Some(project_path) = project.read(cx).absolute_path(&project_path, cx) {
|
||||
let crease_id = crate::context_picker::insert_crease_for_mention(
|
||||
anchor.excerpt_id,
|
||||
anchor.text_anchor,
|
||||
range.end - range.start,
|
||||
filename,
|
||||
crease_icon_path,
|
||||
message_editor.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
|
||||
if let Some(crease_id) = crease_id {
|
||||
mention_set.lock().insert(crease_id, project_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let snapshot = snapshot.as_singleton().unwrap().2.clone();
|
||||
Some(snapshot)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn set_text(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
editor.set_text(text, window, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Focusable for MessageEditor {
|
||||
fn focus_handle(&self, cx: &App) -> FocusHandle {
|
||||
self.editor.focus_handle(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for MessageEditor {
|
||||
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
div()
|
||||
.key_context("MessageEditor")
|
||||
.on_action(cx.listener(Self::chat))
|
||||
.flex_1()
|
||||
.child({
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let font_size = TextSize::Small
|
||||
.rems(cx)
|
||||
.to_pixels(settings.agent_font_size(cx));
|
||||
let line_height = settings.buffer_line_height.value() * font_size;
|
||||
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.buffer_font.family.clone(),
|
||||
font_fallbacks: settings.buffer_font.fallbacks.clone(),
|
||||
font_features: settings.buffer_font.features.clone(),
|
||||
font_size: font_size.into(),
|
||||
line_height: line_height.into(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
EditorElement::new(
|
||||
&self.editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
syntax: cx.theme().syntax().clone(),
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use agent_client_protocol as acp;
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use lsp::{CompletionContext, CompletionTriggerKind};
|
||||
use project::{CompletionIntent, Project};
|
||||
use serde_json::json;
|
||||
use util::path;
|
||||
use workspace::Workspace;
|
||||
|
||||
use crate::acp::{message_editor::MessageEditor, thread_view::tests::init_test};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_at_mention_removal(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree("/project", json!({"file": ""})).await;
|
||||
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
|
||||
|
||||
let (workspace, cx) =
|
||||
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
|
||||
|
||||
let message_editor = cx.update(|window, cx| {
|
||||
cx.new(|cx| MessageEditor::new(workspace.downgrade(), project.clone(), window, cx))
|
||||
});
|
||||
let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone());
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let excerpt_id = editor.update(cx, |editor, cx| {
|
||||
editor
|
||||
.buffer()
|
||||
.read(cx)
|
||||
.excerpt_ids()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
});
|
||||
let completions = editor.update_in(cx, |editor, window, cx| {
|
||||
editor.set_text("Hello @", window, cx);
|
||||
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
|
||||
let completion_provider = editor.completion_provider().unwrap();
|
||||
completion_provider.completions(
|
||||
excerpt_id,
|
||||
&buffer,
|
||||
text::Anchor::MAX,
|
||||
CompletionContext {
|
||||
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
|
||||
trigger_character: Some("@".into()),
|
||||
},
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let [_, completion]: [_; 2] = completions
|
||||
.await
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.flat_map(|response| response.completions)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let start = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
|
||||
.unwrap();
|
||||
let end = snapshot
|
||||
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
|
||||
.unwrap();
|
||||
editor.edit([(start..end, completion.new_text)], cx);
|
||||
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
// Backspace over the inserted crease (and the following space).
|
||||
editor.update_in(cx, |editor, window, cx| {
|
||||
editor.backspace(&Default::default(), window, cx);
|
||||
editor.backspace(&Default::default(), window, cx);
|
||||
});
|
||||
|
||||
let content = message_editor
|
||||
.update_in(cx, |message_editor, _window, cx| {
|
||||
message_editor.contents(cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// We don't send a resource link for the deleted crease.
|
||||
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
pub struct MessageHistory<T> {
|
||||
items: Vec<T>,
|
||||
current: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T> Default for MessageHistory<T> {
|
||||
fn default() -> Self {
|
||||
MessageHistory {
|
||||
items: Vec::new(),
|
||||
current: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> MessageHistory<T> {
|
||||
pub fn push(&mut self, message: T) {
|
||||
self.current.take();
|
||||
self.items.push(message);
|
||||
}
|
||||
|
||||
pub fn reset_position(&mut self) {
|
||||
self.current.take();
|
||||
}
|
||||
|
||||
pub fn prev(&mut self) -> Option<&T> {
|
||||
if self.items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let new_ix = self
|
||||
.current
|
||||
.get_or_insert(self.items.len())
|
||||
.saturating_sub(1);
|
||||
|
||||
self.current = Some(new_ix);
|
||||
self.items.get(new_ix)
|
||||
}
|
||||
|
||||
pub fn next(&mut self) -> Option<&T> {
|
||||
let current = self.current.as_mut()?;
|
||||
*current += 1;
|
||||
|
||||
self.items.get(*current).or_else(|| {
|
||||
self.current.take();
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn items(&self) -> &[T] {
|
||||
&self.items
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prev_next() {
|
||||
let mut history = MessageHistory::default();
|
||||
|
||||
// Test empty history
|
||||
assert_eq!(history.prev(), None);
|
||||
assert_eq!(history.next(), None);
|
||||
|
||||
// Add some messages
|
||||
history.push("first");
|
||||
history.push("second");
|
||||
history.push("third");
|
||||
|
||||
// Test prev navigation
|
||||
assert_eq!(history.prev(), Some(&"third"));
|
||||
assert_eq!(history.prev(), Some(&"second"));
|
||||
assert_eq!(history.prev(), Some(&"first"));
|
||||
assert_eq!(history.prev(), Some(&"first"));
|
||||
|
||||
assert_eq!(history.next(), Some(&"second"));
|
||||
|
||||
// Test mixed navigation
|
||||
history.push("fourth");
|
||||
assert_eq!(history.prev(), Some(&"fourth"));
|
||||
assert_eq!(history.prev(), Some(&"third"));
|
||||
assert_eq!(history.next(), Some(&"fourth"));
|
||||
assert_eq!(history.next(), None);
|
||||
|
||||
// Test that push resets navigation
|
||||
history.prev();
|
||||
history.prev();
|
||||
history.push("fifth");
|
||||
assert_eq!(history.prev(), Some(&"fifth"));
|
||||
}
|
||||
}
|
||||
472
crates/agent_ui/src/acp/model_selector.rs
Normal file
472
crates/agent_ui/src/acp/model_selector.rs
Normal file
@@ -0,0 +1,472 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use futures::FutureExt;
|
||||
use fuzzy::{StringMatchCandidate, match_strings};
|
||||
use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use ui::{
|
||||
AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
|
||||
prelude::*, rems,
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
|
||||
|
||||
pub fn acp_model_selector(
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<AcpModelSelector>,
|
||||
) -> AcpModelSelector {
|
||||
let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
|
||||
Picker::list(delegate, window, cx)
|
||||
.show_scrollbar(true)
|
||||
.width(rems(20.))
|
||||
.max_height(Some(rems(20.).into()))
|
||||
}
|
||||
|
||||
enum AcpModelPickerEntry {
|
||||
Separator(SharedString),
|
||||
Model(AgentModelInfo),
|
||||
}
|
||||
|
||||
pub struct AcpModelPickerDelegate {
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
filtered_entries: Vec<AcpModelPickerEntry>,
|
||||
models: Option<AgentModelList>,
|
||||
selected_index: usize,
|
||||
selected_model: Option<AgentModelInfo>,
|
||||
_refresh_models_task: Task<()>,
|
||||
}
|
||||
|
||||
impl AcpModelPickerDelegate {
|
||||
fn new(
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<AcpModelSelector>,
|
||||
) -> Self {
|
||||
let mut rx = selector.watch(cx);
|
||||
let refresh_models_task = cx.spawn_in(window, {
|
||||
let session_id = session_id.clone();
|
||||
async move |this, cx| {
|
||||
async fn refresh(
|
||||
this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
|
||||
session_id: &acp::SessionId,
|
||||
cx: &mut AsyncWindowContext,
|
||||
) -> Result<()> {
|
||||
let (models_task, selected_model_task) = this.update(cx, |this, cx| {
|
||||
(
|
||||
this.delegate.selector.list_models(cx),
|
||||
this.delegate.selector.selected_model(session_id, cx),
|
||||
)
|
||||
})?;
|
||||
|
||||
let (models, selected_model) = futures::join!(models_task, selected_model_task);
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate.models = models.ok();
|
||||
this.delegate.selected_model = selected_model.ok();
|
||||
this.delegate.update_matches(this.query(cx), window, cx)
|
||||
})?
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
refresh(&this, &session_id, cx).await.log_err();
|
||||
while let Ok(()) = rx.recv().await {
|
||||
refresh(&this, &session_id, cx).await.log_err();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
session_id,
|
||||
selector,
|
||||
filtered_entries: Vec::new(),
|
||||
models: None,
|
||||
selected_model: None,
|
||||
selected_index: 0,
|
||||
_refresh_models_task: refresh_models_task,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_model(&self) -> Option<&AgentModelInfo> {
|
||||
self.selected_model.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl PickerDelegate for AcpModelPickerDelegate {
|
||||
type ListItem = AnyElement;
|
||||
|
||||
fn match_count(&self) -> usize {
|
||||
self.filtered_entries.len()
|
||||
}
|
||||
|
||||
fn selected_index(&self) -> usize {
|
||||
self.selected_index
|
||||
}
|
||||
|
||||
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn can_select(
|
||||
&mut self,
|
||||
ix: usize,
|
||||
_window: &mut Window,
|
||||
_cx: &mut Context<Picker<Self>>,
|
||||
) -> bool {
|
||||
match self.filtered_entries.get(ix) {
|
||||
Some(AcpModelPickerEntry::Model(_)) => true,
|
||||
Some(AcpModelPickerEntry::Separator(_)) | None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
|
||||
"Select a model…".into()
|
||||
}
|
||||
|
||||
fn update_matches(
|
||||
&mut self,
|
||||
query: String,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Task<()> {
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let filtered_models = match this
|
||||
.read_with(cx, |this, cx| {
|
||||
this.delegate.models.clone().map(move |models| {
|
||||
fuzzy_search(models, query, cx.background_executor().clone())
|
||||
})
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
{
|
||||
Some(task) => task.await,
|
||||
None => AgentModelList::Flat(vec![]),
|
||||
};
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
this.delegate.filtered_entries =
|
||||
info_list_to_picker_entries(filtered_models).collect();
|
||||
// Finds the currently selected model in the list
|
||||
let new_index = this
|
||||
.delegate
|
||||
.selected_model
|
||||
.as_ref()
|
||||
.and_then(|selected| {
|
||||
this.delegate.filtered_entries.iter().position(|entry| {
|
||||
if let AcpModelPickerEntry::Model(model_info) = entry {
|
||||
model_info.id == selected.id
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or(0);
|
||||
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
})
|
||||
}
|
||||
|
||||
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
if let Some(AcpModelPickerEntry::Model(model_info)) =
|
||||
self.filtered_entries.get(self.selected_index)
|
||||
{
|
||||
self.selector
|
||||
.select_model(self.session_id.clone(), model_info.id.clone(), cx)
|
||||
.detach_and_log_err(cx);
|
||||
self.selected_model = Some(model_info.clone());
|
||||
let current_index = self.selected_index;
|
||||
self.set_selected_index(current_index, window, cx);
|
||||
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
}
|
||||
|
||||
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
|
||||
cx.emit(DismissEvent);
|
||||
}
|
||||
|
||||
fn render_match(
|
||||
&self,
|
||||
ix: usize,
|
||||
selected: bool,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<Self::ListItem> {
|
||||
match self.filtered_entries.get(ix)? {
|
||||
AcpModelPickerEntry::Separator(title) => Some(
|
||||
div()
|
||||
.px_2()
|
||||
.pb_1()
|
||||
.when(ix > 1, |this| {
|
||||
this.mt_1()
|
||||
.pt_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
})
|
||||
.child(
|
||||
Label::new(title)
|
||||
.size(LabelSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
AcpModelPickerEntry::Model(model_info) => {
|
||||
let is_selected = Some(model_info) == self.selected_model.as_ref();
|
||||
|
||||
let model_icon_color = if is_selected {
|
||||
Color::Accent
|
||||
} else {
|
||||
Color::Muted
|
||||
};
|
||||
|
||||
Some(
|
||||
ListItem::new(ix)
|
||||
.inset(true)
|
||||
.spacing(ListItemSpacing::Sparse)
|
||||
.toggle_state(selected)
|
||||
.start_slot::<Icon>(model_info.icon.map(|icon| {
|
||||
Icon::new(icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
}))
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.pl_0p5()
|
||||
.gap_1p5()
|
||||
.w(px(240.))
|
||||
.child(Label::new(model_info.name.clone()).truncate()),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
this.child(
|
||||
Icon::new(IconName::Check)
|
||||
.color(Color::Accent)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
}))
|
||||
.into_any_element(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_footer(
|
||||
&self,
|
||||
_: &mut Window,
|
||||
cx: &mut Context<Picker<Self>>,
|
||||
) -> Option<gpui::AnyElement> {
|
||||
Some(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.p_1()
|
||||
.gap_4()
|
||||
.justify_between()
|
||||
.child(
|
||||
Button::new("configure", "Configure")
|
||||
.icon(IconName::Settings)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(|_, window, cx| {
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenSettings.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
.into_any(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn info_list_to_picker_entries(
|
||||
model_list: AgentModelList,
|
||||
) -> impl Iterator<Item = AcpModelPickerEntry> {
|
||||
match model_list {
|
||||
AgentModelList::Flat(list) => {
|
||||
itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
|
||||
}
|
||||
AgentModelList::Grouped(index_map) => {
|
||||
itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
|
||||
std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
|
||||
.chain(models.into_iter().map(AcpModelPickerEntry::Model))
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn fuzzy_search(
|
||||
model_list: AgentModelList,
|
||||
query: String,
|
||||
executor: BackgroundExecutor,
|
||||
) -> AgentModelList {
|
||||
async fn fuzzy_search_list(
|
||||
model_list: Vec<AgentModelInfo>,
|
||||
query: &str,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Vec<AgentModelInfo> {
|
||||
let candidates = model_list
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(ix, model)| {
|
||||
StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut matches = match_strings(
|
||||
&candidates,
|
||||
&query,
|
||||
false,
|
||||
true,
|
||||
100,
|
||||
&Default::default(),
|
||||
executor,
|
||||
)
|
||||
.await;
|
||||
|
||||
matches.sort_unstable_by_key(|mat| {
|
||||
let candidate = &candidates[mat.candidate_id];
|
||||
(Reverse(OrderedFloat(mat.score)), candidate.id)
|
||||
});
|
||||
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|mat| model_list[mat.candidate_id].clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
match model_list {
|
||||
AgentModelList::Flat(model_list) => {
|
||||
AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
|
||||
}
|
||||
AgentModelList::Grouped(index_map) => {
|
||||
let groups =
|
||||
futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
|
||||
fuzzy_search_list(models, &query, executor.clone())
|
||||
.map(|results| (group_name, results))
|
||||
}))
|
||||
.await;
|
||||
AgentModelList::Grouped(IndexMap::from_iter(
|
||||
groups
|
||||
.into_iter()
|
||||
.filter(|(_, results)| !results.is_empty()),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::TestAppContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
|
||||
AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
|
||||
|(group, models)| {
|
||||
(
|
||||
acp_thread::AgentModelGroupName(group.to_string().into()),
|
||||
models
|
||||
.into_iter()
|
||||
.map(|model| acp_thread::AgentModelInfo {
|
||||
id: acp_thread::AgentModelId(model.to_string().into()),
|
||||
name: model.to_string().into(),
|
||||
icon: None,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
|
||||
let AgentModelList::Grouped(groups) = result else {
|
||||
panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
groups.len(),
|
||||
expected.len(),
|
||||
"Number of groups doesn't match"
|
||||
);
|
||||
|
||||
for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
|
||||
let (actual_group, actual_models) = groups.get_index(i).unwrap();
|
||||
assert_eq!(
|
||||
actual_group.0.as_ref(),
|
||||
*expected_group,
|
||||
"Group at position {} doesn't match expected group",
|
||||
i
|
||||
);
|
||||
assert_eq!(
|
||||
actual_models.len(),
|
||||
expected_models.len(),
|
||||
"Number of models in group {} doesn't match",
|
||||
expected_group
|
||||
);
|
||||
|
||||
for (j, expected_model_name) in expected_models.iter().enumerate() {
|
||||
assert_eq!(
|
||||
actual_models[j].name, *expected_model_name,
|
||||
"Model at position {} in group {} doesn't match expected model",
|
||||
j, expected_group
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fuzzy_match(cx: &mut TestAppContext) {
|
||||
let models = create_model_list(vec![
|
||||
(
|
||||
"zed",
|
||||
vec![
|
||||
"Claude 3.7 Sonnet",
|
||||
"Claude 3.7 Sonnet Thinking",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
),
|
||||
("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
|
||||
("ollama", vec!["mistral", "deepseek"]),
|
||||
]);
|
||||
|
||||
// Results should preserve models order whenever possible.
|
||||
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
|
||||
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
|
||||
// so it should appear first in the results.
|
||||
let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
|
||||
// Fuzzy search
|
||||
let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
|
||||
assert_models_eq(
|
||||
results,
|
||||
vec![
|
||||
("zed", vec!["gpt-4.1-nano"]),
|
||||
("openai", vec!["gpt-4.1-nano"]),
|
||||
],
|
||||
);
|
||||
}
|
||||
}
|
||||
85
crates/agent_ui/src/acp/model_selector_popover.rs
Normal file
85
crates/agent_ui/src/acp/model_selector_popover.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use std::rc::Rc;
|
||||
|
||||
use acp_thread::AgentModelSelector;
|
||||
use agent_client_protocol as acp;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
use picker::popover_menu::PickerPopoverMenu;
|
||||
use ui::{
|
||||
ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*,
|
||||
};
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
|
||||
use crate::acp::{AcpModelSelector, model_selector::acp_model_selector};
|
||||
|
||||
pub struct AcpModelSelectorPopover {
|
||||
selector: Entity<AcpModelSelector>,
|
||||
menu_handle: PopoverMenuHandle<AcpModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
|
||||
impl AcpModelSelectorPopover {
|
||||
pub(crate) fn new(
|
||||
session_id: acp::SessionId,
|
||||
selector: Rc<dyn AgentModelSelector>,
|
||||
menu_handle: PopoverMenuHandle<AcpModelSelector>,
|
||||
focus_handle: FocusHandle,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
Self {
|
||||
selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)),
|
||||
menu_handle,
|
||||
focus_handle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.menu_handle.toggle(window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for AcpModelSelectorPopover {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let model = self.selector.read(cx).delegate.active_model();
|
||||
let model_name = model
|
||||
.as_ref()
|
||||
.map(|model| model.name.clone())
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon);
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.when_some(model_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall))
|
||||
})
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
.color(Color::Muted)
|
||||
.size(LabelSize::Small)
|
||||
.ml_0p5(),
|
||||
)
|
||||
.child(
|
||||
Icon::new(IconName::ChevronDown)
|
||||
.color(Color::Muted)
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
move |window, cx| {
|
||||
Tooltip::for_action_in(
|
||||
"Change Model",
|
||||
&ToggleModelSelector,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
},
|
||||
gpui::Corner::BottomRight,
|
||||
cx,
|
||||
)
|
||||
.with_handle(self.menu_handle.clone())
|
||||
.render(window, cx)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1521,7 +1521,8 @@ impl AgentDiff {
|
||||
self.update_reviewing_editors(workspace, window, cx);
|
||||
}
|
||||
}
|
||||
AcpThreadEvent::Stopped
|
||||
AcpThreadEvent::EntriesRemoved(_)
|
||||
| AcpThreadEvent::Stopped
|
||||
| AcpThreadEvent::ToolAuthorizationRequired
|
||||
| AcpThreadEvent::Error
|
||||
| AcpThreadEvent::ServerExited(_) => {}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -64,6 +64,8 @@ actions!(
|
||||
NewTextThread,
|
||||
/// Toggles the context picker interface for adding files, symbols, or other context.
|
||||
ToggleContextPicker,
|
||||
/// Toggles the menu to create new agent threads.
|
||||
ToggleNewThreadMenu,
|
||||
/// Toggles the navigation menu for switching between threads and views.
|
||||
ToggleNavigationMenu,
|
||||
/// Toggles the options menu for agent settings and preferences.
|
||||
@@ -155,11 +157,11 @@ enum ExternalAgent {
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
pub fn server(&self) -> Rc<dyn agent_servers::AgentServer> {
|
||||
pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> {
|
||||
match self {
|
||||
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
|
||||
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
|
||||
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ mod agent_notification;
|
||||
mod burn_mode_tooltip;
|
||||
mod context_pill;
|
||||
mod end_trial_upsell;
|
||||
mod new_thread_button;
|
||||
// mod new_thread_button;
|
||||
mod onboarding_modal;
|
||||
pub mod preview;
|
||||
|
||||
@@ -10,5 +10,5 @@ pub use agent_notification::*;
|
||||
pub use burn_mode_tooltip::*;
|
||||
pub use context_pill::*;
|
||||
pub use end_trial_upsell::*;
|
||||
pub use new_thread_button::*;
|
||||
// pub use new_thread_button::*;
|
||||
pub use onboarding_modal::*;
|
||||
|
||||
@@ -11,7 +11,7 @@ pub struct NewThreadButton {
|
||||
}
|
||||
|
||||
impl NewThreadButton {
|
||||
pub fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
|
||||
fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
label: label.into(),
|
||||
@@ -21,12 +21,12 @@ impl NewThreadButton {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
|
||||
fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
|
||||
self.keybinding = keybinding;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn on_click<F>(mut self, handler: F) -> Self
|
||||
fn on_click<F>(mut self, handler: F) -> Self
|
||||
where
|
||||
F: Fn(&mut Window, &mut App) + 'static,
|
||||
{
|
||||
|
||||
@@ -86,7 +86,7 @@ impl Tool for DiagnosticsTool {
|
||||
input: serde_json::Value,
|
||||
_request: Arc<LanguageModelRequest>,
|
||||
project: Entity<Project>,
|
||||
action_log: Entity<ActionLog>,
|
||||
_action_log: Entity<ActionLog>,
|
||||
_model: Arc<dyn LanguageModel>,
|
||||
_window: Option<AnyWindowHandle>,
|
||||
cx: &mut App,
|
||||
@@ -159,10 +159,6 @@ impl Tool for DiagnosticsTool {
|
||||
}
|
||||
}
|
||||
|
||||
action_log.update(cx, |action_log, _cx| {
|
||||
action_log.checked_project_diagnostics();
|
||||
});
|
||||
|
||||
if has_diagnostics {
|
||||
Task::ready(Ok(output.into())).into()
|
||||
} else {
|
||||
|
||||
@@ -65,7 +65,7 @@ pub enum EditAgentOutputEvent {
|
||||
ResolvingEditRange(Range<Anchor>),
|
||||
UnresolvedEditRange,
|
||||
AmbiguousEditRange(Vec<Range<usize>>),
|
||||
Edited,
|
||||
Edited(Range<Anchor>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -178,7 +178,9 @@ impl EditAgent {
|
||||
)
|
||||
});
|
||||
output_events_tx
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited(
|
||||
language::Anchor::MIN..language::Anchor::MAX,
|
||||
))
|
||||
.ok();
|
||||
})?;
|
||||
|
||||
@@ -200,7 +202,9 @@ impl EditAgent {
|
||||
});
|
||||
})?;
|
||||
output_events_tx
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited(
|
||||
language::Anchor::MIN..language::Anchor::MAX,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
@@ -336,8 +340,8 @@ impl EditAgent {
|
||||
// Edit the buffer and report edits to the action log as part of the
|
||||
// same effect cycle, otherwise the edit will be reported as if the
|
||||
// user made it.
|
||||
cx.update(|cx| {
|
||||
let max_edit_end = buffer.update(cx, |buffer, cx| {
|
||||
let (min_edit_start, max_edit_end) = cx.update(|cx| {
|
||||
let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(edits.iter().cloned(), None, cx);
|
||||
let max_edit_end = buffer
|
||||
.summaries_for_anchors::<Point, _>(
|
||||
@@ -345,7 +349,16 @@ impl EditAgent {
|
||||
)
|
||||
.max()
|
||||
.unwrap();
|
||||
buffer.anchor_before(max_edit_end)
|
||||
let min_edit_start = buffer
|
||||
.summaries_for_anchors::<Point, _>(
|
||||
edits.iter().map(|(range, _)| &range.start),
|
||||
)
|
||||
.min()
|
||||
.unwrap();
|
||||
(
|
||||
buffer.anchor_after(min_edit_start),
|
||||
buffer.anchor_before(max_edit_end),
|
||||
)
|
||||
});
|
||||
self.action_log
|
||||
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
|
||||
@@ -358,9 +371,10 @@ impl EditAgent {
|
||||
cx,
|
||||
);
|
||||
});
|
||||
(min_edit_start, max_edit_end)
|
||||
})?;
|
||||
output_events
|
||||
.unbounded_send(EditAgentOutputEvent::Edited)
|
||||
.unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end))
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -755,6 +769,7 @@ mod tests {
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
use indoc::indoc;
|
||||
use language_model::fake_provider::FakeLanguageModel;
|
||||
use pretty_assertions::assert_matches;
|
||||
use project::{AgentLocation, Project};
|
||||
use rand::prelude::*;
|
||||
use rand::rngs::StdRng;
|
||||
@@ -992,7 +1007,10 @@ mod tests {
|
||||
|
||||
model.send_last_completion_stream_text_chunk("<new_text>abX");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited(_)]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
"abXc\ndef\nghi\njkl"
|
||||
@@ -1007,7 +1025,10 @@ mod tests {
|
||||
|
||||
model.send_last_completion_stream_text_chunk("cY");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
"abXcY\ndef\nghi\njkl"
|
||||
@@ -1118,9 +1139,9 @@ mod tests {
|
||||
|
||||
model.send_last_completion_stream_text_chunk("GHI</new_text>");
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1165,9 +1186,9 @@ mod tests {
|
||||
);
|
||||
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited(_)]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1183,9 +1204,9 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("```\njkl\n").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1201,9 +1222,9 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("mno\n").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited { .. }]
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
@@ -1219,9 +1240,9 @@ mod tests {
|
||||
|
||||
chunks_tx.unbounded_send("pqr\n```").unwrap();
|
||||
cx.run_until_parked();
|
||||
assert_eq!(
|
||||
drain_events(&mut events),
|
||||
vec![EditAgentOutputEvent::Edited]
|
||||
assert_matches!(
|
||||
drain_events(&mut events).as_slice(),
|
||||
[EditAgentOutputEvent::Edited(_)],
|
||||
);
|
||||
assert_eq!(
|
||||
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
|
||||
|
||||
@@ -307,7 +307,7 @@ impl Tool for EditFileTool {
|
||||
let mut ambiguous_ranges = Vec::new();
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
EditAgentOutputEvent::Edited => {
|
||||
EditAgentOutputEvent::Edited { .. } => {
|
||||
if let Some(card) = card_clone.as_ref() {
|
||||
card.update(cx, |card, cx| card.update_diff(cx))?;
|
||||
}
|
||||
|
||||
@@ -18,6 +18,6 @@ collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
|
||||
rodio = { workspace = true, features = ["wav", "playback", "tracing"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -59,16 +59,9 @@ pub enum VersionCheckType {
|
||||
pub enum AutoUpdateStatus {
|
||||
Idle,
|
||||
Checking,
|
||||
Downloading {
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Installing {
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Updated {
|
||||
binary_path: PathBuf,
|
||||
version: VersionCheckType,
|
||||
},
|
||||
Downloading { version: VersionCheckType },
|
||||
Installing { version: VersionCheckType },
|
||||
Updated { version: VersionCheckType },
|
||||
Errored,
|
||||
}
|
||||
|
||||
@@ -83,6 +76,7 @@ pub struct AutoUpdater {
|
||||
current_version: SemanticVersion,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
pending_poll: Option<Task<Option<()>>>,
|
||||
quit_subscription: Option<gpui::Subscription>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
@@ -164,7 +158,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
AutoUpdateSetting::register(cx);
|
||||
|
||||
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
|
||||
workspace.register_action(|_, action: &Check, window, cx| check(action, window, cx));
|
||||
workspace.register_action(|_, action, window, cx| check(action, window, cx));
|
||||
|
||||
workspace.register_action(|_, action, _, cx| {
|
||||
view_release_notes(action, cx);
|
||||
@@ -174,7 +168,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
|
||||
|
||||
let version = release_channel::AppVersion::global(cx);
|
||||
let auto_updater = cx.new(|cx| {
|
||||
let updater = AutoUpdater::new(version, http_client);
|
||||
let updater = AutoUpdater::new(version, http_client, cx);
|
||||
|
||||
let poll_for_updates = ReleaseChannel::try_global(cx)
|
||||
.map(|channel| channel.poll_for_updates())
|
||||
@@ -321,12 +315,34 @@ impl AutoUpdater {
|
||||
cx.default_global::<GlobalAutoUpdate>().0.clone()
|
||||
}
|
||||
|
||||
fn new(current_version: SemanticVersion, http_client: Arc<HttpClientWithUrl>) -> Self {
|
||||
fn new(
|
||||
current_version: SemanticVersion,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
// On windows, executable files cannot be overwritten while they are
|
||||
// running, so we must wait to overwrite the application until quitting
|
||||
// or restarting. When quitting the app, we spawn the auto update helper
|
||||
// to finish the auto update process after Zed exits. When restarting
|
||||
// the app after an update, we use `set_restart_path` to run the auto
|
||||
// update helper instead of the app, so that it can overwrite the app
|
||||
// and then spawn the new binary.
|
||||
let quit_subscription = Some(cx.on_app_quit(|_, _| async move {
|
||||
#[cfg(target_os = "windows")]
|
||||
finalize_auto_update_on_quit();
|
||||
}));
|
||||
|
||||
cx.on_app_restart(|this, _| {
|
||||
this.quit_subscription.take();
|
||||
})
|
||||
.detach();
|
||||
|
||||
Self {
|
||||
status: AutoUpdateStatus::Idle,
|
||||
current_version,
|
||||
http_client,
|
||||
pending_poll: None,
|
||||
quit_subscription,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,6 +552,8 @@ impl AutoUpdater {
|
||||
)
|
||||
})?;
|
||||
|
||||
Self::check_dependencies()?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.status = AutoUpdateStatus::Checking;
|
||||
cx.notify();
|
||||
@@ -582,13 +600,15 @@ impl AutoUpdater {
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?;
|
||||
let new_binary_path = Self::install_release(installer_dir, target_path, &cx).await?;
|
||||
if let Some(new_binary_path) = new_binary_path {
|
||||
cx.update(|cx| cx.set_restart_path(new_binary_path))?;
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.set_should_show_update_notification(true, cx)
|
||||
.detach_and_log_err(cx);
|
||||
this.status = AutoUpdateStatus::Updated {
|
||||
binary_path,
|
||||
version: newer_version,
|
||||
};
|
||||
cx.notify();
|
||||
@@ -639,6 +659,15 @@ impl AutoUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_dependencies() -> Result<()> {
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
anyhow::ensure!(
|
||||
which::which("rsync").is_ok(),
|
||||
"Aborting. Could not find rsync which is required for auto-updates."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn target_path(installer_dir: &InstallerDir) -> Result<PathBuf> {
|
||||
let filename = match OS {
|
||||
"macos" => anyhow::Ok("Zed.dmg"),
|
||||
@@ -647,20 +676,14 @@ impl AutoUpdater {
|
||||
unsupported_os => anyhow::bail!("not supported: {unsupported_os}"),
|
||||
}?;
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
anyhow::ensure!(
|
||||
which::which("rsync").is_ok(),
|
||||
"Aborting. Could not find rsync which is required for auto-updates."
|
||||
);
|
||||
|
||||
Ok(installer_dir.path().join(filename))
|
||||
}
|
||||
|
||||
async fn binary_path(
|
||||
async fn install_release(
|
||||
installer_dir: InstallerDir,
|
||||
target_path: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
) -> Result<Option<PathBuf>> {
|
||||
match OS {
|
||||
"macos" => install_release_macos(&installer_dir, target_path, cx).await,
|
||||
"linux" => install_release_linux(&installer_dir, target_path, cx).await,
|
||||
@@ -801,7 +824,7 @@ async fn install_release_linux(
|
||||
temp_dir: &InstallerDir,
|
||||
downloaded_tar_gz: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
) -> Result<Option<PathBuf>> {
|
||||
let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?;
|
||||
let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?);
|
||||
let running_app_path = cx.update(|cx| cx.app_path())??;
|
||||
@@ -861,14 +884,14 @@ async fn install_release_linux(
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
Ok(to.join(expected_suffix))
|
||||
Ok(Some(to.join(expected_suffix)))
|
||||
}
|
||||
|
||||
async fn install_release_macos(
|
||||
temp_dir: &InstallerDir,
|
||||
downloaded_dmg: PathBuf,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<PathBuf> {
|
||||
) -> Result<Option<PathBuf>> {
|
||||
let running_app_path = cx.update(|cx| cx.app_path())??;
|
||||
let running_app_filename = running_app_path
|
||||
.file_name()
|
||||
@@ -910,10 +933,10 @@ async fn install_release_macos(
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
Ok(running_app_path)
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBuf> {
|
||||
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<Option<PathBuf>> {
|
||||
let output = Command::new(downloaded_installer)
|
||||
.arg("/verysilent")
|
||||
.arg("/update=true")
|
||||
@@ -926,29 +949,36 @@ async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBu
|
||||
"failed to start installer: {:?}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
Ok(std::env::current_exe()?)
|
||||
// We return the path to the update helper program, because it will
|
||||
// perform the final steps of the update process, copying the new binary,
|
||||
// deleting the old one, and launching the new binary.
|
||||
let helper_path = std::env::current_exe()?
|
||||
.parent()
|
||||
.context("No parent dir for Zed.exe")?
|
||||
.join("tools\\auto_update_helper.exe");
|
||||
Ok(Some(helper_path))
|
||||
}
|
||||
|
||||
pub fn check_pending_installation() -> bool {
|
||||
pub fn finalize_auto_update_on_quit() {
|
||||
let Some(installer_path) = std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|p| p.join("updates")))
|
||||
else {
|
||||
return false;
|
||||
return;
|
||||
};
|
||||
|
||||
// The installer will create a flag file after it finishes updating
|
||||
let flag_file = installer_path.join("versions.txt");
|
||||
if flag_file.exists() {
|
||||
if let Some(helper) = installer_path
|
||||
if flag_file.exists()
|
||||
&& let Some(helper) = installer_path
|
||||
.parent()
|
||||
.map(|p| p.join("tools\\auto_update_helper.exe"))
|
||||
{
|
||||
let _ = std::process::Command::new(helper).spawn();
|
||||
return true;
|
||||
}
|
||||
{
|
||||
let mut command = std::process::Command::new(helper);
|
||||
command.arg("--launch");
|
||||
command.arg("false");
|
||||
let _ = command.spawn();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -1002,7 +1032,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
|
||||
};
|
||||
let fetched_version = SemanticVersion::new(1, 0, 1);
|
||||
@@ -1024,7 +1053,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
|
||||
};
|
||||
let fetched_version = SemanticVersion::new(1, 0, 2);
|
||||
@@ -1090,7 +1118,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "b".to_string();
|
||||
@@ -1112,7 +1139,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(Some("a".to_string()));
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "c".to_string();
|
||||
@@ -1160,7 +1186,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(None);
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "b".to_string();
|
||||
@@ -1183,7 +1208,6 @@ mod tests {
|
||||
let app_commit_sha = Ok(None);
|
||||
let installed_version = SemanticVersion::new(1, 0, 0);
|
||||
let status = AutoUpdateStatus::Updated {
|
||||
binary_path: PathBuf::new(),
|
||||
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
|
||||
};
|
||||
let fetched_sha = "c".to_string();
|
||||
|
||||
@@ -37,6 +37,11 @@ mod windows_impl {
|
||||
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
|
||||
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Args {
|
||||
launch: Option<bool>,
|
||||
}
|
||||
|
||||
pub(crate) fn run() -> Result<()> {
|
||||
let helper_dir = std::env::current_exe()?
|
||||
.parent()
|
||||
@@ -51,8 +56,9 @@ mod windows_impl {
|
||||
log::info!("======= Starting Zed update =======");
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let hwnd = create_dialog_window(rx)?.0 as isize;
|
||||
let args = parse_args();
|
||||
std::thread::spawn(move || {
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd));
|
||||
let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch.unwrap_or(true));
|
||||
tx.send(result).ok();
|
||||
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
|
||||
});
|
||||
@@ -77,6 +83,41 @@ mod windows_impl {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_args() -> Args {
|
||||
let mut result = Args { launch: None };
|
||||
if let Some(candidate) = std::env::args().nth(1) {
|
||||
parse_single_arg(&candidate, &mut result);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn parse_single_arg(arg: &str, result: &mut Args) {
|
||||
let Some((key, value)) = arg.strip_prefix("--").and_then(|arg| arg.split_once('=')) else {
|
||||
log::error!(
|
||||
"Invalid argument format: '{}'. Expected format: --key=value",
|
||||
arg
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
match key {
|
||||
"launch" => parse_launch_arg(value, &mut result.launch),
|
||||
_ => log::error!("Unknown argument: --{}", key),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_launch_arg(value: &str, arg: &mut Option<bool>) {
|
||||
match value {
|
||||
"true" => *arg = Some(true),
|
||||
"false" => *arg = Some(false),
|
||||
_ => log::error!(
|
||||
"Invalid value for --launch: '{}'. Expected 'true' or 'false'",
|
||||
value
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn show_error(mut content: String) {
|
||||
if content.len() > 600 {
|
||||
content.truncate(600);
|
||||
@@ -91,4 +132,47 @@ mod windows_impl {
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::windows_impl::{Args, parse_launch_arg, parse_single_arg};
|
||||
|
||||
#[test]
|
||||
fn test_parse_launch_arg() {
|
||||
let mut arg = None;
|
||||
parse_launch_arg("true", &mut arg);
|
||||
assert_eq!(arg, Some(true));
|
||||
|
||||
let mut arg = None;
|
||||
parse_launch_arg("false", &mut arg);
|
||||
assert_eq!(arg, Some(false));
|
||||
|
||||
let mut arg = None;
|
||||
parse_launch_arg("invalid", &mut arg);
|
||||
assert_eq!(arg, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_single_arg() {
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=true", &mut args);
|
||||
assert_eq!(args.launch, Some(true));
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=false", &mut args);
|
||||
assert_eq!(args.launch, Some(false));
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch=invalid", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--launch", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
|
||||
let mut args = Args { launch: None };
|
||||
parse_single_arg("--unknown", &mut args);
|
||||
assert_eq!(args.launch, None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ pub(crate) fn create_dialog_window(receiver: Receiver<Result<()>>) -> Result<HWN
|
||||
let hwnd = CreateWindowExW(
|
||||
WS_EX_TOPMOST,
|
||||
class_name,
|
||||
windows::core::w!("Zed Editor"),
|
||||
windows::core::w!("Zed"),
|
||||
WS_VISIBLE | WS_POPUP | WS_CAPTION,
|
||||
rect.right / 2 - width / 2,
|
||||
rect.bottom / 2 - height / 2,
|
||||
@@ -171,7 +171,7 @@ unsafe extern "system" fn wnd_proc(
|
||||
&HSTRING::from(font_name),
|
||||
);
|
||||
let temp = SelectObject(hdc, font.into());
|
||||
let string = HSTRING::from("Zed Editor is updating...");
|
||||
let string = HSTRING::from("Updating Zed...");
|
||||
return_if_failed!(TextOutW(hdc, 20, 15, &string).ok());
|
||||
return_if_failed!(DeleteObject(temp).ok());
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ pub(crate) const JOBS: [Job; 2] = [
|
||||
},
|
||||
];
|
||||
|
||||
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()> {
|
||||
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool) -> Result<()> {
|
||||
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
|
||||
|
||||
for job in JOBS.iter() {
|
||||
@@ -145,9 +145,11 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()>
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
|
||||
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
|
||||
.spawn();
|
||||
if launch {
|
||||
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
|
||||
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
|
||||
.spawn();
|
||||
}
|
||||
log::info!("Update completed successfully");
|
||||
Ok(())
|
||||
}
|
||||
@@ -159,11 +161,11 @@ mod test {
|
||||
#[test]
|
||||
fn test_perform_update() {
|
||||
let app_dir = std::path::Path::new("C:/");
|
||||
assert!(perform_update(app_dir, None).is_ok());
|
||||
assert!(perform_update(app_dir, None, false).is_ok());
|
||||
|
||||
// Simulate a timeout
|
||||
unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") };
|
||||
let ret = perform_update(app_dir, None);
|
||||
let ret = perform_update(app_dir, None, false);
|
||||
assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -957,17 +957,14 @@ mod mac_os {
|
||||
) -> Result<()> {
|
||||
use anyhow::bail;
|
||||
|
||||
let app_id_prompt = format!("id of app \"{}\"", channel.display_name());
|
||||
let app_id_output = Command::new("osascript")
|
||||
let app_path_prompt = format!(
|
||||
"POSIX path of (path to application \"{}\")",
|
||||
channel.display_name()
|
||||
);
|
||||
let app_path_output = Command::new("osascript")
|
||||
.arg("-e")
|
||||
.arg(&app_id_prompt)
|
||||
.arg(&app_path_prompt)
|
||||
.output()?;
|
||||
if !app_id_output.status.success() {
|
||||
bail!("Could not determine app id for {}", channel.display_name());
|
||||
}
|
||||
let app_name = String::from_utf8(app_id_output.stdout)?.trim().to_owned();
|
||||
let app_path_prompt = format!("kMDItemCFBundleIdentifier == '{app_name}'");
|
||||
let app_path_output = Command::new("mdfind").arg(app_path_prompt).output()?;
|
||||
if !app_path_output.status.success() {
|
||||
bail!(
|
||||
"Could not determine app path for {}",
|
||||
|
||||
@@ -340,22 +340,35 @@ impl Telemetry {
|
||||
}
|
||||
|
||||
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str, is_via_ssh: bool) {
|
||||
static LAST_EVENT_TIME: Mutex<Option<Instant>> = Mutex::new(None);
|
||||
|
||||
let mut state = self.state.lock();
|
||||
let period_data = state.event_coalescer.log_event(environment);
|
||||
drop(state);
|
||||
|
||||
if let Some((start, end, environment)) = period_data {
|
||||
let duration = end
|
||||
.saturating_duration_since(start)
|
||||
.min(Duration::from_secs(60 * 60 * 24))
|
||||
.as_millis() as i64;
|
||||
if let Some(mut last_event) = LAST_EVENT_TIME.try_lock() {
|
||||
let current_time = std::time::Instant::now();
|
||||
let last_time = last_event.get_or_insert(current_time);
|
||||
|
||||
telemetry::event!(
|
||||
"Editor Edited",
|
||||
duration = duration,
|
||||
environment = environment,
|
||||
is_via_ssh = is_via_ssh
|
||||
);
|
||||
if current_time.duration_since(*last_time) > Duration::from_secs(60 * 10) {
|
||||
*last_time = current_time;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some((start, end, environment)) = period_data {
|
||||
let duration = end
|
||||
.saturating_duration_since(start)
|
||||
.min(Duration::from_secs(60 * 60 * 24))
|
||||
.as_millis() as i64;
|
||||
|
||||
telemetry::event!(
|
||||
"Editor Edited",
|
||||
duration = duration,
|
||||
environment = environment,
|
||||
is_via_ssh = is_via_ssh
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ use language::{
|
||||
point_from_lsp, point_to_lsp,
|
||||
};
|
||||
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
|
||||
use node_runtime::NodeRuntime;
|
||||
use node_runtime::{NodeRuntime, VersionCheck};
|
||||
use parking_lot::Mutex;
|
||||
use project::DisableAiSettings;
|
||||
use request::StatusNotification;
|
||||
@@ -1169,9 +1169,8 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
|
||||
const SERVER_PATH: &str =
|
||||
"node_modules/@github/copilot-language-server/dist/language-server.js";
|
||||
|
||||
let latest_version = node_runtime
|
||||
.npm_package_latest_version(PACKAGE_NAME)
|
||||
.await?;
|
||||
// pinning it: https://github.com/zed-industries/zed/issues/36093
|
||||
const PINNED_VERSION: &str = "1.354";
|
||||
let server_path = paths::copilot_dir().join(SERVER_PATH);
|
||||
|
||||
fs.create_dir(paths::copilot_dir()).await?;
|
||||
@@ -1181,12 +1180,13 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
|
||||
PACKAGE_NAME,
|
||||
&server_path,
|
||||
paths::copilot_dir(),
|
||||
&latest_version,
|
||||
&PINNED_VERSION,
|
||||
VersionCheck::VersionMismatch,
|
||||
)
|
||||
.await;
|
||||
if should_install {
|
||||
node_runtime
|
||||
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &latest_version)])
|
||||
.npm_install_packages(paths::copilot_dir(), &[(PACKAGE_NAME, &PINNED_VERSION)])
|
||||
.await?;
|
||||
}
|
||||
|
||||
|
||||
@@ -250,6 +250,24 @@ pub type RenderDiffHunkControlsFn = Arc<
|
||||
) -> AnyElement,
|
||||
>;
|
||||
|
||||
enum ReportEditorEvent {
|
||||
Saved { auto_saved: bool },
|
||||
EditorOpened,
|
||||
ZetaTosClicked,
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl ReportEditorEvent {
|
||||
pub fn event_type(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Saved { .. } => "Editor Saved",
|
||||
Self::EditorOpened => "Editor Opened",
|
||||
Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked",
|
||||
Self::Closed => "Editor Closed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct InlineValueCache {
|
||||
enabled: bool,
|
||||
inlays: Vec<InlayId>,
|
||||
@@ -2325,7 +2343,7 @@ impl Editor {
|
||||
}
|
||||
|
||||
if editor.mode.is_full() {
|
||||
editor.report_editor_event("Editor Opened", None, cx);
|
||||
editor.report_editor_event(ReportEditorEvent::EditorOpened, None, cx);
|
||||
}
|
||||
|
||||
editor
|
||||
@@ -9124,7 +9142,7 @@ impl Editor {
|
||||
.on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default())
|
||||
.on_click(cx.listener(|this, _event, window, cx| {
|
||||
cx.stop_propagation();
|
||||
this.report_editor_event("Edit Prediction Provider ToS Clicked", None, cx);
|
||||
this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx);
|
||||
window.dispatch_action(
|
||||
zed_actions::OpenZedPredictOnboarding.boxed_clone(),
|
||||
cx,
|
||||
@@ -20547,7 +20565,7 @@ impl Editor {
|
||||
|
||||
fn report_editor_event(
|
||||
&self,
|
||||
event_type: &'static str,
|
||||
reported_event: ReportEditorEvent,
|
||||
file_extension: Option<String>,
|
||||
cx: &App,
|
||||
) {
|
||||
@@ -20581,15 +20599,30 @@ impl Editor {
|
||||
.show_edit_predictions;
|
||||
|
||||
let project = project.read(cx);
|
||||
telemetry::event!(
|
||||
event_type,
|
||||
file_extension,
|
||||
vim_mode,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
edit_predictions_provider,
|
||||
is_via_ssh = project.is_via_ssh(),
|
||||
);
|
||||
let event_type = reported_event.event_type();
|
||||
|
||||
if let ReportEditorEvent::Saved { auto_saved } = reported_event {
|
||||
telemetry::event!(
|
||||
event_type,
|
||||
type = if auto_saved {"autosave"} else {"manual"},
|
||||
file_extension,
|
||||
vim_mode,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
edit_predictions_provider,
|
||||
is_via_ssh = project.is_via_ssh(),
|
||||
);
|
||||
} else {
|
||||
telemetry::event!(
|
||||
event_type,
|
||||
file_extension,
|
||||
vim_mode,
|
||||
copilot_enabled,
|
||||
copilot_enabled_for_language,
|
||||
edit_predictions_provider,
|
||||
is_via_ssh = project.is_via_ssh(),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
/// Copy the highlighted chunks to the clipboard as JSON. The format is an array of lines,
|
||||
|
||||
@@ -22456,7 +22456,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
cx.update(|_, cx| {
|
||||
workspace::reload(&workspace::Reload::default(), cx);
|
||||
workspace::reload(cx);
|
||||
});
|
||||
assert_language_servers_count(
|
||||
1,
|
||||
|
||||
@@ -3011,7 +3011,7 @@ impl EditorElement {
|
||||
.icon_color(Color::Custom(cx.theme().colors().editor_line_number))
|
||||
.selected_icon_color(Color::Custom(cx.theme().colors().editor_foreground))
|
||||
.icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size())))
|
||||
.width(width.into())
|
||||
.width(width)
|
||||
.on_click(move |_, window, cx| {
|
||||
editor.update(cx, |editor, cx| {
|
||||
editor.expand_excerpt(excerpt_id, direction, window, cx);
|
||||
@@ -3627,7 +3627,7 @@ impl EditorElement {
|
||||
ButtonLike::new("toggle-buffer-fold")
|
||||
.style(ui::ButtonStyle::Transparent)
|
||||
.height(px(28.).into())
|
||||
.width(px(28.).into())
|
||||
.width(px(28.))
|
||||
.children(toggle_chevron_icon)
|
||||
.tooltip({
|
||||
let focus_handle = focus_handle.clone();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
Anchor, Autoscroll, Editor, EditorEvent, EditorSettings, ExcerptId, ExcerptRange, FormatTarget,
|
||||
MultiBuffer, MultiBufferSnapshot, NavigationData, SearchWithinRange, SelectionEffects,
|
||||
ToPoint as _,
|
||||
MultiBuffer, MultiBufferSnapshot, NavigationData, ReportEditorEvent, SearchWithinRange,
|
||||
SelectionEffects, ToPoint as _,
|
||||
display_map::HighlightKey,
|
||||
editor_settings::SeedQuerySetting,
|
||||
persistence::{DB, SerializedEditor},
|
||||
@@ -776,6 +776,10 @@ impl Item for Editor {
|
||||
}
|
||||
}
|
||||
|
||||
fn on_removed(&self, cx: &App) {
|
||||
self.report_editor_event(ReportEditorEvent::Closed, None, cx);
|
||||
}
|
||||
|
||||
fn deactivated(&mut self, _: &mut Window, cx: &mut Context<Self>) {
|
||||
let selection = self.selections.newest_anchor();
|
||||
self.push_to_nav_history(selection.head(), None, true, false, cx);
|
||||
@@ -815,9 +819,9 @@ impl Item for Editor {
|
||||
) -> Task<Result<()>> {
|
||||
// Add meta data tracking # of auto saves
|
||||
if options.autosave {
|
||||
self.report_editor_event("Editor Autosaved", None, cx);
|
||||
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: true }, None, cx);
|
||||
} else {
|
||||
self.report_editor_event("Editor Saved", None, cx);
|
||||
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: false }, None, cx);
|
||||
}
|
||||
|
||||
let buffers = self.buffer().clone().read(cx).all_buffers();
|
||||
@@ -896,7 +900,11 @@ impl Item for Editor {
|
||||
.path
|
||||
.extension()
|
||||
.map(|a| a.to_string_lossy().to_string());
|
||||
self.report_editor_event("Editor Saved", file_extension, cx);
|
||||
self.report_editor_event(
|
||||
ReportEditorEvent::Saved { auto_saved: false },
|
||||
file_extension,
|
||||
cx,
|
||||
);
|
||||
|
||||
project.update(cx, |project, cx| project.save_buffer_as(buffer, path, cx))
|
||||
}
|
||||
@@ -997,12 +1005,16 @@ impl Item for Editor {
|
||||
) {
|
||||
self.workspace = Some((workspace.weak_handle(), workspace.database_id()));
|
||||
if let Some(workspace) = &workspace.weak_handle().upgrade() {
|
||||
cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| {
|
||||
if matches!(event, workspace::Event::ModalOpened) {
|
||||
editor.mouse_context_menu.take();
|
||||
editor.inline_blame_popover.take();
|
||||
}
|
||||
})
|
||||
cx.subscribe(
|
||||
&workspace,
|
||||
|editor, _, event: &workspace::Event, _cx| match event {
|
||||
workspace::Event::ModalOpened => {
|
||||
editor.mouse_context_menu.take();
|
||||
editor.inline_blame_popover.take();
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
)
|
||||
.detach();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1118,15 +1118,17 @@ impl ExtensionStore {
|
||||
extensions_to_unload.len() - reload_count
|
||||
);
|
||||
|
||||
for extension_id in &extensions_to_load {
|
||||
if let Some(extension) = new_index.extensions.get(extension_id) {
|
||||
telemetry::event!(
|
||||
"Extension Loaded",
|
||||
extension_id,
|
||||
version = extension.manifest.version
|
||||
);
|
||||
}
|
||||
}
|
||||
let extension_ids = extensions_to_load
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
Some((
|
||||
id.clone(),
|
||||
new_index.extensions.get(id)?.manifest.version.clone(),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
telemetry::event!("Extensions Loaded", id_and_versions = extension_ids);
|
||||
|
||||
let themes_to_remove = old_index
|
||||
.themes
|
||||
|
||||
@@ -33,13 +33,23 @@ impl FileIcons {
|
||||
// TODO: Associate a type with the languages and have the file's language
|
||||
// override these associations
|
||||
|
||||
// check if file name is in suffixes
|
||||
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
|
||||
if let Some(typ) = path.file_name().and_then(|typ| typ.to_str()) {
|
||||
if let Some(mut typ) = path.file_name().and_then(|typ| typ.to_str()) {
|
||||
// check if file name is in suffixes
|
||||
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
|
||||
let maybe_path = get_icon_from_suffix(typ);
|
||||
if maybe_path.is_some() {
|
||||
return maybe_path;
|
||||
}
|
||||
|
||||
// check if suffix based on first dot is in suffixes
|
||||
// e.g. consider `module.js` as suffix to angular's module file named `auth.module.js`
|
||||
while let Some((_, suffix)) = typ.split_once('.') {
|
||||
let maybe_path = get_icon_from_suffix(suffix);
|
||||
if maybe_path.is_some() {
|
||||
return maybe_path;
|
||||
}
|
||||
typ = suffix;
|
||||
}
|
||||
}
|
||||
|
||||
// primary case: check if the files extension or the hidden file name
|
||||
|
||||
@@ -51,6 +51,7 @@ ashpd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
git = { workspace = true, features = ["test-support"] }
|
||||
|
||||
[features]
|
||||
test-support = ["gpui/test-support", "git/test-support"]
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{FakeFs, Fs};
|
||||
use crate::{FakeFs, FakeFsEntry, Fs};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::{HashMap, HashSet};
|
||||
use futures::future::{self, BoxFuture, join_all};
|
||||
use git::{
|
||||
Oid,
|
||||
blame::Blame,
|
||||
repository::{
|
||||
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
|
||||
@@ -10,8 +11,9 @@ use git::{
|
||||
},
|
||||
status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus},
|
||||
};
|
||||
use gpui::{AsyncApp, BackgroundExecutor, SharedString};
|
||||
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
|
||||
use ignore::gitignore::GitignoreBuilder;
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
use smol::future::FutureExt as _;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
@@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc};
|
||||
#[derive(Clone)]
|
||||
pub struct FakeGitRepository {
|
||||
pub(crate) fs: Arc<FakeFs>,
|
||||
pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
|
||||
pub(crate) executor: BackgroundExecutor,
|
||||
pub(crate) dot_git_path: PathBuf,
|
||||
pub(crate) repository_dir_path: PathBuf,
|
||||
@@ -183,7 +186,7 @@ impl GitRepository for FakeGitRepository {
|
||||
async move { None }.boxed()
|
||||
}
|
||||
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
|
||||
let workdir_path = self.dot_git_path.parent().unwrap();
|
||||
|
||||
// Load gitignores
|
||||
@@ -311,7 +314,10 @@ impl GitRepository for FakeGitRepository {
|
||||
entries: entries.into(),
|
||||
})
|
||||
});
|
||||
async move { result? }.boxed()
|
||||
Task::ready(match result {
|
||||
Ok(result) => result,
|
||||
Err(e) => Err(e),
|
||||
})
|
||||
}
|
||||
|
||||
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
|
||||
@@ -466,22 +472,57 @@ impl GitRepository for FakeGitRepository {
|
||||
}
|
||||
|
||||
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
|
||||
unimplemented!()
|
||||
let executor = self.executor.clone();
|
||||
let fs = self.fs.clone();
|
||||
let checkpoints = self.checkpoints.clone();
|
||||
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
|
||||
async move {
|
||||
executor.simulate_random_delay().await;
|
||||
let oid = Oid::random(&mut executor.rng());
|
||||
let entry = fs.entry(&repository_dir_path)?;
|
||||
checkpoints.lock().insert(oid, entry);
|
||||
Ok(GitRepositoryCheckpoint { commit_sha: oid })
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn restore_checkpoint(
|
||||
&self,
|
||||
_checkpoint: GitRepositoryCheckpoint,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
unimplemented!()
|
||||
fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
|
||||
let executor = self.executor.clone();
|
||||
let fs = self.fs.clone();
|
||||
let checkpoints = self.checkpoints.clone();
|
||||
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
|
||||
async move {
|
||||
executor.simulate_random_delay().await;
|
||||
let checkpoints = checkpoints.lock();
|
||||
let entry = checkpoints
|
||||
.get(&checkpoint.commit_sha)
|
||||
.context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
|
||||
fs.insert_entry(&repository_dir_path, entry.clone())?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn compare_checkpoints(
|
||||
&self,
|
||||
_left: GitRepositoryCheckpoint,
|
||||
_right: GitRepositoryCheckpoint,
|
||||
left: GitRepositoryCheckpoint,
|
||||
right: GitRepositoryCheckpoint,
|
||||
) -> BoxFuture<'_, Result<bool>> {
|
||||
unimplemented!()
|
||||
let executor = self.executor.clone();
|
||||
let checkpoints = self.checkpoints.clone();
|
||||
async move {
|
||||
executor.simulate_random_delay().await;
|
||||
let checkpoints = checkpoints.lock();
|
||||
let left = checkpoints
|
||||
.get(&left.commit_sha)
|
||||
.context(format!("invalid left checkpoint: {}", left.commit_sha))?;
|
||||
let right = checkpoints
|
||||
.get(&right.commit_sha)
|
||||
.context(format!("invalid right checkpoint: {}", right.commit_sha))?;
|
||||
|
||||
Ok(left == right)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn diff_checkpoints(
|
||||
@@ -496,3 +537,63 @@ impl GitRepository for FakeGitRepository {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{FakeFs, Fs};
|
||||
use gpui::BackgroundExecutor;
|
||||
use serde_json::json;
|
||||
use std::path::Path;
|
||||
use util::path;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_checkpoints(executor: BackgroundExecutor) {
|
||||
let fs = FakeFs::new(executor);
|
||||
fs.insert_tree(
|
||||
path!("/"),
|
||||
json!({
|
||||
"bar": {
|
||||
"baz": "qux"
|
||||
},
|
||||
"foo": {
|
||||
".git": {},
|
||||
"a": "lorem",
|
||||
"b": "ipsum",
|
||||
},
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
|
||||
.unwrap();
|
||||
let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
|
||||
|
||||
let checkpoint_1 = repository.checkpoint().await.unwrap();
|
||||
fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
|
||||
fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
|
||||
let checkpoint_2 = repository.checkpoint().await.unwrap();
|
||||
let checkpoint_3 = repository.checkpoint().await.unwrap();
|
||||
|
||||
assert!(
|
||||
repository
|
||||
.compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
assert!(
|
||||
!repository
|
||||
.compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
|
||||
.await
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
repository.restore_checkpoint(checkpoint_1).await.unwrap();
|
||||
assert_eq!(
|
||||
fs.files_with_contents(Path::new("")),
|
||||
[
|
||||
(Path::new("/bar/baz").into(), b"qux".into()),
|
||||
(Path::new("/foo/a").into(), b"lorem".into()),
|
||||
(Path::new("/foo/b").into(), b"ipsum".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -924,7 +924,7 @@ pub struct FakeFs {
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
struct FakeFsState {
|
||||
root: Arc<Mutex<FakeFsEntry>>,
|
||||
root: FakeFsEntry,
|
||||
next_inode: u64,
|
||||
next_mtime: SystemTime,
|
||||
git_event_tx: smol::channel::Sender<PathBuf>,
|
||||
@@ -939,7 +939,7 @@ struct FakeFsState {
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
enum FakeFsEntry {
|
||||
File {
|
||||
inode: u64,
|
||||
@@ -953,7 +953,7 @@ enum FakeFsEntry {
|
||||
inode: u64,
|
||||
mtime: MTime,
|
||||
len: u64,
|
||||
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
|
||||
entries: BTreeMap<String, FakeFsEntry>,
|
||||
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
|
||||
},
|
||||
Symlink {
|
||||
@@ -961,6 +961,67 @@ enum FakeFsEntry {
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl PartialEq for FakeFsEntry {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(
|
||||
Self::File {
|
||||
inode: l_inode,
|
||||
mtime: l_mtime,
|
||||
len: l_len,
|
||||
content: l_content,
|
||||
git_dir_path: l_git_dir_path,
|
||||
},
|
||||
Self::File {
|
||||
inode: r_inode,
|
||||
mtime: r_mtime,
|
||||
len: r_len,
|
||||
content: r_content,
|
||||
git_dir_path: r_git_dir_path,
|
||||
},
|
||||
) => {
|
||||
l_inode == r_inode
|
||||
&& l_mtime == r_mtime
|
||||
&& l_len == r_len
|
||||
&& l_content == r_content
|
||||
&& l_git_dir_path == r_git_dir_path
|
||||
}
|
||||
(
|
||||
Self::Dir {
|
||||
inode: l_inode,
|
||||
mtime: l_mtime,
|
||||
len: l_len,
|
||||
entries: l_entries,
|
||||
git_repo_state: l_git_repo_state,
|
||||
},
|
||||
Self::Dir {
|
||||
inode: r_inode,
|
||||
mtime: r_mtime,
|
||||
len: r_len,
|
||||
entries: r_entries,
|
||||
git_repo_state: r_git_repo_state,
|
||||
},
|
||||
) => {
|
||||
let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
|
||||
(Some(l), Some(r)) => Arc::ptr_eq(l, r),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
};
|
||||
l_inode == r_inode
|
||||
&& l_mtime == r_mtime
|
||||
&& l_len == r_len
|
||||
&& l_entries == r_entries
|
||||
&& same_repo_state
|
||||
}
|
||||
(Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
|
||||
l_target == r_target
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
impl FakeFsState {
|
||||
fn get_and_increment_mtime(&mut self) -> MTime {
|
||||
@@ -975,25 +1036,9 @@ impl FakeFsState {
|
||||
inode
|
||||
}
|
||||
|
||||
fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
|
||||
Ok(self
|
||||
.try_read_path(target, true)
|
||||
.ok_or_else(|| {
|
||||
anyhow!(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
format!("not found: {target:?}")
|
||||
))
|
||||
})?
|
||||
.0)
|
||||
}
|
||||
|
||||
fn try_read_path(
|
||||
&self,
|
||||
target: &Path,
|
||||
follow_symlink: bool,
|
||||
) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
|
||||
let mut path = target.to_path_buf();
|
||||
fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
|
||||
let mut canonical_path = PathBuf::new();
|
||||
let mut path = target.to_path_buf();
|
||||
let mut entry_stack = Vec::new();
|
||||
'outer: loop {
|
||||
let mut path_components = path.components().peekable();
|
||||
@@ -1003,7 +1048,7 @@ impl FakeFsState {
|
||||
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
|
||||
Component::RootDir => {
|
||||
entry_stack.clear();
|
||||
entry_stack.push(self.root.clone());
|
||||
entry_stack.push(&self.root);
|
||||
canonical_path.clear();
|
||||
match prefix {
|
||||
Some(prefix_component) => {
|
||||
@@ -1020,20 +1065,18 @@ impl FakeFsState {
|
||||
canonical_path.pop();
|
||||
}
|
||||
Component::Normal(name) => {
|
||||
let current_entry = entry_stack.last().cloned()?;
|
||||
let current_entry = current_entry.lock();
|
||||
if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
|
||||
let entry = entries.get(name.to_str().unwrap()).cloned()?;
|
||||
let current_entry = *entry_stack.last()?;
|
||||
if let FakeFsEntry::Dir { entries, .. } = current_entry {
|
||||
let entry = entries.get(name.to_str().unwrap())?;
|
||||
if path_components.peek().is_some() || follow_symlink {
|
||||
let entry = entry.lock();
|
||||
if let FakeFsEntry::Symlink { target, .. } = &*entry {
|
||||
if let FakeFsEntry::Symlink { target, .. } = entry {
|
||||
let mut target = target.clone();
|
||||
target.extend(path_components);
|
||||
path = target;
|
||||
continue 'outer;
|
||||
}
|
||||
}
|
||||
entry_stack.push(entry.clone());
|
||||
entry_stack.push(entry);
|
||||
canonical_path = canonical_path.join(name);
|
||||
} else {
|
||||
return None;
|
||||
@@ -1043,19 +1086,72 @@ impl FakeFsState {
|
||||
}
|
||||
break;
|
||||
}
|
||||
Some((entry_stack.pop()?, canonical_path))
|
||||
|
||||
if entry_stack.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(canonical_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
|
||||
fn try_entry(
|
||||
&mut self,
|
||||
target: &Path,
|
||||
follow_symlink: bool,
|
||||
) -> Option<(&mut FakeFsEntry, PathBuf)> {
|
||||
let canonical_path = self.canonicalize(target, follow_symlink)?;
|
||||
|
||||
let mut components = canonical_path.components();
|
||||
let Some(Component::RootDir) = components.next() else {
|
||||
panic!(
|
||||
"the path {:?} was not canonicalized properly {:?}",
|
||||
target, canonical_path
|
||||
)
|
||||
};
|
||||
|
||||
let mut entry = &mut self.root;
|
||||
for component in components {
|
||||
match component {
|
||||
Component::Normal(name) => {
|
||||
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||
entry = entries.get_mut(name.to_str().unwrap())?;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!(
|
||||
"the path {:?} was not canonicalized properly {:?}",
|
||||
target, canonical_path
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some((entry, canonical_path))
|
||||
}
|
||||
|
||||
fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
|
||||
Ok(self
|
||||
.try_entry(target, true)
|
||||
.ok_or_else(|| {
|
||||
anyhow!(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
format!("not found: {target:?}")
|
||||
))
|
||||
})?
|
||||
.0)
|
||||
}
|
||||
|
||||
fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
|
||||
where
|
||||
Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
|
||||
Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
|
||||
{
|
||||
let path = normalize_path(path);
|
||||
let filename = path.file_name().context("cannot overwrite the root")?;
|
||||
let parent_path = path.parent().unwrap();
|
||||
|
||||
let parent = self.read_path(parent_path)?;
|
||||
let mut parent = parent.lock();
|
||||
let parent = self.entry(parent_path)?;
|
||||
let new_entry = parent
|
||||
.dir_entries(parent_path)?
|
||||
.entry(filename.to_str().unwrap().into());
|
||||
@@ -1105,13 +1201,13 @@ impl FakeFs {
|
||||
this: this.clone(),
|
||||
executor: executor.clone(),
|
||||
state: Arc::new(Mutex::new(FakeFsState {
|
||||
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||
root: FakeFsEntry::Dir {
|
||||
inode: 0,
|
||||
mtime: MTime(UNIX_EPOCH),
|
||||
len: 0,
|
||||
entries: Default::default(),
|
||||
git_repo_state: None,
|
||||
})),
|
||||
},
|
||||
git_event_tx: tx,
|
||||
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
|
||||
next_inode: 1,
|
||||
@@ -1161,15 +1257,15 @@ impl FakeFs {
|
||||
.write_path(path, move |entry| {
|
||||
match entry {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
e.insert(FakeFsEntry::File {
|
||||
inode: new_inode,
|
||||
mtime: new_mtime,
|
||||
content: Vec::new(),
|
||||
len: 0,
|
||||
git_dir_path: None,
|
||||
})));
|
||||
});
|
||||
}
|
||||
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
|
||||
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
|
||||
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
|
||||
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
|
||||
FakeFsEntry::Symlink { .. } => {}
|
||||
@@ -1188,7 +1284,7 @@ impl FakeFs {
|
||||
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
|
||||
let mut state = self.state.lock();
|
||||
let path = path.as_ref();
|
||||
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
|
||||
let file = FakeFsEntry::Symlink { target };
|
||||
state
|
||||
.write_path(path.as_ref(), move |e| match e {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
@@ -1221,13 +1317,13 @@ impl FakeFs {
|
||||
match entry {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
kind = Some(PathEventKind::Created);
|
||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
e.insert(FakeFsEntry::File {
|
||||
inode: new_inode,
|
||||
mtime: new_mtime,
|
||||
len: new_len,
|
||||
content: new_content,
|
||||
git_dir_path: None,
|
||||
})));
|
||||
});
|
||||
}
|
||||
btree_map::Entry::Occupied(mut e) => {
|
||||
kind = Some(PathEventKind::Changed);
|
||||
@@ -1237,7 +1333,7 @@ impl FakeFs {
|
||||
len,
|
||||
content,
|
||||
..
|
||||
} = &mut *e.get_mut().lock()
|
||||
} = e.get_mut()
|
||||
{
|
||||
*mtime = new_mtime;
|
||||
*content = new_content;
|
||||
@@ -1259,9 +1355,8 @@ impl FakeFs {
|
||||
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
|
||||
let path = path.as_ref();
|
||||
let path = normalize_path(path);
|
||||
let state = self.state.lock();
|
||||
let entry = state.read_path(&path)?;
|
||||
let entry = entry.lock();
|
||||
let mut state = self.state.lock();
|
||||
let entry = state.entry(&path)?;
|
||||
entry.file_content(&path).cloned()
|
||||
}
|
||||
|
||||
@@ -1269,9 +1364,8 @@ impl FakeFs {
|
||||
let path = path.as_ref();
|
||||
let path = normalize_path(path);
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
let entry = state.read_path(&path)?;
|
||||
let entry = entry.lock();
|
||||
let mut state = self.state.lock();
|
||||
let entry = state.entry(&path)?;
|
||||
entry.file_content(&path).cloned()
|
||||
}
|
||||
|
||||
@@ -1292,6 +1386,25 @@ impl FakeFs {
|
||||
self.state.lock().flush_events(count);
|
||||
}
|
||||
|
||||
pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
|
||||
self.state.lock().entry(target).cloned()
|
||||
}
|
||||
|
||||
pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
state.write_path(target, |entry| {
|
||||
match entry {
|
||||
btree_map::Entry::Vacant(vacant_entry) => {
|
||||
vacant_entry.insert(new_entry);
|
||||
}
|
||||
btree_map::Entry::Occupied(mut occupied_entry) => {
|
||||
occupied_entry.insert(new_entry);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn insert_tree<'a>(
|
||||
&'a self,
|
||||
@@ -1361,20 +1474,19 @@ impl FakeFs {
|
||||
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
|
||||
{
|
||||
let mut state = self.state.lock();
|
||||
let entry = state.read_path(dot_git).context("open .git")?;
|
||||
let mut entry = entry.lock();
|
||||
let git_event_tx = state.git_event_tx.clone();
|
||||
let entry = state.entry(dot_git).context("open .git")?;
|
||||
|
||||
if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
|
||||
if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
|
||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||
log::debug!("insert git state for {dot_git:?}");
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
||||
state.git_event_tx.clone(),
|
||||
)))
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
|
||||
});
|
||||
let mut repo_state = repo_state.lock();
|
||||
|
||||
let result = f(&mut repo_state, dot_git, dot_git);
|
||||
|
||||
drop(repo_state);
|
||||
if emit_git_event {
|
||||
state.emit_event([(dot_git, None)]);
|
||||
}
|
||||
@@ -1398,21 +1510,20 @@ impl FakeFs {
|
||||
}
|
||||
}
|
||||
.clone();
|
||||
drop(entry);
|
||||
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
|
||||
let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
|
||||
anyhow::bail!("pointed-to git dir {path:?} not found")
|
||||
};
|
||||
let FakeFsEntry::Dir {
|
||||
git_repo_state,
|
||||
entries,
|
||||
..
|
||||
} = &mut *git_dir_entry.lock()
|
||||
} = git_dir_entry
|
||||
else {
|
||||
anyhow::bail!("gitfile points to a non-directory")
|
||||
};
|
||||
let common_dir = if let Some(child) = entries.get("commondir") {
|
||||
Path::new(
|
||||
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
|
||||
std::str::from_utf8(child.file_content("commondir".as_ref())?)
|
||||
.context("commondir content")?,
|
||||
)
|
||||
.to_owned()
|
||||
@@ -1420,15 +1531,14 @@ impl FakeFs {
|
||||
canonical_path.clone()
|
||||
};
|
||||
let repo_state = git_repo_state.get_or_insert_with(|| {
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(
|
||||
state.git_event_tx.clone(),
|
||||
)))
|
||||
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
|
||||
});
|
||||
let mut repo_state = repo_state.lock();
|
||||
|
||||
let result = f(&mut repo_state, &canonical_path, &common_dir);
|
||||
|
||||
if emit_git_event {
|
||||
drop(repo_state);
|
||||
state.emit_event([(canonical_path, None)]);
|
||||
}
|
||||
|
||||
@@ -1655,14 +1765,12 @@ impl FakeFs {
|
||||
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
}
|
||||
if include_dot_git
|
||||
@@ -1679,14 +1787,12 @@ impl FakeFs {
|
||||
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
|
||||
if let FakeFsEntry::Dir { entries, .. } = entry {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
if include_dot_git
|
||||
|| !path
|
||||
@@ -1703,17 +1809,14 @@ impl FakeFs {
|
||||
pub fn files(&self) -> Vec<PathBuf> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
let e = entry.lock();
|
||||
match &*e {
|
||||
match entry {
|
||||
FakeFsEntry::File { .. } => result.push(path),
|
||||
FakeFsEntry::Dir { entries, .. } => {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
}
|
||||
FakeFsEntry::Symlink { .. } => {}
|
||||
@@ -1725,13 +1828,10 @@ impl FakeFs {
|
||||
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
|
||||
let mut result = Vec::new();
|
||||
let mut queue = collections::VecDeque::new();
|
||||
queue.push_back((
|
||||
PathBuf::from(util::path!("/")),
|
||||
self.state.lock().root.clone(),
|
||||
));
|
||||
let state = &*self.state.lock();
|
||||
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
|
||||
while let Some((path, entry)) = queue.pop_front() {
|
||||
let e = entry.lock();
|
||||
match &*e {
|
||||
match entry {
|
||||
FakeFsEntry::File { content, .. } => {
|
||||
if path.starts_with(prefix) {
|
||||
result.push((path, content.clone()));
|
||||
@@ -1739,7 +1839,7 @@ impl FakeFs {
|
||||
}
|
||||
FakeFsEntry::Dir { entries, .. } => {
|
||||
for (name, entry) in entries {
|
||||
queue.push_back((path.join(name), entry.clone()));
|
||||
queue.push_back((path.join(name), entry));
|
||||
}
|
||||
}
|
||||
FakeFsEntry::Symlink { .. } => {}
|
||||
@@ -1805,10 +1905,7 @@ impl FakeFsEntry {
|
||||
}
|
||||
}
|
||||
|
||||
fn dir_entries(
|
||||
&mut self,
|
||||
path: &Path,
|
||||
) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
|
||||
fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
|
||||
if let Self::Dir { entries, .. } = self {
|
||||
Ok(entries)
|
||||
} else {
|
||||
@@ -1855,12 +1952,12 @@ struct FakeHandle {
|
||||
impl FileHandle for FakeHandle {
|
||||
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
|
||||
let fs = fs.as_fake();
|
||||
let state = fs.state.lock();
|
||||
let Some(target) = state.moves.get(&self.inode) else {
|
||||
let mut state = fs.state.lock();
|
||||
let Some(target) = state.moves.get(&self.inode).cloned() else {
|
||||
anyhow::bail!("fake fd not moved")
|
||||
};
|
||||
|
||||
if state.try_read_path(&target, false).is_some() {
|
||||
if state.try_entry(&target, false).is_some() {
|
||||
return Ok(target.clone());
|
||||
}
|
||||
anyhow::bail!("fake fd target not found")
|
||||
@@ -1888,13 +1985,13 @@ impl Fs for FakeFs {
|
||||
state.write_path(&cur_path, |entry| {
|
||||
entry.or_insert_with(|| {
|
||||
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
|
||||
Arc::new(Mutex::new(FakeFsEntry::Dir {
|
||||
FakeFsEntry::Dir {
|
||||
inode,
|
||||
mtime,
|
||||
len: 0,
|
||||
entries: Default::default(),
|
||||
git_repo_state: None,
|
||||
}))
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
})?
|
||||
@@ -1909,13 +2006,13 @@ impl Fs for FakeFs {
|
||||
let mut state = self.state.lock();
|
||||
let inode = state.get_and_increment_inode();
|
||||
let mtime = state.get_and_increment_mtime();
|
||||
let file = Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
let file = FakeFsEntry::File {
|
||||
inode,
|
||||
mtime,
|
||||
len: 0,
|
||||
content: Vec::new(),
|
||||
git_dir_path: None,
|
||||
}));
|
||||
};
|
||||
let mut kind = Some(PathEventKind::Created);
|
||||
state.write_path(path, |entry| {
|
||||
match entry {
|
||||
@@ -1939,7 +2036,7 @@ impl Fs for FakeFs {
|
||||
|
||||
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
|
||||
let mut state = self.state.lock();
|
||||
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
|
||||
let file = FakeFsEntry::Symlink { target };
|
||||
state
|
||||
.write_path(path.as_ref(), move |e| match e {
|
||||
btree_map::Entry::Vacant(e) => {
|
||||
@@ -2002,7 +2099,7 @@ impl Fs for FakeFs {
|
||||
}
|
||||
})?;
|
||||
|
||||
let inode = match *moved_entry.lock() {
|
||||
let inode = match moved_entry {
|
||||
FakeFsEntry::File { inode, .. } => inode,
|
||||
FakeFsEntry::Dir { inode, .. } => inode,
|
||||
_ => 0,
|
||||
@@ -2051,8 +2148,8 @@ impl Fs for FakeFs {
|
||||
let mut state = self.state.lock();
|
||||
let mtime = state.get_and_increment_mtime();
|
||||
let inode = state.get_and_increment_inode();
|
||||
let source_entry = state.read_path(&source)?;
|
||||
let content = source_entry.lock().file_content(&source)?.clone();
|
||||
let source_entry = state.entry(&source)?;
|
||||
let content = source_entry.file_content(&source)?.clone();
|
||||
let mut kind = Some(PathEventKind::Created);
|
||||
state.write_path(&target, |e| match e {
|
||||
btree_map::Entry::Occupied(e) => {
|
||||
@@ -2066,13 +2163,13 @@ impl Fs for FakeFs {
|
||||
}
|
||||
}
|
||||
btree_map::Entry::Vacant(e) => Ok(Some(
|
||||
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
|
||||
e.insert(FakeFsEntry::File {
|
||||
inode,
|
||||
mtime,
|
||||
len: content.len() as u64,
|
||||
content,
|
||||
git_dir_path: None,
|
||||
})))
|
||||
})
|
||||
.clone(),
|
||||
)),
|
||||
})?;
|
||||
@@ -2088,8 +2185,7 @@ impl Fs for FakeFs {
|
||||
let base_name = path.file_name().context("cannot remove the root")?;
|
||||
|
||||
let mut state = self.state.lock();
|
||||
let parent_entry = state.read_path(parent_path)?;
|
||||
let mut parent_entry = parent_entry.lock();
|
||||
let parent_entry = state.entry(parent_path)?;
|
||||
let entry = parent_entry
|
||||
.dir_entries(parent_path)?
|
||||
.entry(base_name.to_str().unwrap().into());
|
||||
@@ -2100,15 +2196,14 @@ impl Fs for FakeFs {
|
||||
anyhow::bail!("{path:?} does not exist");
|
||||
}
|
||||
}
|
||||
btree_map::Entry::Occupied(e) => {
|
||||
btree_map::Entry::Occupied(mut entry) => {
|
||||
{
|
||||
let mut entry = e.get().lock();
|
||||
let children = entry.dir_entries(&path)?;
|
||||
let children = entry.get_mut().dir_entries(&path)?;
|
||||
if !options.recursive && !children.is_empty() {
|
||||
anyhow::bail!("{path:?} is not empty");
|
||||
}
|
||||
}
|
||||
e.remove();
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
||||
@@ -2122,8 +2217,7 @@ impl Fs for FakeFs {
|
||||
let parent_path = path.parent().context("cannot remove the root")?;
|
||||
let base_name = path.file_name().unwrap();
|
||||
let mut state = self.state.lock();
|
||||
let parent_entry = state.read_path(parent_path)?;
|
||||
let mut parent_entry = parent_entry.lock();
|
||||
let parent_entry = state.entry(parent_path)?;
|
||||
let entry = parent_entry
|
||||
.dir_entries(parent_path)?
|
||||
.entry(base_name.to_str().unwrap().into());
|
||||
@@ -2133,9 +2227,9 @@ impl Fs for FakeFs {
|
||||
anyhow::bail!("{path:?} does not exist");
|
||||
}
|
||||
}
|
||||
btree_map::Entry::Occupied(e) => {
|
||||
e.get().lock().file_content(&path)?;
|
||||
e.remove();
|
||||
btree_map::Entry::Occupied(mut entry) => {
|
||||
entry.get_mut().file_content(&path)?;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
state.emit_event([(path, Some(PathEventKind::Removed))]);
|
||||
@@ -2149,12 +2243,10 @@ impl Fs for FakeFs {
|
||||
|
||||
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
let entry = state.read_path(&path)?;
|
||||
let entry = entry.lock();
|
||||
let inode = match *entry {
|
||||
FakeFsEntry::File { inode, .. } => inode,
|
||||
FakeFsEntry::Dir { inode, .. } => inode,
|
||||
let mut state = self.state.lock();
|
||||
let inode = match state.entry(&path)? {
|
||||
FakeFsEntry::File { inode, .. } => *inode,
|
||||
FakeFsEntry::Dir { inode, .. } => *inode,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Arc::new(FakeHandle { inode }))
|
||||
@@ -2204,8 +2296,8 @@ impl Fs for FakeFs {
|
||||
let path = normalize_path(path);
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
let (_, canonical_path) = state
|
||||
.try_read_path(&path, true)
|
||||
let canonical_path = state
|
||||
.canonicalize(&path, true)
|
||||
.with_context(|| format!("path does not exist: {path:?}"))?;
|
||||
Ok(canonical_path)
|
||||
}
|
||||
@@ -2213,9 +2305,9 @@ impl Fs for FakeFs {
|
||||
async fn is_file(&self, path: &Path) -> bool {
|
||||
let path = normalize_path(path);
|
||||
self.simulate_random_delay().await;
|
||||
let state = self.state.lock();
|
||||
if let Some((entry, _)) = state.try_read_path(&path, true) {
|
||||
entry.lock().is_file()
|
||||
let mut state = self.state.lock();
|
||||
if let Some((entry, _)) = state.try_entry(&path, true) {
|
||||
entry.is_file()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -2232,17 +2324,16 @@ impl Fs for FakeFs {
|
||||
let path = normalize_path(path);
|
||||
let mut state = self.state.lock();
|
||||
state.metadata_call_count += 1;
|
||||
if let Some((mut entry, _)) = state.try_read_path(&path, false) {
|
||||
let is_symlink = entry.lock().is_symlink();
|
||||
if let Some((mut entry, _)) = state.try_entry(&path, false) {
|
||||
let is_symlink = entry.is_symlink();
|
||||
if is_symlink {
|
||||
if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
|
||||
if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
|
||||
entry = e;
|
||||
} else {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
let entry = entry.lock();
|
||||
Ok(Some(match &*entry {
|
||||
FakeFsEntry::File {
|
||||
inode, mtime, len, ..
|
||||
@@ -2274,12 +2365,11 @@ impl Fs for FakeFs {
|
||||
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
|
||||
self.simulate_random_delay().await;
|
||||
let path = normalize_path(path);
|
||||
let state = self.state.lock();
|
||||
let mut state = self.state.lock();
|
||||
let (entry, _) = state
|
||||
.try_read_path(&path, false)
|
||||
.try_entry(&path, false)
|
||||
.with_context(|| format!("path does not exist: {path:?}"))?;
|
||||
let entry = entry.lock();
|
||||
if let FakeFsEntry::Symlink { target } = &*entry {
|
||||
if let FakeFsEntry::Symlink { target } = entry {
|
||||
Ok(target.clone())
|
||||
} else {
|
||||
anyhow::bail!("not a symlink: {path:?}")
|
||||
@@ -2294,8 +2384,7 @@ impl Fs for FakeFs {
|
||||
let path = normalize_path(path);
|
||||
let mut state = self.state.lock();
|
||||
state.read_dir_call_count += 1;
|
||||
let entry = state.read_path(&path)?;
|
||||
let mut entry = entry.lock();
|
||||
let entry = state.entry(&path)?;
|
||||
let children = entry.dir_entries(&path)?;
|
||||
let paths = children
|
||||
.keys()
|
||||
@@ -2359,6 +2448,7 @@ impl Fs for FakeFs {
|
||||
dot_git_path: abs_dot_git.to_path_buf(),
|
||||
repository_dir_path: repository_dir_path.to_owned(),
|
||||
common_dir_path: common_dir_path.to_owned(),
|
||||
checkpoints: Arc::default(),
|
||||
}) as _
|
||||
},
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ workspace = true
|
||||
path = "src/git.rs"
|
||||
|
||||
[features]
|
||||
test-support = []
|
||||
test-support = ["rand"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
@@ -26,6 +26,7 @@ http_client.workspace = true
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
regex.workspace = true
|
||||
rand = { workspace = true, optional = true }
|
||||
rope.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] }
|
||||
unindent.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
tempfile.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -73,6 +73,7 @@ async fn run_git_blame(
|
||||
.current_dir(working_directory)
|
||||
.arg("blame")
|
||||
.arg("--incremental")
|
||||
.arg("-w")
|
||||
.arg("--contents")
|
||||
.arg("-")
|
||||
.arg(path.as_os_str())
|
||||
|
||||
@@ -119,6 +119,13 @@ impl Oid {
|
||||
Ok(Self(oid))
|
||||
}
|
||||
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub fn random(rng: &mut impl rand::Rng) -> Self {
|
||||
let mut bytes = [0; 20];
|
||||
rng.fill(&mut bytes);
|
||||
Self::from_bytes(&bytes).unwrap()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
self.0.as_bytes()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use collections::HashMap;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::{AsyncWriteExt, FutureExt as _, select_biased};
|
||||
use git2::BranchType;
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString};
|
||||
use gpui::{AppContext as _, AsyncApp, BackgroundExecutor, SharedString, Task};
|
||||
use parking_lot::Mutex;
|
||||
use rope::Rope;
|
||||
use schemars::JsonSchema;
|
||||
@@ -338,7 +338,7 @@ pub trait GitRepository: Send + Sync {
|
||||
|
||||
fn merge_message(&self) -> BoxFuture<'_, Option<String>>;
|
||||
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>>;
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>>;
|
||||
|
||||
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>>;
|
||||
|
||||
@@ -953,25 +953,27 @@ impl GitRepository for RealGitRepository {
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
|
||||
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
|
||||
let git_binary_path = self.git_binary_path.clone();
|
||||
let working_directory = self.working_directory();
|
||||
let path_prefixes = path_prefixes.to_owned();
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let output = new_std_command(&git_binary_path)
|
||||
.current_dir(working_directory?)
|
||||
.args(git_status_args(&path_prefixes))
|
||||
.output()?;
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
stdout.parse()
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("git status failed: {stderr}");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
let working_directory = match self.working_directory() {
|
||||
Ok(working_directory) => working_directory,
|
||||
Err(e) => return Task::ready(Err(e)),
|
||||
};
|
||||
let args = git_status_args(&path_prefixes);
|
||||
log::debug!("Checking for git status in {path_prefixes:?}");
|
||||
self.executor.spawn(async move {
|
||||
let output = new_std_command(&git_binary_path)
|
||||
.current_dir(working_directory)
|
||||
.args(args)
|
||||
.output()?;
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
stdout.parse()
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("git status failed: {stderr}");
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
|
||||
|
||||
@@ -2105,7 +2105,7 @@ impl GitPanel {
|
||||
Ok(_) => cx.update(|window, cx| {
|
||||
window.prompt(
|
||||
PromptLevel::Info,
|
||||
"Git Clone",
|
||||
&format!("Git Clone: {}", repo_name),
|
||||
None,
|
||||
&["Add repo to project", "Open repo in new project"],
|
||||
cx,
|
||||
|
||||
@@ -181,10 +181,6 @@ pub fn init(cx: &mut App) {
|
||||
workspace.toggle_modal(window, cx, |window, cx| {
|
||||
GitCloneModal::show(panel, window, cx)
|
||||
});
|
||||
|
||||
// panel.update(cx, |panel, cx| {
|
||||
// panel.git_clone(window, cx);
|
||||
// });
|
||||
});
|
||||
workspace.register_action(|workspace, _: &git::OpenModifiedFiles, window, cx| {
|
||||
open_modified_files(workspace, window, cx);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use gpui::{
|
||||
App, Application, Context, Menu, MenuItem, Window, WindowOptions, actions, div, prelude::*, rgb,
|
||||
App, Application, Context, Menu, MenuItem, SystemMenuType, Window, WindowOptions, actions, div,
|
||||
prelude::*, rgb,
|
||||
};
|
||||
|
||||
struct SetMenus;
|
||||
@@ -27,7 +28,11 @@ fn main() {
|
||||
// Add menu items
|
||||
cx.set_menus(vec![Menu {
|
||||
name: "set_menus".into(),
|
||||
items: vec![MenuItem::action("Quit", Quit)],
|
||||
items: vec![
|
||||
MenuItem::os_submenu("Services", SystemMenuType::Services),
|
||||
MenuItem::separator(),
|
||||
MenuItem::action("Quit", Quit),
|
||||
],
|
||||
}]);
|
||||
cx.open_window(WindowOptions::default(), |_, cx| cx.new(|_| SetMenus {}))
|
||||
.unwrap();
|
||||
|
||||
@@ -277,6 +277,8 @@ pub struct App {
|
||||
pub(crate) release_listeners: SubscriberSet<EntityId, ReleaseListener>,
|
||||
pub(crate) global_observers: SubscriberSet<TypeId, Handler>,
|
||||
pub(crate) quit_observers: SubscriberSet<(), QuitHandler>,
|
||||
pub(crate) restart_observers: SubscriberSet<(), Handler>,
|
||||
pub(crate) restart_path: Option<PathBuf>,
|
||||
pub(crate) window_closed_observers: SubscriberSet<(), WindowClosedHandler>,
|
||||
pub(crate) layout_id_buffer: Vec<LayoutId>, // We recycle this memory across layout requests.
|
||||
pub(crate) propagate_event: bool,
|
||||
@@ -349,6 +351,8 @@ impl App {
|
||||
keyboard_layout_observers: SubscriberSet::new(),
|
||||
global_observers: SubscriberSet::new(),
|
||||
quit_observers: SubscriberSet::new(),
|
||||
restart_observers: SubscriberSet::new(),
|
||||
restart_path: None,
|
||||
window_closed_observers: SubscriberSet::new(),
|
||||
layout_id_buffer: Default::default(),
|
||||
propagate_event: true,
|
||||
@@ -832,8 +836,16 @@ impl App {
|
||||
}
|
||||
|
||||
/// Restarts the application.
|
||||
pub fn restart(&self, binary_path: Option<PathBuf>) {
|
||||
self.platform.restart(binary_path)
|
||||
pub fn restart(&mut self) {
|
||||
self.restart_observers
|
||||
.clone()
|
||||
.retain(&(), |observer| observer(self));
|
||||
self.platform.restart(self.restart_path.take())
|
||||
}
|
||||
|
||||
/// Sets the path to use when restarting the application.
|
||||
pub fn set_restart_path(&mut self, path: PathBuf) {
|
||||
self.restart_path = Some(path);
|
||||
}
|
||||
|
||||
/// Returns the HTTP client for the application.
|
||||
@@ -1466,6 +1478,21 @@ impl App {
|
||||
subscription
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when the application is about to restart.
|
||||
///
|
||||
/// These callbacks are called before any `on_app_quit` callbacks.
|
||||
pub fn on_app_restart(&self, mut on_restart: impl 'static + FnMut(&mut App)) -> Subscription {
|
||||
let (subscription, activate) = self.restart_observers.insert(
|
||||
(),
|
||||
Box::new(move |cx| {
|
||||
on_restart(cx);
|
||||
true
|
||||
}),
|
||||
);
|
||||
activate();
|
||||
subscription
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when a window is closed
|
||||
/// The window is no longer accessible at the point this callback is invoked.
|
||||
pub fn on_window_closed(&self, mut on_closed: impl FnMut(&mut App) + 'static) -> Subscription {
|
||||
|
||||
@@ -164,6 +164,20 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
subscription
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when the application is about to restart.
|
||||
pub fn on_app_restart(
|
||||
&self,
|
||||
mut on_restart: impl FnMut(&mut T, &mut App) + 'static,
|
||||
) -> Subscription
|
||||
where
|
||||
T: 'static,
|
||||
{
|
||||
let handle = self.weak_entity();
|
||||
self.app.on_app_restart(move |cx| {
|
||||
handle.update(cx, |entity, cx| on_restart(entity, cx)).ok();
|
||||
})
|
||||
}
|
||||
|
||||
/// Arrange for the given function to be invoked whenever the application is quit.
|
||||
/// The future returned from this callback will be polled for up to [crate::SHUTDOWN_TIMEOUT] until the app fully quits.
|
||||
pub fn on_app_quit<Fut>(
|
||||
@@ -175,20 +189,15 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
T: 'static,
|
||||
{
|
||||
let handle = self.weak_entity();
|
||||
let (subscription, activate) = self.app.quit_observers.insert(
|
||||
(),
|
||||
Box::new(move |cx| {
|
||||
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
|
||||
async move {
|
||||
if let Some(future) = future {
|
||||
future.await;
|
||||
}
|
||||
self.app.on_app_quit(move |cx| {
|
||||
let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok();
|
||||
async move {
|
||||
if let Some(future) = future {
|
||||
future.await;
|
||||
}
|
||||
.boxed_local()
|
||||
}),
|
||||
);
|
||||
activate();
|
||||
subscription
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
}
|
||||
|
||||
/// Tell GPUI that this entity has changed and observers of it should be notified.
|
||||
|
||||
@@ -20,6 +20,34 @@ impl Menu {
|
||||
}
|
||||
}
|
||||
|
||||
/// OS menus are menus that are recognized by the operating system
|
||||
/// This allows the operating system to provide specialized items for
|
||||
/// these menus
|
||||
pub struct OsMenu {
|
||||
/// The name of the menu
|
||||
pub name: SharedString,
|
||||
|
||||
/// The type of menu
|
||||
pub menu_type: SystemMenuType,
|
||||
}
|
||||
|
||||
impl OsMenu {
|
||||
/// Create an OwnedOsMenu from this OsMenu
|
||||
pub fn owned(self) -> OwnedOsMenu {
|
||||
OwnedOsMenu {
|
||||
name: self.name.to_string().into(),
|
||||
menu_type: self.menu_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of system menu
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
pub enum SystemMenuType {
|
||||
/// The 'Services' menu in the Application menu on macOS
|
||||
Services,
|
||||
}
|
||||
|
||||
/// The different kinds of items that can be in a menu
|
||||
pub enum MenuItem {
|
||||
/// A separator between items
|
||||
@@ -28,6 +56,9 @@ pub enum MenuItem {
|
||||
/// A submenu
|
||||
Submenu(Menu),
|
||||
|
||||
/// A menu, managed by the system (for example, the Services menu on macOS)
|
||||
SystemMenu(OsMenu),
|
||||
|
||||
/// An action that can be performed
|
||||
Action {
|
||||
/// The name of this menu item
|
||||
@@ -53,6 +84,14 @@ impl MenuItem {
|
||||
Self::Submenu(menu)
|
||||
}
|
||||
|
||||
/// Creates a new submenu that is populated by the OS
|
||||
pub fn os_submenu(name: impl Into<SharedString>, menu_type: SystemMenuType) -> Self {
|
||||
Self::SystemMenu(OsMenu {
|
||||
name: name.into(),
|
||||
menu_type,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new menu item that invokes an action
|
||||
pub fn action(name: impl Into<SharedString>, action: impl Action) -> Self {
|
||||
Self::Action {
|
||||
@@ -89,10 +128,23 @@ impl MenuItem {
|
||||
action,
|
||||
os_action,
|
||||
},
|
||||
MenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.owned()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OS menus are menus that are recognized by the operating system
|
||||
/// This allows the operating system to provide specialized items for
|
||||
/// these menus
|
||||
#[derive(Clone)]
|
||||
pub struct OwnedOsMenu {
|
||||
/// The name of the menu
|
||||
pub name: SharedString,
|
||||
|
||||
/// The type of menu
|
||||
pub menu_type: SystemMenuType,
|
||||
}
|
||||
|
||||
/// A menu of the application, either a main menu or a submenu
|
||||
#[derive(Clone)]
|
||||
pub struct OwnedMenu {
|
||||
@@ -111,6 +163,9 @@ pub enum OwnedMenuItem {
|
||||
/// A submenu
|
||||
Submenu(OwnedMenu),
|
||||
|
||||
/// A menu, managed by the system (for example, the Services menu on macOS)
|
||||
SystemMenu(OwnedOsMenu),
|
||||
|
||||
/// An action that can be performed
|
||||
Action {
|
||||
/// The name of this menu item
|
||||
@@ -139,6 +194,7 @@ impl Clone for OwnedMenuItem {
|
||||
action: action.boxed_clone(),
|
||||
os_action: *os_action,
|
||||
},
|
||||
OwnedMenuItem::SystemMenu(os_menu) => OwnedMenuItem::SystemMenu(os_menu.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@ use super::{
|
||||
use crate::{
|
||||
Action, AnyWindowHandle, BackgroundExecutor, ClipboardEntry, ClipboardItem, ClipboardString,
|
||||
CursorStyle, ForegroundExecutor, Image, ImageFormat, KeyContext, Keymap, MacDispatcher,
|
||||
MacDisplay, MacWindow, Menu, MenuItem, OwnedMenu, PathPromptOptions, Platform, PlatformDisplay,
|
||||
PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result, SemanticVersion, Task,
|
||||
WindowAppearance, WindowParams, hash,
|
||||
MacDisplay, MacWindow, Menu, MenuItem, OsMenu, OwnedMenu, PathPromptOptions, Platform,
|
||||
PlatformDisplay, PlatformKeyboardLayout, PlatformTextSystem, PlatformWindow, Result,
|
||||
SemanticVersion, SystemMenuType, Task, WindowAppearance, WindowParams, hash,
|
||||
};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use block::ConcreteBlock;
|
||||
@@ -413,9 +413,20 @@ impl MacPlatform {
|
||||
}
|
||||
item.setSubmenu_(submenu);
|
||||
item.setTitle_(ns_string(&name));
|
||||
if name == "Services" {
|
||||
let app: id = msg_send![APP_CLASS, sharedApplication];
|
||||
app.setServicesMenu_(item);
|
||||
item
|
||||
}
|
||||
MenuItem::SystemMenu(OsMenu { name, menu_type }) => {
|
||||
let item = NSMenuItem::new(nil).autorelease();
|
||||
let submenu = NSMenu::new(nil).autorelease();
|
||||
submenu.setDelegate_(delegate);
|
||||
item.setSubmenu_(submenu);
|
||||
item.setTitle_(ns_string(&name));
|
||||
|
||||
match menu_type {
|
||||
SystemMenuType::Services => {
|
||||
let app: id = msg_send![APP_CLASS, sharedApplication];
|
||||
app.setServicesMenu_(item);
|
||||
}
|
||||
}
|
||||
|
||||
item
|
||||
|
||||
@@ -10,6 +10,7 @@ mod keyboard;
|
||||
mod platform;
|
||||
mod system_settings;
|
||||
mod util;
|
||||
mod vsync;
|
||||
mod window;
|
||||
mod wrapper;
|
||||
|
||||
@@ -25,6 +26,7 @@ pub(crate) use keyboard::*;
|
||||
pub(crate) use platform::*;
|
||||
pub(crate) use system_settings::*;
|
||||
pub(crate) use util::*;
|
||||
pub(crate) use vsync::*;
|
||||
pub(crate) use window::*;
|
||||
pub(crate) use wrapper::*;
|
||||
|
||||
|
||||
@@ -4,16 +4,15 @@ use ::util::ResultExt;
|
||||
use anyhow::{Context, Result};
|
||||
use windows::{
|
||||
Win32::{
|
||||
Foundation::{FreeLibrary, HMODULE, HWND},
|
||||
Foundation::{HMODULE, HWND},
|
||||
Graphics::{
|
||||
Direct3D::*,
|
||||
Direct3D11::*,
|
||||
DirectComposition::*,
|
||||
Dxgi::{Common::*, *},
|
||||
},
|
||||
System::LibraryLoader::LoadLibraryA,
|
||||
},
|
||||
core::{Interface, PCSTR},
|
||||
core::Interface,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -208,7 +207,7 @@ impl DirectXRenderer {
|
||||
|
||||
fn present(&mut self) -> Result<()> {
|
||||
unsafe {
|
||||
let result = self.resources.swap_chain.Present(1, DXGI_PRESENT(0));
|
||||
let result = self.resources.swap_chain.Present(0, DXGI_PRESENT(0));
|
||||
// Presenting the swap chain can fail if the DirectX device was removed or reset.
|
||||
if result == DXGI_ERROR_DEVICE_REMOVED || result == DXGI_ERROR_DEVICE_RESET {
|
||||
let reason = self.devices.device.GetDeviceRemovedReason();
|
||||
@@ -1619,22 +1618,6 @@ pub(crate) mod shader_resources {
|
||||
}
|
||||
}
|
||||
|
||||
fn with_dll_library<R, F>(dll_name: PCSTR, f: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce(HMODULE) -> Result<R>,
|
||||
{
|
||||
let library = unsafe {
|
||||
LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))?
|
||||
};
|
||||
let result = f(library);
|
||||
unsafe {
|
||||
FreeLibrary(library)
|
||||
.with_context(|| format!("Freeing dll: {}", dll_name.display()))
|
||||
.log_err();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
mod nvidia {
|
||||
use std::{
|
||||
ffi::CStr,
|
||||
@@ -1644,7 +1627,7 @@ mod nvidia {
|
||||
use anyhow::Result;
|
||||
use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
|
||||
|
||||
use crate::platform::windows::directx_renderer::with_dll_library;
|
||||
use crate::with_dll_library;
|
||||
|
||||
// https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180
|
||||
const NVAPI_SHORT_STRING_MAX: usize = 64;
|
||||
@@ -1711,7 +1694,7 @@ mod amd {
|
||||
use anyhow::Result;
|
||||
use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
|
||||
|
||||
use crate::platform::windows::directx_renderer::with_dll_library;
|
||||
use crate::with_dll_library;
|
||||
|
||||
// https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145
|
||||
const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12);
|
||||
|
||||
@@ -32,7 +32,7 @@ use crate::*;
|
||||
|
||||
pub(crate) struct WindowsPlatform {
|
||||
state: RefCell<WindowsPlatformState>,
|
||||
raw_window_handles: RwLock<SmallVec<[HWND; 4]>>,
|
||||
raw_window_handles: Arc<RwLock<SmallVec<[SafeHwnd; 4]>>>,
|
||||
// The below members will never change throughout the entire lifecycle of the app.
|
||||
icon: HICON,
|
||||
main_receiver: flume::Receiver<Runnable>,
|
||||
@@ -114,7 +114,7 @@ impl WindowsPlatform {
|
||||
};
|
||||
let icon = load_icon().unwrap_or_default();
|
||||
let state = RefCell::new(WindowsPlatformState::new());
|
||||
let raw_window_handles = RwLock::new(SmallVec::new());
|
||||
let raw_window_handles = Arc::new(RwLock::new(SmallVec::new()));
|
||||
let windows_version = WindowsVersion::new().context("Error retrieve windows version")?;
|
||||
|
||||
Ok(Self {
|
||||
@@ -134,22 +134,12 @@ impl WindowsPlatform {
|
||||
})
|
||||
}
|
||||
|
||||
fn redraw_all(&self) {
|
||||
for handle in self.raw_window_handles.read().iter() {
|
||||
unsafe {
|
||||
RedrawWindow(Some(*handle), None, None, RDW_INVALIDATE | RDW_UPDATENOW)
|
||||
.ok()
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn window_from_hwnd(&self, hwnd: HWND) -> Option<Rc<WindowsWindowInner>> {
|
||||
self.raw_window_handles
|
||||
.read()
|
||||
.iter()
|
||||
.find(|entry| *entry == &hwnd)
|
||||
.and_then(|hwnd| window_from_hwnd(*hwnd))
|
||||
.find(|entry| entry.as_raw() == hwnd)
|
||||
.and_then(|hwnd| window_from_hwnd(hwnd.as_raw()))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -158,7 +148,7 @@ impl WindowsPlatform {
|
||||
.read()
|
||||
.iter()
|
||||
.for_each(|handle| unsafe {
|
||||
PostMessageW(Some(*handle), message, wparam, lparam).log_err();
|
||||
PostMessageW(Some(handle.as_raw()), message, wparam, lparam).log_err();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -166,7 +156,7 @@ impl WindowsPlatform {
|
||||
let mut lock = self.raw_window_handles.write();
|
||||
let index = lock
|
||||
.iter()
|
||||
.position(|handle| *handle == target_window)
|
||||
.position(|handle| handle.as_raw() == target_window)
|
||||
.unwrap();
|
||||
lock.remove(index);
|
||||
|
||||
@@ -226,19 +216,19 @@ impl WindowsPlatform {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns true if the app should quit.
|
||||
fn handle_events(&self) -> bool {
|
||||
// Returns if the app should quit.
|
||||
fn handle_events(&self) {
|
||||
let mut msg = MSG::default();
|
||||
unsafe {
|
||||
while PeekMessageW(&mut msg, None, 0, 0, PM_REMOVE).as_bool() {
|
||||
while GetMessageW(&mut msg, None, 0, 0).as_bool() {
|
||||
match msg.message {
|
||||
WM_QUIT => return true,
|
||||
WM_QUIT => return,
|
||||
WM_INPUTLANGCHANGE
|
||||
| WM_GPUI_CLOSE_ONE_WINDOW
|
||||
| WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD
|
||||
| WM_GPUI_DOCK_MENU_ACTION => {
|
||||
if self.handle_gpui_evnets(msg.message, msg.wParam, msg.lParam, &msg) {
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
@@ -247,7 +237,6 @@ impl WindowsPlatform {
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Returns true if the app should quit.
|
||||
@@ -315,8 +304,28 @@ impl WindowsPlatform {
|
||||
self.raw_window_handles
|
||||
.read()
|
||||
.iter()
|
||||
.find(|&&hwnd| hwnd == active_window_hwnd)
|
||||
.copied()
|
||||
.find(|hwnd| hwnd.as_raw() == active_window_hwnd)
|
||||
.map(|hwnd| hwnd.as_raw())
|
||||
}
|
||||
|
||||
fn begin_vsync_thread(&self) {
|
||||
let all_windows = Arc::downgrade(&self.raw_window_handles);
|
||||
std::thread::spawn(move || {
|
||||
let vsync_provider = VSyncProvider::new();
|
||||
loop {
|
||||
vsync_provider.wait_for_vsync();
|
||||
let Some(all_windows) = all_windows.upgrade() else {
|
||||
break;
|
||||
};
|
||||
for hwnd in all_windows.read().iter() {
|
||||
unsafe {
|
||||
RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE)
|
||||
.ok()
|
||||
.log_err();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,12 +356,8 @@ impl Platform for WindowsPlatform {
|
||||
|
||||
fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) {
|
||||
on_finish_launching();
|
||||
loop {
|
||||
if self.handle_events() {
|
||||
break;
|
||||
}
|
||||
self.redraw_all();
|
||||
}
|
||||
self.begin_vsync_thread();
|
||||
self.handle_events();
|
||||
|
||||
if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit {
|
||||
callback();
|
||||
@@ -365,9 +370,9 @@ impl Platform for WindowsPlatform {
|
||||
.detach();
|
||||
}
|
||||
|
||||
fn restart(&self, _: Option<PathBuf>) {
|
||||
fn restart(&self, binary_path: Option<PathBuf>) {
|
||||
let pid = std::process::id();
|
||||
let Some(app_path) = self.app_path().log_err() else {
|
||||
let Some(app_path) = binary_path.or(self.app_path().log_err()) else {
|
||||
return;
|
||||
};
|
||||
let script = format!(
|
||||
@@ -445,7 +450,7 @@ impl Platform for WindowsPlatform {
|
||||
) -> Result<Box<dyn PlatformWindow>> {
|
||||
let window = WindowsWindow::new(handle, options, self.generate_creation_info())?;
|
||||
let handle = window.get_raw_handle();
|
||||
self.raw_window_handles.write().push(handle);
|
||||
self.raw_window_handles.write().push(handle.into());
|
||||
|
||||
Ok(Box::new(window))
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use ::util::ResultExt;
|
||||
use anyhow::Context;
|
||||
use windows::{
|
||||
UI::{
|
||||
Color,
|
||||
ViewManagement::{UIColorType, UISettings},
|
||||
},
|
||||
Wdk::System::SystemServices::RtlGetVersion,
|
||||
Win32::{Foundation::*, Graphics::Dwm::*, UI::WindowsAndMessaging::*},
|
||||
core::{BOOL, HSTRING},
|
||||
Win32::{
|
||||
Foundation::*, Graphics::Dwm::*, System::LibraryLoader::LoadLibraryA,
|
||||
UI::WindowsAndMessaging::*,
|
||||
},
|
||||
core::{BOOL, HSTRING, PCSTR},
|
||||
};
|
||||
|
||||
use crate::*;
|
||||
@@ -197,3 +201,19 @@ pub(crate) fn show_error(title: &str, content: String) {
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) fn with_dll_library<R, F>(dll_name: PCSTR, f: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce(HMODULE) -> Result<R>,
|
||||
{
|
||||
let library = unsafe {
|
||||
LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))?
|
||||
};
|
||||
let result = f(library);
|
||||
unsafe {
|
||||
FreeLibrary(library)
|
||||
.with_context(|| format!("Freeing dll: {}", dll_name.display()))
|
||||
.log_err();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
174
crates/gpui/src/platform/windows/vsync.rs
Normal file
174
crates/gpui/src/platform/windows/vsync.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use std::{
|
||||
sync::LazyLock,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use util::ResultExt;
|
||||
use windows::{
|
||||
Win32::{
|
||||
Foundation::{HANDLE, HWND},
|
||||
Graphics::{
|
||||
DirectComposition::{
|
||||
COMPOSITION_FRAME_ID_COMPLETED, COMPOSITION_FRAME_ID_TYPE, COMPOSITION_FRAME_STATS,
|
||||
COMPOSITION_TARGET_ID,
|
||||
},
|
||||
Dwm::{DWM_TIMING_INFO, DwmFlush, DwmGetCompositionTimingInfo},
|
||||
},
|
||||
System::{
|
||||
LibraryLoader::{GetModuleHandleA, GetProcAddress},
|
||||
Performance::QueryPerformanceFrequency,
|
||||
Threading::INFINITE,
|
||||
},
|
||||
},
|
||||
core::{HRESULT, s},
|
||||
};
|
||||
|
||||
static QPC_TICKS_PER_SECOND: LazyLock<u64> = LazyLock::new(|| {
|
||||
let mut frequency = 0;
|
||||
// On systems that run Windows XP or later, the function will always succeed and
|
||||
// will thus never return zero.
|
||||
unsafe { QueryPerformanceFrequency(&mut frequency).unwrap() };
|
||||
frequency as u64
|
||||
});
|
||||
|
||||
const VSYNC_INTERVAL_THRESHOLD: Duration = Duration::from_millis(1);
|
||||
const DEFAULT_VSYNC_INTERVAL: Duration = Duration::from_micros(16_666); // ~60Hz
|
||||
|
||||
// Here we are using dynamic loading of DirectComposition functions,
|
||||
// or the app will refuse to start on windows systems that do not support DirectComposition.
|
||||
type DCompositionGetFrameId =
|
||||
unsafe extern "system" fn(frameidtype: COMPOSITION_FRAME_ID_TYPE, frameid: *mut u64) -> HRESULT;
|
||||
type DCompositionGetStatistics = unsafe extern "system" fn(
|
||||
frameid: u64,
|
||||
framestats: *mut COMPOSITION_FRAME_STATS,
|
||||
targetidcount: u32,
|
||||
targetids: *mut COMPOSITION_TARGET_ID,
|
||||
actualtargetidcount: *mut u32,
|
||||
) -> HRESULT;
|
||||
type DCompositionWaitForCompositorClock =
|
||||
unsafe extern "system" fn(count: u32, handles: *const HANDLE, timeoutinms: u32) -> u32;
|
||||
|
||||
pub(crate) struct VSyncProvider {
|
||||
interval: Duration,
|
||||
f: Box<dyn Fn() -> bool>,
|
||||
}
|
||||
|
||||
impl VSyncProvider {
|
||||
pub(crate) fn new() -> Self {
|
||||
if let Some((get_frame_id, get_statistics, wait_for_comp_clock)) =
|
||||
initialize_direct_composition()
|
||||
.context("Retrieving DirectComposition functions")
|
||||
.log_with_level(log::Level::Warn)
|
||||
{
|
||||
let interval = get_dwm_interval_from_direct_composition(get_frame_id, get_statistics)
|
||||
.context("Failed to get DWM interval from DirectComposition")
|
||||
.log_err()
|
||||
.unwrap_or(DEFAULT_VSYNC_INTERVAL);
|
||||
log::info!(
|
||||
"DirectComposition is supported for VSync, interval: {:?}",
|
||||
interval
|
||||
);
|
||||
let f = Box::new(move || unsafe {
|
||||
wait_for_comp_clock(0, std::ptr::null(), INFINITE) == 0
|
||||
});
|
||||
Self { interval, f }
|
||||
} else {
|
||||
let interval = get_dwm_interval()
|
||||
.context("Failed to get DWM interval")
|
||||
.log_err()
|
||||
.unwrap_or(DEFAULT_VSYNC_INTERVAL);
|
||||
log::info!(
|
||||
"DirectComposition is not supported for VSync, falling back to DWM, interval: {:?}",
|
||||
interval
|
||||
);
|
||||
let f = Box::new(|| unsafe { DwmFlush().is_ok() });
|
||||
Self { interval, f }
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn wait_for_vsync(&self) {
|
||||
let vsync_start = Instant::now();
|
||||
let wait_succeeded = (self.f)();
|
||||
let elapsed = vsync_start.elapsed();
|
||||
// DwmFlush and DCompositionWaitForCompositorClock returns very early
|
||||
// instead of waiting until vblank when the monitor goes to sleep or is
|
||||
// unplugged (nothing to present due to desktop occlusion). We use 1ms as
|
||||
// a threshhold for the duration of the wait functions and fallback to
|
||||
// Sleep() if it returns before that. This could happen during normal
|
||||
// operation for the first call after the vsync thread becomes non-idle,
|
||||
// but it shouldn't happen often.
|
||||
if !wait_succeeded || elapsed < VSYNC_INTERVAL_THRESHOLD {
|
||||
log::warn!("VSyncProvider::wait_for_vsync() took shorter than expected");
|
||||
std::thread::sleep(self.interval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize_direct_composition() -> Result<(
|
||||
DCompositionGetFrameId,
|
||||
DCompositionGetStatistics,
|
||||
DCompositionWaitForCompositorClock,
|
||||
)> {
|
||||
unsafe {
|
||||
// Load DLL at runtime since older Windows versions don't have dcomp.
|
||||
let hmodule = GetModuleHandleA(s!("dcomp.dll")).context("Loading dcomp.dll")?;
|
||||
let get_frame_id_addr = GetProcAddress(hmodule, s!("DCompositionGetFrameId"))
|
||||
.context("Function DCompositionGetFrameId not found")?;
|
||||
let get_statistics_addr = GetProcAddress(hmodule, s!("DCompositionGetStatistics"))
|
||||
.context("Function DCompositionGetStatistics not found")?;
|
||||
let wait_for_compositor_clock_addr =
|
||||
GetProcAddress(hmodule, s!("DCompositionWaitForCompositorClock"))
|
||||
.context("Function DCompositionWaitForCompositorClock not found")?;
|
||||
let get_frame_id: DCompositionGetFrameId = std::mem::transmute(get_frame_id_addr);
|
||||
let get_statistics: DCompositionGetStatistics = std::mem::transmute(get_statistics_addr);
|
||||
let wait_for_compositor_clock: DCompositionWaitForCompositorClock =
|
||||
std::mem::transmute(wait_for_compositor_clock_addr);
|
||||
Ok((get_frame_id, get_statistics, wait_for_compositor_clock))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_dwm_interval_from_direct_composition(
|
||||
get_frame_id: DCompositionGetFrameId,
|
||||
get_statistics: DCompositionGetStatistics,
|
||||
) -> Result<Duration> {
|
||||
let mut frame_id = 0;
|
||||
unsafe { get_frame_id(COMPOSITION_FRAME_ID_COMPLETED, &mut frame_id) }.ok()?;
|
||||
let mut stats = COMPOSITION_FRAME_STATS::default();
|
||||
unsafe {
|
||||
get_statistics(
|
||||
frame_id,
|
||||
&mut stats,
|
||||
0,
|
||||
std::ptr::null_mut(),
|
||||
std::ptr::null_mut(),
|
||||
)
|
||||
}
|
||||
.ok()?;
|
||||
Ok(retrieve_duration(stats.framePeriod, *QPC_TICKS_PER_SECOND))
|
||||
}
|
||||
|
||||
fn get_dwm_interval() -> Result<Duration> {
|
||||
let mut timing_info = DWM_TIMING_INFO {
|
||||
cbSize: std::mem::size_of::<DWM_TIMING_INFO>() as u32,
|
||||
..Default::default()
|
||||
};
|
||||
unsafe { DwmGetCompositionTimingInfo(HWND::default(), &mut timing_info) }?;
|
||||
let interval = retrieve_duration(timing_info.qpcRefreshPeriod, *QPC_TICKS_PER_SECOND);
|
||||
// Check for interval values that are impossibly low. A 29 microsecond
|
||||
// interval was seen (from a qpcRefreshPeriod of 60).
|
||||
if interval < VSYNC_INTERVAL_THRESHOLD {
|
||||
Ok(retrieve_duration(
|
||||
timing_info.rateRefresh.uiDenominator as u64,
|
||||
timing_info.rateRefresh.uiNumerator as u64,
|
||||
))
|
||||
} else {
|
||||
Ok(interval)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn retrieve_duration(counts: u64, ticks_per_second: u64) -> Duration {
|
||||
let ticks_per_microsecond = ticks_per_second / 1_000_000;
|
||||
Duration::from_micros(counts / ticks_per_microsecond)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
use windows::Win32::UI::WindowsAndMessaging::HCURSOR;
|
||||
use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::HCURSOR};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) struct SafeCursor {
|
||||
@@ -23,3 +23,31 @@ impl Deref for SafeCursor {
|
||||
&self.raw
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) struct SafeHwnd {
|
||||
raw: HWND,
|
||||
}
|
||||
|
||||
impl SafeHwnd {
|
||||
pub(crate) fn as_raw(&self) -> HWND {
|
||||
self.raw
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for SafeHwnd {}
|
||||
unsafe impl Sync for SafeHwnd {}
|
||||
|
||||
impl From<HWND> for SafeHwnd {
|
||||
fn from(value: HWND) -> Self {
|
||||
SafeHwnd { raw: value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SafeHwnd {
|
||||
type Target = HWND;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.raw
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,6 +167,7 @@ fn generate_test_function(
|
||||
));
|
||||
cx_teardowns.extend(quote!(
|
||||
dispatcher.run_until_parked();
|
||||
#cx_varname.executor().forbid_parking();
|
||||
#cx_varname.quit();
|
||||
dispatcher.run_until_parked();
|
||||
));
|
||||
@@ -232,7 +233,7 @@ fn generate_test_function(
|
||||
cx_teardowns.extend(quote!(
|
||||
drop(#cx_varname_lock);
|
||||
dispatcher.run_until_parked();
|
||||
#cx_varname.update(|cx| { cx.quit() });
|
||||
#cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
|
||||
dispatcher.run_until_parked();
|
||||
));
|
||||
continue;
|
||||
@@ -247,6 +248,7 @@ fn generate_test_function(
|
||||
));
|
||||
cx_teardowns.extend(quote!(
|
||||
dispatcher.run_until_parked();
|
||||
#cx_varname.executor().forbid_parking();
|
||||
#cx_varname.quit();
|
||||
dispatcher.run_until_parked();
|
||||
));
|
||||
|
||||
@@ -942,6 +942,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
model.id(),
|
||||
model.supports_parallel_tool_calls(),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
|
||||
@@ -14,7 +14,7 @@ use language_model::{
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
|
||||
use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
@@ -45,6 +45,7 @@ pub struct AvailableModel {
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
pub reasoning_effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
pub struct OpenAiLanguageModelProvider {
|
||||
@@ -213,6 +214,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
reasoning_effort: model.reasoning_effort.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -301,7 +303,25 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
use open_ai::Model;
|
||||
match &self.model {
|
||||
Model::FourOmni
|
||||
| Model::FourOmniMini
|
||||
| Model::FourPointOne
|
||||
| Model::FourPointOneMini
|
||||
| Model::FourPointOneNano
|
||||
| Model::Five
|
||||
| Model::FiveMini
|
||||
| Model::FiveNano
|
||||
| Model::O1
|
||||
| Model::O3
|
||||
| Model::O4Mini => true,
|
||||
Model::ThreePointFiveTurbo
|
||||
| Model::Four
|
||||
| Model::FourTurbo
|
||||
| Model::O3Mini
|
||||
| Model::Custom { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
@@ -351,6 +371,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
self.model.reasoning_effort(),
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
@@ -366,6 +387,7 @@ pub fn into_open_ai(
|
||||
model_id: &str,
|
||||
supports_parallel_tool_calls: bool,
|
||||
max_output_tokens: Option<u64>,
|
||||
reasoning_effort: Option<ReasoningEffort>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model_id.starts_with("o1-");
|
||||
|
||||
@@ -455,6 +477,7 @@ pub fn into_open_ai(
|
||||
} else {
|
||||
None
|
||||
},
|
||||
prompt_cache_key: request.thread_id,
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
@@ -471,6 +494,7 @@ pub fn into_open_ai(
|
||||
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
|
||||
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
|
||||
}),
|
||||
reasoning_effort,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -355,7 +355,13 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = into_open_ai(request, &self.model.name, true, self.max_output_tokens());
|
||||
let request = into_open_ai(
|
||||
request,
|
||||
&self.model.name,
|
||||
true,
|
||||
self.max_output_tokens(),
|
||||
None,
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = OpenAiEventMapper::new();
|
||||
|
||||
@@ -356,6 +356,7 @@ impl LanguageModel for VercelLanguageModel {
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
None,
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
|
||||
@@ -360,6 +360,7 @@ impl LanguageModel for XAiLanguageModel {
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
None,
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
|
||||
@@ -149,7 +149,9 @@
|
||||
parameters: (parameter_list
|
||||
"(" @context
|
||||
")" @context)))
|
||||
]
|
||||
(type_qualifier)? @context) @item
|
||||
; Fields declarations may define multiple fields, and so @item is on the
|
||||
; declarator so they each get distinct ranges.
|
||||
] @item
|
||||
(type_qualifier)? @context)
|
||||
|
||||
(comment) @annotation
|
||||
|
||||
@@ -103,7 +103,13 @@ impl LspAdapter for CssLspAdapter {
|
||||
|
||||
let should_install_language_server = self
|
||||
.node
|
||||
.should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version)
|
||||
.should_install_npm_package(
|
||||
Self::PACKAGE_NAME,
|
||||
&server_path,
|
||||
&container_dir,
|
||||
&version,
|
||||
Default::default(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if should_install_language_server {
|
||||
|
||||
@@ -487,6 +487,8 @@ const GO_MODULE_ROOT_TASK_VARIABLE: VariableName =
|
||||
VariableName::Custom(Cow::Borrowed("GO_MODULE_ROOT"));
|
||||
const GO_SUBTEST_NAME_TASK_VARIABLE: VariableName =
|
||||
VariableName::Custom(Cow::Borrowed("GO_SUBTEST_NAME"));
|
||||
const GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE: VariableName =
|
||||
VariableName::Custom(Cow::Borrowed("GO_TABLE_TEST_CASE_NAME"));
|
||||
|
||||
impl ContextProvider for GoContextProvider {
|
||||
fn build_context(
|
||||
@@ -545,10 +547,19 @@ impl ContextProvider for GoContextProvider {
|
||||
let go_subtest_variable = extract_subtest_name(_subtest_name.unwrap_or(""))
|
||||
.map(|subtest_name| (GO_SUBTEST_NAME_TASK_VARIABLE.clone(), subtest_name));
|
||||
|
||||
let table_test_case_name = variables.get(&VariableName::Custom(Cow::Borrowed(
|
||||
"_table_test_case_name",
|
||||
)));
|
||||
|
||||
let go_table_test_case_variable = table_test_case_name
|
||||
.and_then(extract_subtest_name)
|
||||
.map(|case_name| (GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.clone(), case_name));
|
||||
|
||||
Task::ready(Ok(TaskVariables::from_iter(
|
||||
[
|
||||
go_package_variable,
|
||||
go_subtest_variable,
|
||||
go_table_test_case_variable,
|
||||
go_module_root_variable,
|
||||
]
|
||||
.into_iter()
|
||||
@@ -570,6 +581,28 @@ impl ContextProvider for GoContextProvider {
|
||||
let module_cwd = Some(GO_MODULE_ROOT_TASK_VARIABLE.template_value());
|
||||
|
||||
Task::ready(Some(TaskTemplates(vec![
|
||||
TaskTemplate {
|
||||
label: format!(
|
||||
"go test {} -v -run {}/{}",
|
||||
GO_PACKAGE_TASK_VARIABLE.template_value(),
|
||||
VariableName::Symbol.template_value(),
|
||||
GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(),
|
||||
),
|
||||
command: "go".into(),
|
||||
args: vec![
|
||||
"test".into(),
|
||||
"-v".into(),
|
||||
"-run".into(),
|
||||
format!(
|
||||
"\\^{}\\$/\\^{}\\$",
|
||||
VariableName::Symbol.template_value(),
|
||||
GO_TABLE_TEST_CASE_NAME_TASK_VARIABLE.template_value(),
|
||||
),
|
||||
],
|
||||
cwd: package_cwd.clone(),
|
||||
tags: vec!["go-table-test-case".to_owned()],
|
||||
..TaskTemplate::default()
|
||||
},
|
||||
TaskTemplate {
|
||||
label: format!(
|
||||
"go test {} -run {}",
|
||||
@@ -842,10 +875,21 @@ mod tests {
|
||||
.collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
runnables.len() == 2,
|
||||
"Should find test function and subtest with double quotes, found: {}",
|
||||
runnables.len()
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-subtest".to_string()),
|
||||
"Should find go-subtest tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
|
||||
let buffer = cx.new(|cx| {
|
||||
@@ -860,10 +904,299 @@ mod tests {
|
||||
.collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
runnables.len() == 2,
|
||||
"Should find test function and subtest with backticks, found: {}",
|
||||
runnables.len()
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-subtest".to_string()),
|
||||
"Should find go-subtest tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_slice_detection(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
_ = "some random string"
|
||||
|
||||
testCases := []struct{
|
||||
name string
|
||||
anotherStr string
|
||||
}{
|
||||
{
|
||||
name: "test case 1",
|
||||
anotherStr: "foo",
|
||||
},
|
||||
{
|
||||
name: "test case 2",
|
||||
anotherStr: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
notATableTest := []struct{
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "some string",
|
||||
},
|
||||
{
|
||||
name: "some other string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// test code here
|
||||
})
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
|
||||
let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count();
|
||||
let go_table_test_count = tag_strings
|
||||
.iter()
|
||||
.filter(|&tag| tag == "go-table-test-case")
|
||||
.count();
|
||||
|
||||
assert!(
|
||||
go_test_count == 1,
|
||||
"Should find exactly 1 go-test, found: {}",
|
||||
go_test_count
|
||||
);
|
||||
assert!(
|
||||
go_table_test_count == 2,
|
||||
"Should find exactly 2 go-table-test-case, found: {}",
|
||||
go_table_test_count
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_slice_ignored(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
func Example() {
|
||||
_ = "some random string"
|
||||
|
||||
notATableTest := []struct{
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "some string",
|
||||
},
|
||||
{
|
||||
name: "some other string",
|
||||
},
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_map_detection(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
_ = "some random string"
|
||||
|
||||
testCases := map[string]struct {
|
||||
someStr string
|
||||
fail bool
|
||||
}{
|
||||
"test failure": {
|
||||
someStr: "foo",
|
||||
fail: true,
|
||||
},
|
||||
"test success": {
|
||||
someStr: "bar",
|
||||
fail: false,
|
||||
},
|
||||
}
|
||||
|
||||
notATableTest := map[string]struct {
|
||||
someStr string
|
||||
}{
|
||||
"some string": {
|
||||
someStr: "foo",
|
||||
},
|
||||
"some other string": {
|
||||
someStr: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// test code here
|
||||
})
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
|
||||
let go_test_count = tag_strings.iter().filter(|&tag| tag == "go-test").count();
|
||||
let go_table_test_count = tag_strings
|
||||
.iter()
|
||||
.filter(|&tag| tag == "go-table-test-case")
|
||||
.count();
|
||||
|
||||
assert!(
|
||||
go_test_count == 1,
|
||||
"Should find exactly 1 go-test, found: {}",
|
||||
go_test_count
|
||||
);
|
||||
assert!(
|
||||
go_table_test_count == 2,
|
||||
"Should find exactly 2 go-table-test-case, found: {}",
|
||||
go_table_test_count
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_go_table_test_map_ignored(cx: &mut TestAppContext) {
|
||||
let language = language("go", tree_sitter_go::LANGUAGE.into());
|
||||
|
||||
let table_test = r#"
|
||||
package main
|
||||
|
||||
func Example() {
|
||||
_ = "some random string"
|
||||
|
||||
notATableTest := map[string]struct {
|
||||
someStr string
|
||||
}{
|
||||
"some string": {
|
||||
someStr: "foo",
|
||||
},
|
||||
"some other string": {
|
||||
someStr: "bar",
|
||||
},
|
||||
}
|
||||
}
|
||||
"#;
|
||||
|
||||
let buffer =
|
||||
cx.new(|cx| crate::Buffer::local(table_test, cx).with_language(language.clone(), cx));
|
||||
cx.executor().run_until_parked();
|
||||
|
||||
let runnables: Vec<_> = buffer.update(cx, |buffer, _| {
|
||||
let snapshot = buffer.snapshot();
|
||||
snapshot.runnable_ranges(0..table_test.len()).collect()
|
||||
});
|
||||
|
||||
let tag_strings: Vec<String> = runnables
|
||||
.iter()
|
||||
.flat_map(|r| &r.runnable.tags)
|
||||
.map(|tag| tag.0.to_string())
|
||||
.collect();
|
||||
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-test".to_string()),
|
||||
"Should find go-test tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
assert!(
|
||||
!tag_strings.contains(&"go-table-test-case".to_string()),
|
||||
"Should find go-table-test-case tag, found: {:?}",
|
||||
tag_strings
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
(comment) @annotation
|
||||
|
||||
(type_declaration
|
||||
"type" @context
|
||||
[
|
||||
@@ -42,13 +43,13 @@
|
||||
(var_declaration
|
||||
"var" @context
|
||||
[
|
||||
; The declaration may define multiple variables, and so @item is on
|
||||
; the identifier so they get distinct ranges.
|
||||
(var_spec
|
||||
name: (identifier) @name) @item
|
||||
name: (identifier) @name @item)
|
||||
(var_spec_list
|
||||
"("
|
||||
(var_spec
|
||||
name: (identifier) @name) @item
|
||||
")"
|
||||
name: (identifier) @name @item)
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -60,5 +61,7 @@
|
||||
"(" @context
|
||||
")" @context)) @item
|
||||
|
||||
; Fields declarations may define multiple fields, and so @item is on the
|
||||
; declarator so they each get distinct ranges.
|
||||
(field_declaration
|
||||
name: (_) @name) @item
|
||||
name: (_) @name @item)
|
||||
|
||||
@@ -91,3 +91,103 @@
|
||||
) @_
|
||||
(#set! tag go-main)
|
||||
)
|
||||
|
||||
; Table test cases - slice and map
|
||||
(
|
||||
(short_var_declaration
|
||||
left: (expression_list (identifier) @_collection_var)
|
||||
right: (expression_list
|
||||
(composite_literal
|
||||
type: [
|
||||
(slice_type)
|
||||
(map_type
|
||||
key: (type_identifier) @_key_type
|
||||
(#eq? @_key_type "string")
|
||||
)
|
||||
]
|
||||
body: (literal_value
|
||||
[
|
||||
(literal_element
|
||||
(literal_value
|
||||
(keyed_element
|
||||
(literal_element
|
||||
(identifier) @_field_name
|
||||
)
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal) @run @_table_test_case_name
|
||||
(raw_string_literal) @run @_table_test_case_name
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(keyed_element
|
||||
(literal_element
|
||||
[
|
||||
(interpreted_string_literal) @run @_table_test_case_name
|
||||
(raw_string_literal) @run @_table_test_case_name
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(for_statement
|
||||
(range_clause
|
||||
left: (expression_list
|
||||
[
|
||||
(
|
||||
(identifier)
|
||||
(identifier) @_loop_var
|
||||
)
|
||||
(identifier) @_loop_var
|
||||
]
|
||||
)
|
||||
right: (identifier) @_range_var
|
||||
(#eq? @_range_var @_collection_var)
|
||||
)
|
||||
body: (block
|
||||
(expression_statement
|
||||
(call_expression
|
||||
function: (selector_expression
|
||||
operand: (identifier) @_t_var
|
||||
field: (field_identifier) @_run_method
|
||||
(#eq? @_run_method "Run")
|
||||
)
|
||||
arguments: (argument_list
|
||||
.
|
||||
[
|
||||
(selector_expression
|
||||
operand: (identifier) @_tc_var
|
||||
(#eq? @_tc_var @_loop_var)
|
||||
field: (field_identifier) @_field_check
|
||||
(#eq? @_field_check @_field_name)
|
||||
)
|
||||
(identifier) @_arg_var
|
||||
(#eq? @_arg_var @_loop_var)
|
||||
]
|
||||
.
|
||||
(func_literal
|
||||
parameters: (parameter_list
|
||||
(parameter_declaration
|
||||
type: (pointer_type
|
||||
(qualified_type
|
||||
package: (package_identifier) @_pkg
|
||||
name: (type_identifier) @_type
|
||||
(#eq? @_pkg "testing")
|
||||
(#eq? @_type "T")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
) @_
|
||||
(#set! tag go-table-test-case)
|
||||
)
|
||||
|
||||
@@ -31,12 +31,16 @@
|
||||
(export_statement
|
||||
(lexical_declaration
|
||||
["let" "const"] @context
|
||||
; Multiple names may be exported - @item is on the declarator to keep
|
||||
; ranges distinct.
|
||||
(variable_declarator
|
||||
name: (_) @name) @item)))
|
||||
|
||||
(program
|
||||
(lexical_declaration
|
||||
["let" "const"] @context
|
||||
; Multiple names may be defined - @item is on the declarator to keep
|
||||
; ranges distinct.
|
||||
(variable_declarator
|
||||
name: (_) @name) @item))
|
||||
|
||||
|
||||
@@ -340,7 +340,13 @@ impl LspAdapter for JsonLspAdapter {
|
||||
|
||||
let should_install_language_server = self
|
||||
.node
|
||||
.should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version)
|
||||
.should_install_npm_package(
|
||||
Self::PACKAGE_NAME,
|
||||
&server_path,
|
||||
&container_dir,
|
||||
&version,
|
||||
Default::default(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if should_install_language_server {
|
||||
|
||||
@@ -206,6 +206,7 @@ impl LspAdapter for PythonLspAdapter {
|
||||
&server_path,
|
||||
&container_dir,
|
||||
&version,
|
||||
Default::default(),
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -238,7 +238,7 @@ impl LspAdapter for RustLspAdapter {
|
||||
)
|
||||
.await?;
|
||||
make_file_executable(&server_path).await?;
|
||||
remove_matching(&container_dir, |path| server_path != path).await;
|
||||
remove_matching(&container_dir, |path| path != destination_path).await;
|
||||
GithubBinaryMetadata::write_to_file(
|
||||
&GithubBinaryMetadata {
|
||||
metadata_version: 1,
|
||||
@@ -1023,8 +1023,14 @@ async fn get_cached_server_binary(container_dir: PathBuf) -> Option<LanguageServ
|
||||
last = Some(path);
|
||||
}
|
||||
|
||||
let path = last.context("no cached binary")?;
|
||||
let path = match RustLspAdapter::GITHUB_ASSET_KIND {
|
||||
AssetKind::TarGz | AssetKind::Gz => path.clone(), // Tar and gzip extract in place.
|
||||
AssetKind::Zip => path.clone().join("rust-analyzer.exe"), // zip contains a .exe
|
||||
};
|
||||
|
||||
anyhow::Ok(LanguageServerBinary {
|
||||
path: last.context("no cached binary")?,
|
||||
path,
|
||||
env: None,
|
||||
arguments: Default::default(),
|
||||
})
|
||||
|
||||
@@ -108,7 +108,13 @@ impl LspAdapter for TailwindLspAdapter {
|
||||
|
||||
let should_install_language_server = self
|
||||
.node
|
||||
.should_install_npm_package(Self::PACKAGE_NAME, &server_path, &container_dir, &version)
|
||||
.should_install_npm_package(
|
||||
Self::PACKAGE_NAME,
|
||||
&server_path,
|
||||
&container_dir,
|
||||
&version,
|
||||
Default::default(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if should_install_language_server {
|
||||
|
||||
@@ -34,12 +34,16 @@
|
||||
(export_statement
|
||||
(lexical_declaration
|
||||
["let" "const"] @context
|
||||
; Multiple names may be exported - @item is on the declarator to keep
|
||||
; ranges distinct.
|
||||
(variable_declarator
|
||||
name: (_) @name) @item))
|
||||
|
||||
(program
|
||||
(lexical_declaration
|
||||
["let" "const"] @context
|
||||
; Multiple names may be defined - @item is on the declarator to keep
|
||||
; ranges distinct.
|
||||
(variable_declarator
|
||||
name: (_) @name) @item))
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user