Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Sloan
b38b6ff12c agent checkpoint dbgs 2025-08-12 01:11:38 -06:00
225 changed files with 3986 additions and 11497 deletions

View File

@@ -1,15 +1,15 @@
name: Bug Report (Windows Alpha)
description: Zed Windows Alpha Related Bugs
name: Bug Report (Windows)
description: Zed Windows-Related Bugs
type: "Bug"
labels: ["windows"]
title: "Windows Alpha: <a short description of the Windows bug>"
title: "Windows: <a short description of the Windows bug>"
body:
- type: textarea
attributes:
label: Summary
description: Describe the bug with a one-line summary, and provide detailed reproduction steps
description: Describe the bug with a one line summary, and provide detailed reproduction steps
value: |
<!-- Please insert a one-line summary of the issue below -->
<!-- Please insert a one line summary of the issue below -->
SUMMARY_SENTENCE_HERE
### Description

View File

@@ -718,7 +718,7 @@ jobs:
timeout-minutes: 60
runs-on: github-8vcpu-ubuntu-2404
if: |
false && ( startsWith(github.ref, 'refs/tags/v')
( startsWith(github.ref, 'refs/tags/v')
|| contains(github.event.pull_request.labels.*.name, 'run-bundling') )
needs: [linux_tests]
name: Build Zed on FreeBSD

60
Cargo.lock generated
View File

@@ -7,23 +7,20 @@ name = "acp_thread"
version = "0.1.0"
dependencies = [
"action_log",
"agent",
"agent-client-protocol",
"anyhow",
"buffer_diff",
"collections",
"editor",
"env_logger 0.11.8",
"file_icons",
"futures 0.3.31",
"gpui",
"indoc",
"itertools 0.14.0",
"language",
"language_model",
"markdown",
"parking_lot",
"project",
"prompt_store",
"rand 0.8.5",
"serde",
"serde_json",
@@ -32,10 +29,7 @@ dependencies = [
"tempfile",
"terminal",
"ui",
"url",
"util",
"uuid",
"watch",
"workspace-hack",
]
@@ -172,9 +166,9 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
version = "0.0.24"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fd68bbbef8e424fb8a605c5f0b00c360f682c4528b0a5feb5ec928aaf5ce28e"
checksum = "3fad72b7b8ee4331b3a4c8d43c107e982a4725564b4ee658ae5c4e79d2b486e8"
dependencies = [
"anyhow",
"futures 0.3.31",
@@ -202,7 +196,6 @@ dependencies = [
"clock",
"cloud_llm_client",
"collections",
"context_server",
"ctor",
"editor",
"env_logger 0.11.8",
@@ -236,7 +229,6 @@ dependencies = [
"task",
"tempfile",
"terminal",
"text",
"theme",
"tree-sitter-rust",
"ui",
@@ -395,7 +387,6 @@ dependencies = [
"ui",
"ui_input",
"unindent",
"url",
"urlencoding",
"util",
"uuid",
@@ -1304,9 +1295,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
[[package]]
name = "async-trait"
version = "0.1.89"
version = "0.1.88"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5"
dependencies = [
"proc-macro2",
"quote",
@@ -6451,7 +6442,6 @@ dependencies = [
"log",
"parking_lot",
"pretty_assertions",
"rand 0.8.5",
"regex",
"rope",
"schemars",
@@ -11156,13 +11146,14 @@ dependencies = [
"ai_onboarding",
"anyhow",
"client",
"command_palette_hooks",
"component",
"db",
"documented",
"editor",
"feature_flags",
"fs",
"fuzzy",
"git",
"gpui",
"itertools 0.14.0",
"language",
@@ -11174,7 +11165,6 @@ dependencies = [
"schemars",
"serde",
"settings",
"telemetry",
"theme",
"ui",
"util",
@@ -11250,7 +11240,6 @@ dependencies = [
"anyhow",
"futures 0.3.31",
"http_client",
"log",
"schemars",
"serde",
"serde_json",
@@ -15054,10 +15043,8 @@ dependencies = [
"ui",
"ui_input",
"util",
"vim",
"workspace",
"workspace-hack",
"zed_actions",
]
[[package]]
@@ -18892,6 +18879,33 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "welcome"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"component",
"db",
"documented",
"editor",
"fuzzy",
"gpui",
"install_cli",
"language",
"picker",
"project",
"serde",
"settings",
"telemetry",
"ui",
"util",
"vim_mode_setting",
"workspace",
"workspace-hack",
"zed_actions",
]
[[package]]
name = "which"
version = "4.4.2"
@@ -20506,7 +20520,7 @@ dependencies = [
[[package]]
name = "zed"
version = "0.201.0"
version = "0.200.0"
dependencies = [
"activity_indicator",
"agent",
@@ -20646,6 +20660,7 @@ dependencies = [
"watch",
"web_search",
"web_search_providers",
"welcome",
"windows 0.61.1",
"winresource",
"workspace",
@@ -20669,7 +20684,7 @@ dependencies = [
[[package]]
name = "zed_emmet"
version = "0.0.6"
version = "0.0.4"
dependencies = [
"zed_extension_api 0.1.0",
]
@@ -20908,7 +20923,6 @@ dependencies = [
"menu",
"postage",
"project",
"rand 0.8.5",
"regex",
"release_channel",
"reqwest_client",

View File

@@ -185,6 +185,7 @@ members = [
"crates/watch",
"crates/web_search",
"crates/web_search_providers",
"crates/welcome",
"crates/workspace",
"crates/worktree",
"crates/x_ai",
@@ -411,6 +412,7 @@ vim_mode_setting = { path = "crates/vim_mode_setting" }
watch = { path = "crates/watch" }
web_search = { path = "crates/web_search" }
web_search_providers = { path = "crates/web_search_providers" }
welcome = { path = "crates/welcome" }
workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
x_ai = { path = "crates/x_ai" }
@@ -425,7 +427,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
agent-client-protocol = "0.0.24"
agent-client-protocol = "0.0.23"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -564,7 +566,6 @@ reqwest = { git = "https://github.com/zed-industries/reqwest.git", rev = "951c77
"socks",
"stream",
] }
rodio = { version = "0.21.1", default-features = false }
rsa = "0.9.6"
runtimelib = { git = "https://github.com/ConradIrwin/runtimed", rev = "7130c804216b6914355d15d0b91ea91f6babd734", default-features = false, features = [
"async-dispatcher-runtime",
@@ -713,7 +714,6 @@ features = [
"Win32_System_LibraryLoader",
"Win32_System_Memory",
"Win32_System_Ole",
"Win32_System_Performance",
"Win32_System_Pipes",
"Win32_System_SystemInformation",
"Win32_System_SystemServices",

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,9 +1,8 @@
Copyright 2019 The Lilex Project Authors (https://github.com/mishamyrt/Lilex)
Copyright © 2017 IBM Corp. with Reserved Font Name "Plex"
This Font Software is licensed under the SIL Open Font License, Version 1.1.
This license is copied below, and is also available with a FAQ at:
https://scripts.sil.org/OFL
http://scripts.sil.org/OFL
-----------------------------------------------------------
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
@@ -90,4 +89,4 @@ COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
OTHER DEALINGS IN THE FONT SOFTWARE.
OTHER DEALINGS IN THE FONT SOFTWARE.

View File

@@ -1,4 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.78125 3C3.90625 3 3.90625 4.5 3.90625 5.5C3.90625 6.5 3.40625 7.50106 2.40625 8C3.40625 8.50106 3.90625 9.5 3.90625 10.5C3.90625 11.5 3.90625 13 5.78125 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M10.2422 3C12.1172 3 12.1172 4.5 12.1172 5.5C12.1172 6.5 12.6172 7.50106 13.6172 8C12.6172 8.50106 12.1172 9.5 12.1172 10.5C12.1172 11.5 12.1172 13 10.2422 13" stroke="black" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 607 B

View File

@@ -239,7 +239,6 @@
"ctrl-shift-a": "agent::ToggleContextPicker",
"ctrl-shift-j": "agent::ToggleNavigationMenu",
"ctrl-shift-i": "agent::ToggleOptionsMenu",
"ctrl-alt-shift-n": "agent::ToggleNewThreadMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"ctrl->": "assistant::QuoteSelection",
"ctrl-alt-e": "agent::RemoveAllContext",
@@ -331,6 +330,8 @@
"use_key_equivalents": true,
"bindings": {
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff",
"ctrl-shift-y": "agent::KeepAll",
"ctrl-shift-n": "agent::RejectAll"

View File

@@ -279,7 +279,6 @@
"cmd-shift-a": "agent::ToggleContextPicker",
"cmd-shift-j": "agent::ToggleNavigationMenu",
"cmd-shift-i": "agent::ToggleOptionsMenu",
"cmd-alt-shift-n": "agent::ToggleNewThreadMenu",
"shift-alt-escape": "agent::ExpandMessageEditor",
"cmd->": "assistant::QuoteSelection",
"cmd-alt-e": "agent::RemoveAllContext",
@@ -383,6 +382,8 @@
"use_key_equivalents": true,
"bindings": {
"enter": "agent::Chat",
"up": "agent::PreviousHistoryMessage",
"down": "agent::NextHistoryMessage",
"shift-ctrl-r": "agent::OpenAgentDiff",
"cmd-shift-y": "agent::KeepAll",
"cmd-shift-n": "agent::RejectAll"

View File

@@ -58,8 +58,6 @@
"[ space": "vim::InsertEmptyLineAbove",
"[ e": "editor::MoveLineUp",
"] e": "editor::MoveLineDown",
"[ f": "workspace::FollowNextCollaborator",
"] f": "workspace::FollowNextCollaborator",
// Word motions
"w": "vim::NextWordStart",
@@ -392,7 +390,7 @@
"right": "vim::WrappingRight",
"h": "vim::WrappingLeft",
"l": "vim::WrappingRight",
"y": "vim::HelixYank",
"y": "editor::Copy",
"alt-;": "vim::OtherEnd",
"ctrl-r": "vim::Redo",
"f": ["vim::PushFindForward", { "before": false, "multiline": true }],
@@ -409,7 +407,6 @@
"g w": "vim::PushRewrap",
"insert": "vim::InsertBefore",
"alt-.": "vim::RepeatFind",
"alt-s": ["editor::SplitSelectionIntoLines", { "keep_selections": true }],
// tree-sitter related commands
"[ x": "editor::SelectLargerSyntaxNode",
"] x": "editor::SelectSmallerSyntaxNode",

View File

@@ -28,9 +28,7 @@
"edit_prediction_provider": "zed"
},
// The name of a font to use for rendering text in the editor
// ".ZedMono" currently aliases to Lilex
// but this may change in the future.
"buffer_font_family": ".ZedMono",
"buffer_font_family": "Zed Plex Mono",
// Set the buffer text's font fallbacks, this will be merged with
// the platform's default fallbacks.
"buffer_font_fallbacks": null,
@@ -56,9 +54,7 @@
"buffer_line_height": "comfortable",
// The name of a font to use for rendering text in the UI
// You can set this to ".SystemUIFont" to use the system font
// ".ZedSans" currently aliases to "IBM Plex Sans", but this may
// change in the future
"ui_font_family": ".ZedSans",
"ui_font_family": "Zed Plex Sans",
// Set the UI's font fallbacks, this will be merged with the platform's
// default font fallbacks.
"ui_font_fallbacks": null,
@@ -86,10 +82,10 @@
// Layout mode of the bottom dock. Defaults to "contained"
// choices: contained, full, left_aligned, right_aligned
"bottom_dock_layout": "contained",
// The direction that you want to split panes horizontally. Defaults to "down"
"pane_split_direction_horizontal": "down",
// The direction that you want to split panes vertically. Defaults to "right"
"pane_split_direction_vertical": "right",
// The direction that you want to split panes horizontally. Defaults to "up"
"pane_split_direction_horizontal": "up",
// The direction that you want to split panes vertically. Defaults to "left"
"pane_split_direction_vertical": "left",
// Centered layout related settings.
"centered_layout": {
// The relative width of the left padding of the central pane from the
@@ -1406,7 +1402,7 @@
// "font_size": 15,
// Set the terminal's font family. If this option is not included,
// the terminal will default to matching the buffer's font family.
// "font_family": ".ZedMono",
// "font_family": "Zed Plex Mono",
// Set the terminal's font fallbacks. If this option is not included,
// the terminal will default to matching the buffer's font fallbacks.
// This will be merged with the platform's default font fallbacks

View File

@@ -13,35 +13,28 @@ path = "src/acp_thread.rs"
doctest = false
[features]
test-support = ["gpui/test-support", "project/test-support", "dep:parking_lot"]
test-support = ["gpui/test-support", "project/test-support"]
[dependencies]
action_log.workspace = true
agent-client-protocol.workspace = true
agent.workspace = true
anyhow.workspace = true
buffer_diff.workspace = true
collections.workspace = true
editor.workspace = true
file_icons.workspace = true
futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
markdown.workspace = true
parking_lot = { workspace = true, optional = true }
project.workspace = true
prompt_store.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
terminal.workspace = true
ui.workspace = true
url.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
workspace-hack.workspace = true
[dev-dependencies]

File diff suppressed because it is too large Load Diff

View File

@@ -1,78 +1,18 @@
use crate::AcpThread;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use gpui::{AsyncApp, Entity, Task};
use language_model::LanguageModel;
use project::Project;
use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
use ui::App;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct UserMessageId(Arc<str>);
impl UserMessageId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}
pub trait AgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>>;
fn auth_methods(&self) -> &[acp::AuthMethod];
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(
&self,
user_message_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn session_editor(
&self,
_session_id: &acp::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionEditor>> {
None
}
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
None
}
}
pub trait AgentSessionEditor {
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
}
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthRequired")
}
}
use crate::AcpThread;
/// Trait for agents that support listing, selecting, and querying language models.
///
/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
pub trait AgentModelSelector: 'static {
pub trait ModelSelector: 'static {
/// Lists all available language models for this agent.
///
/// # Parameters
@@ -80,7 +20,7 @@ pub trait AgentModelSelector: 'static {
///
/// # Returns
/// A task resolving to the list of models or an error (e.g., if no models are configured).
fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
/// Selects a model for a specific session (thread).
///
@@ -97,8 +37,8 @@ pub trait AgentModelSelector: 'static {
fn select_model(
&self,
session_id: acp::SessionId,
model_id: AgentModelId,
cx: &mut App,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>>;
/// Retrieves the currently selected model for a specific session (thread).
@@ -112,203 +52,42 @@ pub trait AgentModelSelector: 'static {
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut App,
) -> Task<Result<AgentModelInfo>>;
/// Whenever the model list is updated the receiver will be notified.
fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>>;
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AgentModelId(pub SharedString);
pub trait AgentConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
impl std::ops::Deref for AgentModelId {
type Target = SharedString;
fn auth_methods(&self) -> &[acp::AuthMethod];
fn deref(&self) -> &Self::Target {
&self.0
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
-> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
/// This allows sharing the selector in UI components.
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
None // Default impl for agents that don't support it
}
}
impl fmt::Display for AgentModelId {
#[derive(Debug)]
pub struct AuthRequired;
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
write!(f, "AuthRequired")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentModelInfo {
pub id: AgentModelId,
pub name: SharedString,
pub icon: Option<IconName>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct AgentModelGroupName(pub SharedString);
#[derive(Debug, Clone)]
pub enum AgentModelList {
Flat(Vec<AgentModelInfo>),
Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
}
impl AgentModelList {
pub fn is_empty(&self) -> bool {
match self {
AgentModelList::Flat(models) => models.is_empty(),
AgentModelList::Grouped(groups) => groups.is_empty(),
}
}
}
#[cfg(feature = "test-support")]
mod test_support {
use std::sync::Arc;
use collections::HashMap;
use futures::future::try_join_all;
use gpui::{AppContext as _, WeakEntity};
use parking_lot::Mutex;
use super::*;
#[derive(Clone, Default)]
pub struct StubAgentConnection {
sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
}
impl StubAgentConnection {
pub fn new() -> Self {
Self {
next_prompt_updates: Default::default(),
permission_requests: HashMap::default(),
sessions: Arc::default(),
}
}
pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
*self.next_prompt_updates.lock() = updates;
}
pub fn with_permission_requests(
mut self,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
) -> Self {
self.permission_requests = permission_requests;
self
}
pub fn send_update(
&self,
session_id: acp::SessionId,
update: acp::SessionUpdate,
cx: &mut App,
) {
self.sessions
.lock()
.get(&session_id)
.unwrap()
.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})
.unwrap();
}
}
impl AgentConnection for StubAgentConnection {
fn auth_methods(&self) -> &[acp::AuthMethod] {
&[]
}
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
fn authenticate(
&self,
_method_id: acp::AuthMethodId,
_cx: &mut App,
) -> Task<gpui::Result<()>> {
unimplemented!()
}
fn prompt(
&self,
_id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
let sessions = self.sessions.lock();
let thread = sessions.get(&params.session_id).unwrap();
let mut tasks = vec![];
for update in self.next_prompt_updates.lock().drain(..) {
let thread = thread.clone();
let update = update.clone();
let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
&& let Some(options) = self.permission_requests.get(&tool_call.id)
{
Some((tool_call.clone(), options.clone()))
} else {
None
};
let task = cx.spawn(async move |cx| {
if let Some((tool_call, options)) = permission_request {
let permission = thread.update(cx, |thread, cx| {
thread.request_tool_call_authorization(
tool_call.clone(),
options.clone(),
cx,
)
})?;
permission.await?;
}
thread.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})?;
anyhow::Ok(())
});
tasks.push(task);
}
cx.spawn(async move |_| {
try_join_all(tasks).await?;
Ok(acp::PromptResponse {
stop_reason: acp::StopReason::EndTurn,
})
})
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
fn session_editor(
&self,
_session_id: &agent_client_protocol::SessionId,
_cx: &mut App,
) -> Option<Rc<dyn AgentSessionEditor>> {
Some(Rc::new(StubAgentSessionEditor))
}
}
struct StubAgentSessionEditor;
impl AgentSessionEditor for StubAgentSessionEditor {
fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
Task::ready(Ok(()))
}
}
}
#[cfg(feature = "test-support")]
pub use test_support::*;

View File

@@ -1,460 +0,0 @@
use agent::ThreadId;
use anyhow::{Context as _, Result, bail};
use file_icons::FileIcons;
use prompt_store::{PromptId, UserPromptId};
use std::{
fmt,
ops::Range,
path::{Path, PathBuf},
str::FromStr,
};
use ui::{App, IconName, SharedString};
use url::Url;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum MentionUri {
File {
abs_path: PathBuf,
is_directory: bool,
},
Symbol {
path: PathBuf,
name: String,
line_range: Range<u32>,
},
Thread {
id: ThreadId,
name: String,
},
TextThread {
path: PathBuf,
name: String,
},
Rule {
id: PromptId,
name: String,
},
Selection {
path: PathBuf,
line_range: Range<u32>,
},
Fetch {
url: Url,
},
}
impl MentionUri {
pub fn parse(input: &str) -> Result<Self> {
let url = url::Url::parse(input)?;
let path = url.path();
match url.scheme() {
"file" => {
if let Some(fragment) = url.fragment() {
let range = fragment
.strip_prefix("L")
.context("Line range must start with \"L\"")?;
let (start, end) = range
.split_once(":")
.context("Line range must use colon as separator")?;
let line_range = start
.parse::<u32>()
.context("Parsing line range start")?
.checked_sub(1)
.context("Line numbers should be 1-based")?
..end
.parse::<u32>()
.context("Parsing line range end")?
.checked_sub(1)
.context("Line numbers should be 1-based")?;
if let Some(name) = single_query_param(&url, "symbol")? {
Ok(Self::Symbol {
name,
path: path.into(),
line_range,
})
} else {
Ok(Self::Selection {
path: path.into(),
line_range,
})
}
} else {
let file_path =
PathBuf::from(format!("{}{}", url.host_str().unwrap_or(""), path));
let is_directory = input.ends_with("/");
Ok(Self::File {
abs_path: file_path,
is_directory,
})
}
}
"zed" => {
if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
Ok(Self::Thread {
id: thread_id.into(),
name,
})
} else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
let name = single_query_param(&url, "name")?.context("Missing thread name")?;
Ok(Self::TextThread {
path: path.into(),
name,
})
} else if let Some(rule_id) = path.strip_prefix("/agent/rule/") {
let name = single_query_param(&url, "name")?.context("Missing rule name")?;
let rule_id = UserPromptId(rule_id.parse()?);
Ok(Self::Rule {
id: rule_id.into(),
name,
})
} else {
bail!("invalid zed url: {:?}", input);
}
}
"http" | "https" => Ok(MentionUri::Fetch { url }),
other => bail!("unrecognized scheme {:?}", other),
}
}
pub fn name(&self) -> String {
match self {
MentionUri::File { abs_path, .. } => abs_path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.into_owned(),
MentionUri::Symbol { name, .. } => name.clone(),
MentionUri::Thread { name, .. } => name.clone(),
MentionUri::TextThread { name, .. } => name.clone(),
MentionUri::Rule { name, .. } => name.clone(),
MentionUri::Selection {
path, line_range, ..
} => selection_name(path, line_range),
MentionUri::Fetch { url } => url.to_string(),
}
}
pub fn icon_path(&self, cx: &mut App) -> SharedString {
match self {
MentionUri::File {
abs_path,
is_directory,
} => {
if *is_directory {
FileIcons::get_folder_icon(false, cx)
.unwrap_or_else(|| IconName::Folder.path().into())
} else {
FileIcons::get_icon(&abs_path, cx)
.unwrap_or_else(|| IconName::File.path().into())
}
}
MentionUri::Symbol { .. } => IconName::Code.path().into(),
MentionUri::Thread { .. } => IconName::Thread.path().into(),
MentionUri::TextThread { .. } => IconName::Thread.path().into(),
MentionUri::Rule { .. } => IconName::Reader.path().into(),
MentionUri::Selection { .. } => IconName::Reader.path().into(),
MentionUri::Fetch { .. } => IconName::ToolWeb.path().into(),
}
}
pub fn as_link<'a>(&'a self) -> MentionLink<'a> {
MentionLink(self)
}
pub fn to_uri(&self) -> Url {
match self {
MentionUri::File {
abs_path,
is_directory,
} => {
let mut url = Url::parse("file:///").unwrap();
let mut path = abs_path.to_string_lossy().to_string();
if *is_directory && !path.ends_with("/") {
path.push_str("/");
}
url.set_path(&path);
url
}
MentionUri::Symbol {
path,
name,
line_range,
} => {
let mut url = Url::parse("file:///").unwrap();
url.set_path(&path.to_string_lossy());
url.query_pairs_mut().append_pair("symbol", name);
url.set_fragment(Some(&format!(
"L{}:{}",
line_range.start + 1,
line_range.end + 1
)));
url
}
MentionUri::Selection { path, line_range } => {
let mut url = Url::parse("file:///").unwrap();
url.set_path(&path.to_string_lossy());
url.set_fragment(Some(&format!(
"L{}:{}",
line_range.start + 1,
line_range.end + 1
)));
url
}
MentionUri::Thread { name, id } => {
let mut url = Url::parse("zed:///").unwrap();
url.set_path(&format!("/agent/thread/{id}"));
url.query_pairs_mut().append_pair("name", name);
url
}
MentionUri::TextThread { path, name } => {
let mut url = Url::parse("zed:///").unwrap();
url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy()));
url.query_pairs_mut().append_pair("name", name);
url
}
MentionUri::Rule { name, id } => {
let mut url = Url::parse("zed:///").unwrap();
url.set_path(&format!("/agent/rule/{id}"));
url.query_pairs_mut().append_pair("name", name);
url
}
MentionUri::Fetch { url } => url.clone(),
}
}
}
impl FromStr for MentionUri {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
Self::parse(s)
}
}
pub struct MentionLink<'a>(&'a MentionUri);
impl fmt::Display for MentionLink<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[@{}]({})", self.0.name(), self.0.to_uri())
}
}
fn single_query_param(url: &Url, name: &'static str) -> Result<Option<String>> {
let pairs = url.query_pairs().collect::<Vec<_>>();
match pairs.as_slice() {
[] => Ok(None),
[(k, v)] => {
if k != name {
bail!("invalid query parameter")
}
Ok(Some(v.to_string()))
}
_ => bail!("too many query pairs"),
}
}
pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
format!(
"{} ({}:{})",
path.file_name().unwrap_or_default().display(),
line_range.start + 1,
line_range.end + 1
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_file_uri() {
let file_uri = "file:///path/to/file.rs";
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
MentionUri::File {
abs_path,
is_directory,
} => {
assert_eq!(abs_path.to_str().unwrap(), "/path/to/file.rs");
assert!(!is_directory);
}
_ => panic!("Expected File variant"),
}
assert_eq!(parsed.to_uri().to_string(), file_uri);
}
#[test]
fn test_parse_directory_uri() {
let file_uri = "file:///path/to/dir/";
let parsed = MentionUri::parse(file_uri).unwrap();
match &parsed {
MentionUri::File {
abs_path,
is_directory,
} => {
assert_eq!(abs_path.to_str().unwrap(), "/path/to/dir/");
assert!(is_directory);
}
_ => panic!("Expected File variant"),
}
assert_eq!(parsed.to_uri().to_string(), file_uri);
}
#[test]
fn test_to_directory_uri_with_slash() {
let uri = MentionUri::File {
abs_path: PathBuf::from("/path/to/dir/"),
is_directory: true,
};
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
}
#[test]
fn test_to_directory_uri_without_slash() {
let uri = MentionUri::File {
abs_path: PathBuf::from("/path/to/dir"),
is_directory: true,
};
assert_eq!(uri.to_uri().to_string(), "file:///path/to/dir/");
}
#[test]
fn test_parse_symbol_uri() {
let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
let parsed = MentionUri::parse(symbol_uri).unwrap();
match &parsed {
MentionUri::Symbol {
path,
name,
line_range,
} => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
assert_eq!(name, "MySymbol");
assert_eq!(line_range.start, 9);
assert_eq!(line_range.end, 19);
}
_ => panic!("Expected Symbol variant"),
}
assert_eq!(parsed.to_uri().to_string(), symbol_uri);
}
#[test]
fn test_parse_selection_uri() {
let selection_uri = "file:///path/to/file.rs#L5:15";
let parsed = MentionUri::parse(selection_uri).unwrap();
match &parsed {
MentionUri::Selection { path, line_range } => {
assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
assert_eq!(line_range.start, 4);
assert_eq!(line_range.end, 14);
}
_ => panic!("Expected Selection variant"),
}
assert_eq!(parsed.to_uri().to_string(), selection_uri);
}
#[test]
fn test_parse_thread_uri() {
let thread_uri = "zed:///agent/thread/session123?name=Thread+name";
let parsed = MentionUri::parse(thread_uri).unwrap();
match &parsed {
MentionUri::Thread {
id: thread_id,
name,
} => {
assert_eq!(thread_id.to_string(), "session123");
assert_eq!(name, "Thread name");
}
_ => panic!("Expected Thread variant"),
}
assert_eq!(parsed.to_uri().to_string(), thread_uri);
}
#[test]
fn test_parse_rule_uri() {
let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some+rule";
let parsed = MentionUri::parse(rule_uri).unwrap();
match &parsed {
MentionUri::Rule { id, name } => {
assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52");
assert_eq!(name, "Some rule");
}
_ => panic!("Expected Rule variant"),
}
assert_eq!(parsed.to_uri().to_string(), rule_uri);
}
#[test]
fn test_parse_fetch_http_uri() {
let http_uri = "http://example.com/path?query=value#fragment";
let parsed = MentionUri::parse(http_uri).unwrap();
match &parsed {
MentionUri::Fetch { url } => {
assert_eq!(url.to_string(), http_uri);
}
_ => panic!("Expected Fetch variant"),
}
assert_eq!(parsed.to_uri().to_string(), http_uri);
}
#[test]
fn test_parse_fetch_https_uri() {
let https_uri = "https://example.com/api/endpoint";
let parsed = MentionUri::parse(https_uri).unwrap();
match &parsed {
MentionUri::Fetch { url } => {
assert_eq!(url.to_string(), https_uri);
}
_ => panic!("Expected Fetch variant"),
}
assert_eq!(parsed.to_uri().to_string(), https_uri);
}
#[test]
fn test_invalid_scheme() {
assert!(MentionUri::parse("ftp://example.com").is_err());
assert!(MentionUri::parse("ssh://example.com").is_err());
assert!(MentionUri::parse("unknown://example.com").is_err());
}
#[test]
fn test_invalid_zed_path() {
assert!(MentionUri::parse("zed:///invalid/path").is_err());
assert!(MentionUri::parse("zed:///agent/unknown/test").is_err());
}
#[test]
fn test_invalid_line_range_format() {
// Missing L prefix
assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
// Missing colon separator
assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
// Invalid numbers
assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
}
#[test]
fn test_invalid_query_parameters() {
// Invalid query parameter name
assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
// Too many query parameters
assert!(
MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
);
}
#[test]
fn test_zero_based_line_numbers() {
// Test that 0-based line numbers are rejected (should be 1-based)
assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
}
}

View File

@@ -716,10 +716,18 @@ impl ActivityIndicator {
})),
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Updated { version } => Some(Content {
AutoUpdateStatus::Updated {
binary_path,
version,
} => Some(Content {
icon: None,
message: "Click to restart and update Zed".to_string(),
on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))),
on_click: Some(Arc::new({
let reload = workspace::Reload {
binary_path: Some(binary_path.clone()),
};
move |_, _, cx| workspace::reload(&reload, cx)
})),
tooltip_message: Some(Self::version_tooltip_message(&version)),
}),
AutoUpdateStatus::Errored => Some(Content {

View File

@@ -813,6 +813,7 @@ impl Thread {
}
fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
dbg!("finalize_pending_checkpoint");
let pending_checkpoint = if self.is_generating() {
return;
} else if let Some(checkpoint) = self.pending_checkpoint.take() {
@@ -829,10 +830,13 @@ impl Thread {
pending_checkpoint: ThreadCheckpoint,
cx: &mut Context<Self>,
) {
dbg!("finalize_checkpoint");
let git_store = self.project.read(cx).git_store().clone();
let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
cx.spawn(async move |this, cx| match final_checkpoint.await {
Ok(final_checkpoint) => {
dbg!(&pending_checkpoint.git_checkpoint);
dbg!(&final_checkpoint);
let equal = git_store
.update(cx, |store, cx| {
store.compare_checkpoints(
@@ -844,7 +848,7 @@ impl Thread {
.await
.unwrap_or(false);
if !equal {
if dbg!(!equal) {
this.update(cx, |this, cx| {
this.insert_checkpoint(pending_checkpoint, cx)
})?;
@@ -860,6 +864,7 @@ impl Thread {
}
fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
dbg!("insert_checkpoint");
self.checkpoints_by_message
.insert(checkpoint.message_id, checkpoint);
cx.emit(ThreadEvent::CheckpointChanged);
@@ -867,6 +872,7 @@ impl Thread {
}
pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
dbg!();
self.last_restore_checkpoint.as_ref()
}
@@ -2268,15 +2274,6 @@ impl Thread {
max_attempts: 3,
})
}
Other(err)
if err.is::<PaymentRequiredError>()
|| err.is::<ModelRequestLimitReachedError>() =>
{
// Retrying won't help for Payment Required or Model Request Limit errors (where
// the user must upgrade to usage-based billing to get more requests, or else wait
// for a significant amount of time for the request limit to reset).
None
}
// Conservatively assume that any other errors are non-retryable
HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed {
delay: BASE_RETRY_DELAY,

View File

@@ -205,22 +205,6 @@ impl ThreadStore {
(this, ready_rx)
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake(project: Entity<Project>, cx: &mut App) -> Self {
Self {
project,
tools: cx.new(|_| ToolWorkingSet::default()),
prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
prompt_store: None,
context_server_tool_ids: HashMap::default(),
threads: Vec::new(),
project_context: SharedProjectContext::default(),
reload_system_prompt_tx: mpsc::channel(0).0,
_reload_system_prompt_task: Task::ready(()),
_subscriptions: vec![],
}
}
fn handle_project_event(
&mut self,
_project: Entity<Project>,

View File

@@ -23,7 +23,6 @@ assistant_tools.workspace = true
chrono.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true
fs.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -49,7 +48,6 @@ settings.workspace = true
smol.workspace = true
task.workspace = true
terminal.workspace = true
text.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
@@ -62,7 +60,6 @@ workspace-hack.workspace = true
ctor.workspace = true
client = { workspace = true, "features" = ["test-support"] }
clock = { workspace = true, "features" = ["test-support"] }
context_server = { workspace = true, "features" = ["test-support"] }
editor = { workspace = true, "features" = ["test-support"] }
env_logger.workspace = true
fs = { workspace = true, "features" = ["test-support"] }

View File

@@ -1,26 +1,21 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
WebSearchTool,
CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool,
GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool,
ThinkingTool, ToolCallAuthorization, WebSearchTool,
};
use acp_thread::AgentModelSelector;
use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
};
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
use language_model::{LanguageModel, LanguageModelRegistry};
use project::{Project, ProjectItem, ProjectPath, Worktree};
use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@@ -53,104 +48,6 @@ struct Session {
_subscription: Subscription,
}
pub struct LanguageModels {
/// Access language model by ID
models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
/// Cached list for returning language model information
model_list: acp_thread::AgentModelList,
refresh_models_rx: watch::Receiver<()>,
refresh_models_tx: watch::Sender<()>,
}
impl LanguageModels {
fn new(cx: &App) -> Self {
let (refresh_models_tx, refresh_models_rx) = watch::channel(());
let mut this = Self {
models: HashMap::default(),
model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
refresh_models_rx,
refresh_models_tx,
};
this.refresh_list(cx);
this
}
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
let mut language_model_list = IndexMap::default();
let mut recommended_models = HashSet::default();
let mut recommended = Vec::new();
for provider in &providers {
for model in provider.recommended_models(cx) {
recommended_models.insert(model.id());
recommended.push(Self::map_language_model_to_info(&model, &provider));
}
}
if !recommended.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName("Recommended".into()),
recommended,
);
}
let mut models = HashMap::default();
for provider in providers {
let mut provider_models = Vec::new();
for model in provider.provided_models(cx) {
let model_info = Self::map_language_model_to_info(&model, &provider);
let model_id = model_info.id.clone();
if !recommended_models.contains(&model.id()) {
provider_models.push(model_info);
}
models.insert(model_id, model);
}
if !provider_models.is_empty() {
language_model_list.insert(
acp_thread::AgentModelGroupName(provider.name().0.clone()),
provider_models,
);
}
}
self.models = models;
self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
self.refresh_models_tx.send(()).ok();
}
fn watch(&self) -> watch::Receiver<()> {
self.refresh_models_rx.clone()
}
pub fn model_from_id(
&self,
model_id: &acp_thread::AgentModelId,
) -> Option<Arc<dyn LanguageModel>> {
self.models.get(model_id).cloned()
}
fn map_language_model_to_info(
model: &Arc<dyn LanguageModel>,
provider: &Arc<dyn LanguageModelProvider>,
) -> acp_thread::AgentModelInfo {
acp_thread::AgentModelInfo {
id: Self::model_id(model),
name: model.name().0,
icon: Some(provider.icon()),
}
}
fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
}
}
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
@@ -158,14 +55,10 @@ pub struct NativeAgent {
project_context: Rc<RefCell<ProjectContext>>,
project_context_needs_refresh: watch::Sender<()>,
_maintain_project_context: Task<Result<()>>,
context_server_registry: Entity<ContextServerRegistry>,
/// Shared templates for all threads
templates: Arc<Templates>,
/// Cached model information
models: LanguageModels,
project: Entity<Project>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
_subscriptions: Vec<Subscription>,
}
@@ -174,7 +67,6 @@ impl NativeAgent {
project: Entity<Project>,
templates: Arc<Templates>,
prompt_store: Option<Entity<PromptStore>>,
fs: Arc<dyn Fs>,
cx: &mut AsyncApp,
) -> Result<Entity<NativeAgent>> {
log::info!("Creating new NativeAgent");
@@ -184,13 +76,7 @@ impl NativeAgent {
.await;
cx.new(|cx| {
let mut subscriptions = vec![
cx.subscribe(&project, Self::handle_project_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
Self::handle_models_updated_event,
),
];
let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
if let Some(prompt_store) = prompt_store.as_ref() {
subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
}
@@ -204,23 +90,14 @@ impl NativeAgent {
_maintain_project_context: cx.spawn(async move |this, cx| {
Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
}),
context_server_registry: cx.new(|cx| {
ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
}),
templates,
models: LanguageModels::new(cx),
project,
prompt_store,
fs,
_subscriptions: subscriptions,
}
})
}
pub fn models(&self) -> &LanguageModels {
&self.models
}
async fn maintain_project_context(
this: WeakEntity<Self>,
mut needs_refresh: watch::Receiver<()>,
@@ -416,104 +293,75 @@ impl NativeAgent {
) {
self.project_context_needs_refresh.send(()).ok();
}
fn handle_models_updated_event(
&mut self,
_registry: Entity<LanguageModelRegistry>,
_event: &language_model::Event,
cx: &mut Context<Self>,
) {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
let model_id = LanguageModels::model_id(&thread.selected_model);
if let Some(model) = self.models.model_from_id(&model_id) {
thread.selected_model = model.clone();
}
});
}
}
}
/// Wrapper struct that implements the AgentConnection trait
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
impl ModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
log::debug!("NativeAgentConnection::list_models called");
let list = self.0.read(cx).models.model_list.clone();
Task::ready(if list.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(list)
cx.spawn(async move |cx| {
cx.update(|cx| {
let registry = LanguageModelRegistry::read_global(cx);
let models = registry.available_models(cx).collect::<Vec<_>>();
log::info!("Found {} available models", models.len());
if models.is_empty() {
Err(anyhow::anyhow!("No models available"))
} else {
Ok(models)
}
})?
})
}
fn select_model(
&self,
session_id: acp::SessionId,
model_id: acp_thread::AgentModelId,
cx: &mut App,
model: Arc<dyn LanguageModel>,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
log::info!("Setting model for session {}: {}", session_id, model_id);
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
};
thread.update(cx, |thread, _cx| {
thread.selected_model = model.clone();
});
update_settings_file::<AgentSettings>(
self.0.read(cx).fs.clone(),
cx,
move |settings, _cx| {
settings.set_model(model);
},
log::info!(
"Setting model for session {}: {:?}",
session_id,
model.name()
);
let agent = self.0.clone();
Task::ready(Ok(()))
cx.spawn(async move |cx| {
agent.update(cx, |agent, cx| {
if let Some(session) = agent.sessions.get(&session_id) {
session.thread.update(cx, |thread, _cx| {
thread.selected_model = model;
});
Ok(())
} else {
Err(anyhow!("Session not found"))
}
})?
})
}
fn selected_model(
&self,
session_id: &acp::SessionId,
cx: &mut App,
) -> Task<Result<acp_thread::AgentModelInfo>> {
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
let session_id = session_id.clone();
let Some(thread) = self
.0
.read(cx)
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
else {
return Task::ready(Err(anyhow!("Session not found")));
};
let model = thread.read(cx).selected_model.clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
};
Task::ready(Ok(LanguageModels::map_language_model_to_info(
&model, &provider,
)))
}
fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
self.0.read(cx).models.watch()
cx.spawn(async move |cx| {
let thread = agent
.read_with(cx, |agent, _| {
agent
.sessions
.get(&session_id)
.map(|session| session.thread.clone())
})?
.ok_or_else(|| anyhow::anyhow!("Session not found"))?;
let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
Ok(selected)
})
}
}
@@ -522,7 +370,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
cx: &mut AsyncApp,
) -> Task<Result<Entity<acp_thread::AcpThread>>> {
let agent = self.0.clone();
log::info!("Creating new thread for project at: {:?}", cwd);
@@ -537,13 +385,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Create AcpThread
let acp_thread = cx.update(|cx| {
cx.new(|cx| {
acp_thread::AcpThread::new(
"agent2",
self.clone(),
project.clone(),
session_id.clone(),
cx,
)
acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
})
})?;
let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
@@ -561,44 +403,37 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
let default_model = registry
.default_model()
.and_then(|default_model| {
agent
.models
.model_from_id(&LanguageModels::model_id(&default_model.model))
.map(|configured| {
log::info!(
"Using configured default model: {:?} from provider: {:?}",
configured.model.name(),
configured.provider.name()
);
configured.model
})
.ok_or_else(|| {
log::warn!("No default model configured in settings");
anyhow!(
"No default model. Please configure a default model in settings."
)
anyhow!("No default model configured. Please configure a default model in settings.")
})?;
let thread = cx.new(|cx| {
let mut thread = Thread::new(
project.clone(),
agent.project_context.clone(),
agent.context_server_registry.clone(),
action_log.clone(),
agent.templates.clone(),
default_model,
cx,
);
thread.add_tool(CopyPathTool::new(project.clone()));
let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
thread.add_tool(CreateDirectoryTool::new(project.clone()));
thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
thread.add_tool(CopyPathTool::new(project.clone()));
thread.add_tool(DiagnosticsTool::new(project.clone()));
thread.add_tool(EditFileTool::new(cx.entity()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(MovePathTool::new(project.clone()));
thread.add_tool(NowTool);
thread.add_tool(ListDirectoryTool::new(project.clone()));
thread.add_tool(OpenTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(TerminalTool::new(project.clone(), cx));
thread.add_tool(ThinkingTool);
thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
thread.add_tool(FindPathTool::new(project.clone()));
thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
thread.add_tool(GrepTool::new(project.clone()));
thread.add_tool(ReadFileTool::new(project.clone(), action_log));
thread.add_tool(EditFileTool::new(cx.entity()));
thread.add_tool(NowTool);
thread.add_tool(TerminalTool::new(project.clone(), cx));
// TODO: Needs to be conditional based on zed model or not
thread.add_tool(WebSearchTool);
thread
});
@@ -615,7 +450,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
acp_thread: acp_thread.downgrade(),
_subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
this.sessions.remove(acp_thread.session_id());
}),
})
},
);
})?;
@@ -632,17 +467,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
Task::ready(Ok(()))
}
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
}
fn prompt(
&self,
id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
@@ -663,14 +496,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})?;
log::debug!("Found session for: {}", session_id);
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
.map(Into::into)
.collect::<Vec<_>>();
log::info!("Converted prompt to message: {} chars", content.len());
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
// Convert prompt to message
let message = convert_prompt_to_message(params.prompt);
log::info!("Converted prompt to message: {} chars", message.len());
log::debug!("Message content: {}", message);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
@@ -678,8 +507,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
let mut response_stream =
thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
@@ -773,33 +601,44 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}
});
}
fn session_editor(
&self,
session_id: &agent_client_protocol::SessionId,
cx: &mut App,
) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
self.0.update(cx, |agent, _cx| {
agent
.sessions
.get(session_id)
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
}
struct NativeAgentSessionEditor(Entity<Thread>);
/// Convert ACP content blocks to a message string
fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
log::debug!("Converting {} content blocks to message", blocks.len());
let mut message = String::new();
impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
for block in blocks {
match block {
acp::ContentBlock::Text(text) => {
log::trace!("Processing text block: {} chars", text.text.len());
message.push_str(&text.text);
}
acp::ContentBlock::ResourceLink(link) => {
log::trace!("Processing resource link: {}", link.uri);
message.push_str(&format!(" @{} ", link.uri));
}
acp::ContentBlock::Image(_) => {
log::trace!("Processing image block");
message.push_str(" [image] ");
}
acp::ContentBlock::Audio(_) => {
log::trace!("Processing audio block");
message.push_str(" [audio] ");
}
acp::ContentBlock::Resource(resource) => {
log::trace!("Processing resource block: {:?}", resource.resource);
message.push_str(&format!(" [resource: {:?}] ", resource.resource));
}
}
}
message
}
#[cfg(test)]
mod tests {
use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
use fs::FakeFs;
use gpui::TestAppContext;
use serde_json::json;
@@ -817,15 +656,9 @@ mod tests {
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
.await
.unwrap();
agent.read_with(cx, |agent, _| {
assert_eq!(agent.project_context.borrow().worktrees, vec![])
});
@@ -866,127 +699,13 @@ mod tests {
});
}
#[gpui::test]
async fn test_listing_models(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/", json!({ "a": {} })).await;
let project = Project::test(fs.clone(), [], cx).await;
let connection = NativeAgentConnection(
NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap(),
);
let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
let acp_thread::AgentModelList::Grouped(models) = models else {
panic!("Unexpected model group");
};
assert_eq!(
models,
IndexMap::from_iter([(
AgentModelGroupName("Fake".into()),
vec![AgentModelInfo {
id: AgentModelId("fake/fake".into()),
name: "Fake".into(),
icon: Some(ui::IconName::ZedAssistant),
}]
)])
);
}
#[gpui::test]
async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"default_model": {
"provider": "foo",
"model": "bar"
}
}
})
.to_string()
.into_bytes(),
)
.await;
let project = Project::test(fs.clone(), [], cx).await;
// Create the agent and connection
let agent = NativeAgent::new(
project.clone(),
Templates::new(),
None,
fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Create a thread/session
let acp_thread = cx
.update(|cx| {
Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
})
.await
.unwrap();
let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
// Select a model
let model_id = AgentModelId("fake/fake".into());
cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
.await
.unwrap();
// Verify the thread has the selected model
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
assert_eq!(thread.selected_model.id().0, "fake");
});
});
cx.run_until_parked();
// Verify settings file was updated
let settings_content = fs.load(paths::settings_file()).await.unwrap();
let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
// Check that the agent settings contain the selected model
assert_eq!(
settings_json["agent"]["default_model"]["model"],
json!("fake")
);
assert_eq!(
settings_json["agent"]["default_model"]["provider"],
json!("fake")
);
}
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
Project::init_settings(cx);
agent_settings::init(cx);
language::init(cx);
LanguageModelRegistry::test(cx);
});
}
}

View File

@@ -1,8 +1,8 @@
use std::{path::Path, rc::Rc, sync::Arc};
use std::path::Path;
use std::rc::Rc;
use agent_servers::AgentServer;
use anyhow::Result;
use fs::Fs;
use gpui::{App, Entity, Task};
use project::Project;
use prompt_store::PromptStore;
@@ -10,15 +10,7 @@ use prompt_store::PromptStore;
use crate::{NativeAgent, NativeAgentConnection, templates::Templates};
#[derive(Clone)]
pub struct NativeAgentServer {
fs: Arc<dyn Fs>,
}
impl NativeAgentServer {
pub fn new(fs: Arc<dyn Fs>) -> Self {
Self { fs }
}
}
pub struct NativeAgentServer;
impl AgentServer for NativeAgentServer {
fn name(&self) -> &'static str {
@@ -49,7 +41,6 @@ impl AgentServer for NativeAgentServer {
_root_dir
);
let project = project.clone();
let fs = self.fs.clone();
let prompt_store = PromptStore::global(cx);
cx.spawn(async move |cx| {
log::debug!("Creating templates for native agent");
@@ -57,7 +48,7 @@ impl AgentServer for NativeAgentServer {
let prompt_store = prompt_store.await?;
log::debug!("Creating native agent entity");
let agent = NativeAgent::new(project, templates, Some(prompt_store), fs, cx).await?;
let agent = NativeAgent::new(project, templates, Some(prompt_store), cx).await?;
// Create the connection wrapper
let connection = NativeAgentConnection(agent);

View File

@@ -1,8 +1,7 @@
use super::*;
use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
use acp_thread::AgentConnection;
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
use anyhow::Result;
use client::{Client, UserStore};
use fs::{FakeFs, Fs};
@@ -13,8 +12,8 @@ use gpui::{
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
fake_provider::FakeLanguageModel,
LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
StopReason, fake_provider::FakeLanguageModel,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -37,19 +36,15 @@ async fn test_echo(cx: &mut TestAppContext) {
let events = thread
.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
thread.send("Testing: Reply with 'Hello'", cx)
})
.collect()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.last_message().unwrap().to_markdown(),
indoc! {"
## Assistant
Hello
"}
)
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
@@ -62,13 +57,12 @@ async fn test_thinking(cx: &mut TestAppContext) {
let events = thread
.update(cx, |thread, cx| {
thread.send(
UserMessageId::new(),
[indoc! {"
indoc! {"
Testing:
Generate a thinking step where you just think the word 'Think',
and have your final answer be 'Hello'
"}],
"},
cx,
)
})
@@ -76,10 +70,9 @@ async fn test_thinking(cx: &mut TestAppContext) {
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.last_message().unwrap().to_markdown(),
thread.messages().last().unwrap().to_markdown(),
indoc! {"
## Assistant
## assistant
<think>Think</think>
Hello
"}
@@ -100,9 +93,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["abc"], cx)
});
thread.update(cx, |thread, cx| thread.send("abc", cx));
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
@@ -139,8 +130,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(
UserMessageId::new(),
["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
"Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
cx,
)
})
@@ -154,11 +144,7 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.remove_tool(&AgentTool::name(&EchoTool));
thread.add_tool(DelayTool);
thread.send(
UserMessageId::new(),
[
"Now call the delay tool with 200ms.",
"When the timer goes off, then you echo the output of the tool.",
],
"Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
cx,
)
})
@@ -168,21 +154,18 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| {
assert!(
thread
.last_message()
.unwrap()
.as_agent_message()
.messages()
.last()
.unwrap()
.content
.iter()
.any(|content| {
if let AgentMessageContent::Text(text) = content {
if let MessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
}
}),
"{}",
thread.to_markdown()
})
);
});
}
@@ -195,7 +178,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
// Test a tool call that's likely to complete *before* streaming stops.
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(WordListTool);
thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
thread.send("Test the word_list tool.", cx)
});
let mut saw_partial_tool_use = false;
@@ -203,10 +186,8 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
let message = thread.last_message().unwrap();
let agent_message = message.as_agent_message().unwrap();
let last_content = agent_message.content.last().unwrap();
if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
let last_content = thread.messages().last().unwrap().content.last().unwrap();
if let MessageContent::ToolUse(last_tool_use) = last_content {
assert_eq!(last_tool_use.name.as_ref(), "word_list");
if tool_call.status == acp::ToolCallStatus::Pending {
if !last_tool_use.is_input_complete
@@ -244,7 +225,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
thread.send(UserMessageId::new(), ["abc"], cx)
thread.send("abc", cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -288,14 +269,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
assert_eq!(
message.content,
vec![
language_model::MessageContent::ToolResult(LanguageModelToolResult {
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}),
language_model::MessageContent::ToolResult(LanguageModelToolResult {
MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: true,
@@ -328,15 +309,13 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![language_model::MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}
)]
vec![MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
})]
);
// Simulate a final tool call, ensuring we don't trigger authorization.
@@ -355,15 +334,13 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let message = completion.messages.last().unwrap();
assert_eq!(
message.content,
vec![language_model::MessageContent::ToolResult(
LanguageModelToolResult {
tool_use_id: "tool_id_4".into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
}
)]
vec![MessageContent::ToolResult(LanguageModelToolResult {
tool_use_id: "tool_id_4".into(),
tool_name: ToolRequiringPermission.name().into(),
is_error: false,
content: "Allowed".into(),
output: Some("Allowed".into())
})]
);
}
@@ -372,9 +349,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["abc"], cx)
});
let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -466,12 +441,7 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.add_tool(DelayTool);
thread.send(
UserMessageId::new(),
[
"Call the delay tool twice in the same message.",
"Once with 100ms. Once with 300ms.",
"When both timers are complete, describe the outputs.",
],
"Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
cx,
)
})
@@ -482,13 +452,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
thread.update(cx, |thread, _cx| {
let last_message = thread.last_message().unwrap();
let agent_message = last_message.as_agent_message().unwrap();
let text = agent_message
let last_message = thread.messages().last().unwrap();
let text = last_message
.content
.iter()
.filter_map(|content| {
if let AgentMessageContent::Text(text) = content {
if let MessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
@@ -500,82 +469,6 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_profiles(cx: &mut TestAppContext) {
let ThreadTest {
model, thread, fs, ..
} = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
thread.update(cx, |thread, _cx| {
thread.add_tool(DelayTool);
thread.add_tool(EchoTool);
thread.add_tool(InfiniteTool);
});
// Override profiles and wait for settings to be loaded.
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"profiles": {
"test-1": {
"name": "Test Profile 1",
"tools": {
EchoTool.name(): true,
DelayTool.name(): true,
}
},
"test-2": {
"name": "Test Profile 2",
"tools": {
InfiniteTool.name(): true,
}
}
}
}
})
.to_string()
.into_bytes(),
)
.await;
cx.run_until_parked();
// Test that test-1 profile (default) has echo and delay tools
thread.update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-1".into()));
thread.send(UserMessageId::new(), ["test"], cx);
});
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(pending_completions.len(), 1);
let completion = pending_completions.pop().unwrap();
let tool_names: Vec<String> = completion
.tools
.iter()
.map(|tool| tool.name.clone())
.collect();
assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
fake_model.end_last_completion_stream();
// Switch to test-2 profile, and verify that it has only the infinite tool.
thread.update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-2".into()));
thread.send(UserMessageId::new(), ["test2"], cx)
});
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(pending_completions.len(), 1);
let completion = pending_completions.pop().unwrap();
let tool_names: Vec<String> = completion
.tools
.iter()
.map(|tool| tool.name.clone())
.collect();
assert_eq!(tool_names, vec![InfiniteTool.name()]);
}
#[gpui::test]
#[ignore = "can't run on CI yet"]
async fn test_cancellation(cx: &mut TestAppContext) {
@@ -585,8 +478,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
thread.add_tool(InfiniteTool);
thread.add_tool(EchoTool);
thread.send(
UserMessageId::new(),
["Call the echo tool, then call the infinite tool, then explain their output"],
"Call the echo tool and then call the infinite tool, then explain their output",
cx,
)
});
@@ -631,20 +523,14 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Ensure we can still send a new message after cancellation.
let events = thread
.update(cx, |thread, cx| {
thread.send(
UserMessageId::new(),
["Testing: reply with 'Hello' then stop."],
cx,
)
thread.send("Testing: reply with 'Hello' then stop.", cx)
})
.collect::<Vec<_>>()
.await;
thread.update(cx, |thread, _cx| {
let message = thread.last_message().unwrap();
let agent_message = message.as_agent_message().unwrap();
assert_eq!(
agent_message.content,
vec![AgentMessageContent::Text("Hello".to_string())]
thread.messages().last().unwrap().content,
vec![MessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
@@ -655,16 +541,13 @@ async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hello"], cx)
});
let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
## user
Hello
"}
);
@@ -676,12 +559,9 @@ async fn test_refusal(cx: &mut TestAppContext) {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
## user
Hello
## Assistant
## assistant
Hey!
"}
);
@@ -697,85 +577,6 @@ async fn test_refusal(cx: &mut TestAppContext) {
});
}
#[gpui::test]
async fn test_truncate(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
let message_id = UserMessageId::new();
thread.update(cx, |thread, cx| {
thread.send(message_id.clone(), ["Hello"], cx)
});
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello
"}
);
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hello
## Assistant
Hey!
"}
);
});
thread
.update(cx, |thread, _cx| thread.truncate(message_id))
.unwrap();
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(thread.to_markdown(), "");
});
// Ensure we can still send a new message after truncation.
thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Hi"], cx)
});
thread.update(cx, |thread, _cx| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hi
"}
);
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
## User
Hi
## Assistant
Ahoy!
"}
);
});
}
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@@ -794,26 +595,19 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
language_models::init(user_store.clone(), client.clone(), cx);
Project::init_settings(cx);
LanguageModelRegistry::test(cx);
agent_settings::init(cx);
});
cx.executor().forbid_parking();
// Create a project for new_thread
let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
fake_fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
let cwd = Path::new("/test");
// Create agent and connection
let agent = NativeAgent::new(
project.clone(),
templates.clone(),
None,
fake_fs.clone(),
&mut cx.to_async(),
)
.await
.unwrap();
let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
.await
.unwrap();
let connection = NativeAgentConnection(agent.clone());
// Test model_selector returns Some
@@ -826,22 +620,22 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Test list_models
let listed_models = cx
.update(|cx| selector.list_models(cx))
.update(|cx| {
let mut async_cx = cx.to_async();
selector.list_models(&mut async_cx)
})
.await
.expect("list_models should succeed");
let AgentModelList::Grouped(listed_models) = listed_models else {
panic!("Unexpected model list type");
};
assert!(!listed_models.is_empty(), "should have at least one model");
assert_eq!(
listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
"fake/fake"
);
assert_eq!(listed_models[0].id().0, "fake");
// Create a thread using new_thread
let connection_rc = Rc::new(connection.clone());
let acp_thread = cx
.update(|cx| connection_rc.new_thread(project, cwd, cx))
.update(|cx| {
let mut async_cx = cx.to_async();
connection_rc.new_thread(project, cwd, &mut async_cx)
})
.await
.expect("new_thread should succeed");
@@ -850,12 +644,12 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
// Test selected_model returns the default
let model = cx
.update(|cx| selector.selected_model(&session_id, cx))
.update(|cx| {
let mut async_cx = cx.to_async();
selector.selected_model(&session_id, &mut async_cx)
})
.await
.expect("selected_model should succeed");
let model = cx
.update(|cx| agent.read(cx).models().model_from_id(&model.id))
.unwrap();
let model = model.as_fake();
assert_eq!(model.id().0, "fake", "should return default model");
@@ -889,7 +683,6 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let result = cx
.update(|cx| {
connection.prompt(
Some(acp_thread::UserMessageId::new()),
acp::PromptRequest {
session_id: session_id.clone(),
prompt: vec!["ghi".into()],
@@ -912,9 +705,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
let fake_model = model.as_fake();
let mut events = thread.update(cx, |thread, cx| {
thread.send(UserMessageId::new(), ["Think"], cx)
});
let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
cx.run_until_parked();
// Simulate streaming partial input.
@@ -999,7 +790,6 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
id: acp::ToolCallId("1".into()),
fields: acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::Completed),
raw_output: Some("Finished thinking.".into()),
..Default::default()
},
}
@@ -1023,7 +813,6 @@ struct ThreadTest {
model: Arc<dyn LanguageModel>,
thread: Entity<Thread>,
project_context: Rc<RefCell<ProjectContext>>,
fs: Arc<FakeFs>,
}
enum TestModel {
@@ -1046,57 +835,30 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.executor().allow_parking();
let fs = FakeFs::new(cx.background_executor.clone());
fs.create_dir(paths::settings_file().parent().unwrap())
.await
.unwrap();
fs.insert_file(
paths::settings_file(),
json!({
"agent": {
"default_profile": "test-profile",
"profiles": {
"test-profile": {
"name": "Test Profile",
"tools": {
EchoTool.name(): true,
DelayTool.name(): true,
WordListTool.name(): true,
ToolRequiringPermission.name(): true,
InfiniteTool.name(): true,
}
}
}
}
})
.to_string()
.into_bytes(),
)
.await;
cx.update(|cx| {
settings::init(cx);
watch_settings(fs.clone(), cx);
Project::init_settings(cx);
agent_settings::init(cx);
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
watch_settings(fs.clone(), cx);
});
let templates = Templates::new();
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
let model = cx
.update(|cx| {
gpui_tokio::init(cx);
let http_client = ReqwestClient::user_agent("agent tests").unwrap();
cx.set_http_client(Arc::new(http_client));
client::init_settings(cx);
let client = Client::production(cx);
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
language_model::init(client.clone(), cx);
language_models::init(user_store.clone(), client.clone(), cx);
if let TestModel::Fake = model {
Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
} else {
@@ -1119,25 +881,20 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
.await;
let project_context = Rc::new(RefCell::new(ProjectContext::default()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project,
project_context.clone(),
context_server_registry,
action_log,
templates,
model.clone(),
cx,
)
});
ThreadTest {
model,
thread,
project_context,
fs,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
mod context_server_registry;
mod copy_path_tool;
mod create_directory_tool;
mod delete_path_tool;
@@ -16,7 +15,6 @@ mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
pub use context_server_registry::*;
pub use copy_path_tool::*;
pub use create_directory_tool::*;
pub use delete_path_tool::*;

View File

@@ -1,231 +0,0 @@
use crate::{AgentToolOutput, AnyAgentTool, ToolCallEventStream};
use agent_client_protocol::ToolKind;
use anyhow::{Result, anyhow, bail};
use collections::{BTreeMap, HashMap};
use context_server::ContextServerId;
use gpui::{App, Context, Entity, SharedString, Task};
use project::context_server_store::{ContextServerStatus, ContextServerStore};
use std::sync::Arc;
use util::ResultExt;
pub struct ContextServerRegistry {
server_store: Entity<ContextServerStore>,
registered_servers: HashMap<ContextServerId, RegisteredContextServer>,
_subscription: gpui::Subscription,
}
struct RegisteredContextServer {
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
load_tools: Task<Result<()>>,
}
impl ContextServerRegistry {
pub fn new(server_store: Entity<ContextServerStore>, cx: &mut Context<Self>) -> Self {
let mut this = Self {
server_store: server_store.clone(),
registered_servers: HashMap::default(),
_subscription: cx.subscribe(&server_store, Self::handle_context_server_store_event),
};
for server in server_store.read(cx).running_servers() {
this.reload_tools_for_server(server.id(), cx);
}
this
}
pub fn servers(
&self,
) -> impl Iterator<
Item = (
&ContextServerId,
&BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
),
> {
self.registered_servers
.iter()
.map(|(id, server)| (id, &server.tools))
}
fn reload_tools_for_server(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
let Some(server) = self.server_store.read(cx).get_running_server(&server_id) else {
return;
};
let Some(client) = server.client() else {
return;
};
if !client.capable(context_server::protocol::ServerCapability::Tools) {
return;
}
let registered_server =
self.registered_servers
.entry(server_id.clone())
.or_insert(RegisteredContextServer {
tools: BTreeMap::default(),
load_tools: Task::ready(Ok(())),
});
registered_server.load_tools = cx.spawn(async move |this, cx| {
let response = client
.request::<context_server::types::requests::ListTools>(())
.await;
this.update(cx, |this, cx| {
let Some(registered_server) = this.registered_servers.get_mut(&server_id) else {
return;
};
registered_server.tools.clear();
if let Some(response) = response.log_err() {
for tool in response.tools {
let tool = Arc::new(ContextServerTool::new(
this.server_store.clone(),
server.id(),
tool,
));
registered_server.tools.insert(tool.name(), tool);
}
cx.notify();
}
})
});
}
fn handle_context_server_store_event(
&mut self,
_: Entity<ContextServerStore>,
event: &project::context_server_store::Event,
cx: &mut Context<Self>,
) {
match event {
project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
match status {
ContextServerStatus::Starting => {}
ContextServerStatus::Running => {
self.reload_tools_for_server(server_id.clone(), cx);
}
ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
self.registered_servers.remove(&server_id);
cx.notify();
}
}
}
}
}
}
struct ContextServerTool {
store: Entity<ContextServerStore>,
server_id: ContextServerId,
tool: context_server::types::Tool,
}
impl ContextServerTool {
fn new(
store: Entity<ContextServerStore>,
server_id: ContextServerId,
tool: context_server::types::Tool,
) -> Self {
Self {
store,
server_id,
tool,
}
}
}
impl AnyAgentTool for ContextServerTool {
fn name(&self) -> SharedString {
self.tool.name.clone().into()
}
fn description(&self) -> SharedString {
self.tool.description.clone().unwrap_or_default().into()
}
fn kind(&self) -> ToolKind {
ToolKind::Other
}
fn initial_title(&self, _input: serde_json::Value) -> SharedString {
format!("Run MCP tool `{}`", self.tool.name).into()
}
fn input_schema(
&self,
format: language_model::LanguageModelToolSchemaFormat,
) -> Result<serde_json::Value> {
let mut schema = self.tool.input_schema.clone();
assistant_tool::adapt_schema_to_format(&mut schema, format)?;
Ok(match schema {
serde_json::Value::Null => {
serde_json::json!({ "type": "object", "properties": [] })
}
serde_json::Value::Object(map) if map.is_empty() => {
serde_json::json!({ "type": "object", "properties": [] })
}
_ => schema,
})
}
fn run(
self: Arc<Self>,
input: serde_json::Value,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<AgentToolOutput>> {
let Some(server) = self.store.read(cx).get_running_server(&self.server_id) else {
return Task::ready(Err(anyhow!("Context server not found")));
};
let tool_name = self.tool.name.clone();
let server_clone = server.clone();
let input_clone = input.clone();
cx.spawn(async move |_cx| {
let Some(protocol) = server_clone.client() else {
bail!("Context server not initialized");
};
let arguments = if let serde_json::Value::Object(map) = input_clone {
Some(map.into_iter().collect())
} else {
None
};
log::trace!(
"Running tool: {} with arguments: {:?}",
tool_name,
arguments
);
let response = protocol
.request::<context_server::types::requests::CallTool>(
context_server::types::CallToolParams {
name: tool_name,
arguments,
meta: None,
},
)
.await?;
let mut result = String::new();
for content in response.content {
match content {
context_server::types::ToolResponseContent::Text { text } => {
result.push_str(&text);
}
context_server::types::ToolResponseContent::Image { .. } => {
log::warn!("Ignoring image content from tool response");
}
context_server::types::ToolResponseContent::Audio { .. } => {
log::warn!("Ignoring audio content from tool response");
}
context_server::types::ToolResponseContent::Resource { .. } => {
log::warn!("Ignoring resource content from tool response");
}
}
}
Ok(AgentToolOutput {
raw_output: result.clone().into(),
llm_output: result.into(),
})
})
}
}

View File

@@ -85,7 +85,7 @@ impl AgentTool for DiagnosticsTool {
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
match input.path {
@@ -119,6 +119,11 @@ impl AgentTool for DiagnosticsTool {
range.start.row + 1,
entry.diagnostic.message
)?;
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![output.clone().into()]),
..Default::default()
});
}
if output.is_empty() {
@@ -153,9 +158,18 @@ impl AgentTool for DiagnosticsTool {
}
if has_diagnostics {
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![output.clone().into()]),
..Default::default()
});
Task::ready(Ok(output))
} else {
Task::ready(Ok("No errors or warnings found in the project.".into()))
let text = "No errors or warnings found in the project.";
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![text.into()]),
..Default::default()
});
Task::ready(Ok(text.into()))
}
}
}

View File

@@ -1,13 +1,12 @@
use crate::{AgentTool, Thread, ToolCallEventStream};
use acp_thread::Diff;
use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
use indoc::formatdoc;
use language::ToPoint;
use language::language_settings::{self, FormatOnSave};
use language_model::LanguageModelToolResultContent;
use paths;
@@ -226,16 +225,6 @@ impl AgentTool for EditFileTool {
Ok(path) => path,
Err(err) => return Task::ready(Err(anyhow!(err))),
};
let abs_path = project.read(cx).absolute_path(&project_path, cx);
if let Some(abs_path) = abs_path.clone() {
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: abs_path,
line: None,
}]),
..Default::default()
});
}
let request = self.thread.update(cx, |thread, cx| {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
@@ -294,38 +283,13 @@ impl AgentTool for EditFileTool {
let mut hallucinated_old_text = false;
let mut ambiguous_ranges = Vec::new();
let mut emitted_location = false;
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited(range) => {
if !emitted_location {
let line = buffer.update(cx, |buffer, _cx| {
range.start.to_point(&buffer.snapshot()).row
}).ok();
if let Some(abs_path) = abs_path.clone() {
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
..Default::default()
});
}
emitted_location = true;
}
},
EditAgentOutputEvent::Edited => {},
EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
EditAgentOutputEvent::ResolvingEditRange(range) => {
diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
// if !emitted_location {
// let line = buffer.update(cx, |buffer, _cx| {
// range.start.to_point(&buffer.snapshot()).row
// }).ok();
// if let Some(abs_path) = abs_path.clone() {
// event_stream.update_fields(ToolCallUpdateFields {
// locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
// ..Default::default()
// });
// }
// }
diff.update(cx, |card, cx| card.reveal_range(range, cx))?;
}
}
}
@@ -490,8 +454,9 @@ fn resolve_path(
#[cfg(test)]
mod tests {
use crate::Templates;
use super::*;
use crate::{ContextServerRegistry, Templates};
use action_log::ActionLog;
use client::TelemetrySettings;
use fs::Fs;
@@ -510,20 +475,9 @@ mod tests {
fs.insert_tree("/root", json!({})).await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log,
Templates::new(),
model,
cx,
)
});
let thread =
cx.new(|_| Thread::new(project, Rc::default(), action_log, Templates::new(), model));
let result = cx
.update(|cx| {
let input = EditFileToolInput {
@@ -707,18 +661,14 @@ mod tests {
});
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
@@ -842,19 +792,15 @@ mod tests {
.unwrap();
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
@@ -968,19 +914,15 @@ mod tests {
init_test(cx);
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1099,19 +1041,15 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({})).await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project,
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1210,18 +1148,14 @@ mod tests {
.await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1291,18 +1225,14 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1375,18 +1305,14 @@ mod tests {
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry.clone(),
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });
@@ -1456,18 +1382,14 @@ mod tests {
let fs = project::FakeFs::new(cx.executor());
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
let model = Arc::new(FakeLanguageModel::default());
let thread = cx.new(|cx| {
let thread = cx.new(|_| {
Thread::new(
project.clone(),
Rc::default(),
context_server_registry,
action_log.clone(),
Templates::new(),
model.clone(),
cx,
)
});
let tool = Arc::new(EditFileTool { thread });

View File

@@ -136,7 +136,7 @@ impl AgentTool for FetchTool {
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
let text = cx.background_spawn({
@@ -149,6 +149,12 @@ impl AgentTool for FetchTool {
if text.trim().is_empty() {
bail!("no textual content found");
}
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![text.clone().into()]),
..Default::default()
});
Ok(text)
})
}

View File

@@ -139,6 +139,9 @@ impl AgentTool for FindPathTool {
})
.collect(),
),
raw_output: Some(serde_json::json!({
"paths": &matches,
})),
..Default::default()
});

View File

@@ -101,7 +101,7 @@ impl AgentTool for GrepTool {
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<Self::Output>> {
const CONTEXT_LINES: u32 = 2;
@@ -282,22 +282,33 @@ impl AgentTool for GrepTool {
}
}
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![output.clone().into()]),
..Default::default()
});
matches_found += 1;
}
}
if matches_found == 0 {
Ok("No matches found".into())
let output = if matches_found == 0 {
"No matches found".to_string()
} else if has_more_matches {
Ok(format!(
format!(
"Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
input.offset + 1,
input.offset + matches_found,
input.offset + RESULTS_PER_PAGE,
))
)
} else {
Ok(format!("Found {matches_found} matches:\n{output}"))
}
format!("Found {matches_found} matches:\n{output}")
};
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![output.clone().into()]),
..Default::default()
});
Ok(output)
})
}
}

View File

@@ -47,13 +47,20 @@ impl AgentTool for NowTool {
fn run(
self: Arc<Self>,
input: Self::Input,
_event_stream: ToolCallEventStream,
event_stream: ToolCallEventStream,
_cx: &mut App,
) -> Task<Result<String>> {
let now = match input.timezone {
Timezone::Utc => Utc::now().to_rfc3339(),
Timezone::Local => Local::now().to_rfc3339(),
};
Task::ready(Ok(format!("The current datetime is {now}.")))
let content = format!("The current datetime is {now}.");
event_stream.update_fields(acp::ToolCallUpdateFields {
content: Some(vec![content.clone().into()]),
..Default::default()
});
Task::ready(Ok(content))
}
}

View File

@@ -1,10 +1,10 @@
use action_log::ActionLog;
use agent_client_protocol::{self as acp, ToolCallUpdateFields};
use agent_client_protocol::{self as acp};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::outline;
use gpui::{App, Entity, SharedString, Task};
use indoc::formatdoc;
use language::Point;
use language::{Anchor, Point};
use language_model::{LanguageModelImage, LanguageModelToolResultContent};
use project::{AgentLocation, ImageItem, Project, WorktreeSettings, image_store};
use schemars::JsonSchema;
@@ -97,7 +97,7 @@ impl AgentTool for ReadFileTool {
fn run(
self: Arc<Self>,
input: Self::Input,
event_stream: ToolCallEventStream,
_event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<LanguageModelToolResultContent>> {
let Some(project_path) = self.project.read(cx).find_project_path(&input.path, cx) else {
@@ -166,9 +166,7 @@ impl AgentTool for ReadFileTool {
cx.spawn(async move |cx| {
let buffer = cx
.update(|cx| {
project.update(cx, |project, cx| {
project.open_buffer(project_path.clone(), cx)
})
project.update(cx, |project, cx| project.open_buffer(project_path, cx))
})?
.await?;
if buffer.read_with(cx, |buffer, _| {
@@ -180,10 +178,19 @@ impl AgentTool for ReadFileTool {
anyhow::bail!("{file_path} not found");
}
let mut anchor = None;
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: Anchor::MIN,
}),
cx,
);
})?;
// Check if specific line ranges are provided
let result = if input.start_line.is_some() || input.end_line.is_some() {
if input.start_line.is_some() || input.end_line.is_some() {
let mut anchor = None;
let result = buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
// .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0.
@@ -207,6 +214,18 @@ impl AgentTool for ReadFileTool {
log.buffer_read(buffer.clone(), cx);
})?;
if let Some(anchor) = anchor {
project.update(cx, |project, cx| {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor,
}),
cx,
);
})?;
}
Ok(result.into())
} else {
// No line ranges specified, so check file size to see if it's too big.
@@ -217,7 +236,7 @@ impl AgentTool for ReadFileTool {
let result = buffer.read_with(cx, |buffer, _cx| buffer.text())?;
action_log.update(cx, |log, cx| {
log.buffer_read(buffer.clone(), cx);
log.buffer_read(buffer, cx);
})?;
Ok(result.into())
@@ -225,8 +244,7 @@ impl AgentTool for ReadFileTool {
// File is too big, so return the outline
// and a suggestion to read again with line numbers.
let outline =
outline::file_outline(project.clone(), file_path, action_log, None, cx)
.await?;
outline::file_outline(project, file_path, action_log, None, cx).await?;
Ok(formatdoc! {"
This file was too big to read all at once.
@@ -243,28 +261,7 @@ impl AgentTool for ReadFileTool {
}
.into())
}
};
project.update(cx, |project, cx| {
if let Some(abs_path) = project.absolute_path(&project_path, cx) {
project.set_agent_location(
Some(AgentLocation {
buffer: buffer.downgrade(),
position: anchor.unwrap_or(text::Anchor::MIN),
}),
cx,
);
event_stream.update_fields(ToolCallUpdateFields {
locations: Some(vec![acp::ToolCallLocation {
path: abs_path,
line: input.start_line.map(|line| line.saturating_sub(1)),
}]),
..Default::default()
});
}
})?;
result
}
})
}
}

View File

@@ -5,9 +5,7 @@ use agent_client_protocol as acp;
use anyhow::{Result, anyhow};
use cloud_llm_client::WebSearchResponse;
use gpui::{App, AppContext, Task};
use language_model::{
LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
};
use language_model::LanguageModelToolResultContent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use ui::prelude::*;
@@ -52,11 +50,6 @@ impl AgentTool for WebSearchTool {
"Searching the Web".into()
}
/// We currently only support Zed Cloud as a provider.
fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
provider == &ZED_CLOUD_PROVIDER_ID
}
fn run(
self: Arc<Self>,
input: Self::Input,

View File

@@ -423,7 +423,7 @@ impl AgentConnection for AcpConnection {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
cx: &mut App,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let task = self.connection.request_any(
acp_old::InitializeParams {
@@ -467,7 +467,6 @@ impl AgentConnection for AcpConnection {
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {

View File

@@ -111,7 +111,7 @@ impl AgentConnection for AcpConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let conn = self.connection.clone();
let sessions = self.sessions.clone();
@@ -171,7 +171,6 @@ impl AgentConnection for AcpConnection {
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {

View File

@@ -74,7 +74,7 @@ impl AgentConnection for ClaudeAgentConnection {
self: Rc<Self>,
project: Entity<Project>,
cwd: &Path,
cx: &mut App,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
let cwd = cwd.to_owned();
cx.spawn(async move |cx| {
@@ -210,7 +210,6 @@ impl AgentConnection for ClaudeAgentConnection {
fn prompt(
&self,
_id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
@@ -424,7 +423,7 @@ impl ClaudeAgentSession {
if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
thread.push_user_content_block(None, text.into(), cx)
thread.push_user_content_block(text.into(), cx)
})
.log_err();
}

View File

@@ -422,8 +422,8 @@ pub async fn new_test_thread(
.await
.unwrap();
let thread = cx
.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
let thread = connection
.new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
.await
.unwrap();

View File

@@ -48,20 +48,6 @@ pub struct AgentProfileSettings {
pub context_servers: IndexMap<Arc<str>, ContextServerPreset>,
}
impl AgentProfileSettings {
pub fn is_tool_enabled(&self, tool_name: &str) -> bool {
self.tools.get(tool_name) == Some(&true)
}
pub fn is_context_server_tool_enabled(&self, server_id: &str, tool_name: &str) -> bool {
self.enable_all_context_servers
|| self
.context_servers
.get(server_id)
.map_or(false, |preset| preset.tools.get(tool_name) == Some(&true))
}
}
#[derive(Debug, Clone, Default)]
pub struct ContextServerPreset {
pub tools: IndexMap<Arc<str>, bool>,

View File

@@ -93,7 +93,6 @@ time.workspace = true
time_format.workspace = true
ui.workspace = true
ui_input.workspace = true
url.workspace = true
urlencoding.workspace = true
util.workspace = true
uuid.workspace = true
@@ -103,9 +102,6 @@ workspace.workspace = true
zed_actions.workspace = true
[dev-dependencies]
acp_thread = { workspace = true, features = ["test-support"] }
agent = { workspace = true, features = ["test-support"] }
assistant_context = { workspace = true, features = ["test-support"] }
assistant_tools.workspace = true
buffer_diff = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }

View File

@@ -1,10 +1,6 @@
mod completion_provider;
mod entry_view_state;
mod message_editor;
mod model_selector;
mod model_selector_popover;
mod message_history;
mod thread_view;
pub use model_selector::AcpModelSelector;
pub use model_selector_popover::AcpModelSelectorPopover;
pub use message_history::MessageHistory;
pub use thread_view::AcpThreadView;

File diff suppressed because it is too large Load Diff

View File

@@ -1,351 +0,0 @@
use std::{collections::HashMap, ops::Range};
use acp_thread::AcpThread;
use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
use gpui::{
AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
};
use language::language_settings::SoftWrap;
use settings::Settings as _;
use terminal_view::TerminalView;
use theme::ThemeSettings;
use ui::TextSize;
use workspace::Workspace;
#[derive(Default)]
pub struct EntryViewState {
entries: Vec<Entry>,
}
impl EntryViewState {
pub fn entry(&self, index: usize) -> Option<&Entry> {
self.entries.get(index)
}
pub fn sync_entry(
&mut self,
workspace: WeakEntity<Workspace>,
thread: Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
debug_assert!(index <= self.entries.len());
let entry = if let Some(entry) = self.entries.get_mut(index) {
entry
} else {
self.entries.push(Entry::default());
self.entries.last_mut().unwrap()
};
entry.sync_diff_multibuffers(&thread, index, window, cx);
entry.sync_terminals(&workspace, &thread, index, window, cx);
}
pub fn remove(&mut self, range: Range<usize>) {
self.entries.drain(range);
}
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
for view in entry.views.values() {
if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
diff_editor.update(cx, |diff_editor, cx| {
diff_editor
.set_text_style_refinement(diff_editor_text_style_refinement(cx));
cx.notify();
})
}
}
}
}
}
pub struct Entry {
views: HashMap<EntityId, AnyEntity>,
}
impl Entry {
pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
self.views
.get(&diff.entity_id())
.cloned()
.map(|entity| entity.downcast::<Editor>().unwrap())
}
pub fn terminal(
&self,
terminal: &Entity<acp_thread::Terminal>,
) -> Option<Entity<TerminalView>> {
self.views
.get(&terminal.entity_id())
.cloned()
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
fn sync_diff_multibuffers(
&mut self,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let multibuffers = entry
.diffs()
.map(|diff| diff.read(cx).multibuffer().clone());
let multibuffers = multibuffers.collect::<Vec<_>>();
for multibuffer in multibuffers {
if self.views.contains_key(&multibuffer.entity_id()) {
return;
}
let editor = cx.new(|cx| {
let mut editor = Editor::new(
EditorMode::Full {
scale_ui_elements_with_buffer_font_size: false,
show_active_line_background: false,
sized_by_content: true,
},
multibuffer.clone(),
None,
window,
cx,
);
editor.set_show_gutter(false, cx);
editor.disable_inline_diagnostics();
editor.disable_expand_excerpt_buttons(cx);
editor.set_show_vertical_scrollbar(false, cx);
editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
editor.set_soft_wrap_mode(SoftWrap::None, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_show_indent_guides(false, cx);
editor.set_read_only(true);
editor.set_show_breakpoints(false, cx);
editor.set_show_code_actions(false, cx);
editor.set_show_git_diff_gutter(false, cx);
editor.set_expand_all_diff_hunks(cx);
editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
editor
});
let entity_id = multibuffer.entity_id();
self.views.insert(entity_id, editor.into_any());
}
}
fn sync_terminals(
&mut self,
workspace: &WeakEntity<Workspace>,
thread: &Entity<AcpThread>,
index: usize,
window: &mut Window,
cx: &mut App,
) {
let Some(entry) = thread.read(cx).entries().get(index) else {
return;
};
let terminals = entry
.terminals()
.map(|terminal| terminal.clone())
.collect::<Vec<_>>();
for terminal in terminals {
if self.views.contains_key(&terminal.entity_id()) {
return;
}
let Some(strong_workspace) = workspace.upgrade() else {
return;
};
let terminal_view = cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
workspace.clone(),
None,
strong_workspace.read(cx).project().downgrade(),
window,
cx,
);
view.set_embedded_mode(Some(1000), cx);
view
});
let entity_id = terminal.entity_id();
self.views.insert(entity_id, terminal_view.into_any());
}
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.views.len()
}
}
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
TextSize::Small
.rems(cx)
.to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
.into(),
),
..Default::default()
}
}
impl Default for Entry {
fn default() -> Self {
Self {
// Avoid allocating in the heap by default
views: HashMap::with_capacity(0),
}
}
}
#[cfg(test)]
mod tests {
use std::{path::Path, rc::Rc};
use acp_thread::{AgentConnection, StubAgentConnection};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use editor::{EditorSettings, RowInfo};
use fs::FakeFs;
use gpui::{SemanticVersion, TestAppContext};
use multi_buffer::MultiBufferRow;
use pretty_assertions::assert_matches;
use project::Project;
use serde_json::json;
use settings::{Settings as _, SettingsStore};
use theme::ThemeSettings;
use util::path;
use workspace::Workspace;
use crate::acp::entry_view_state::EntryViewState;
#[gpui::test]
async fn test_diff_sync(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/project",
json!({
"hello.txt": "hi world"
}),
)
.await;
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let tool_call = acp::ToolCall {
id: acp::ToolCallId("tool".into()),
title: "Tool call".into(),
kind: acp::ToolKind::Other,
status: acp::ToolCallStatus::InProgress,
content: vec![acp::ToolCallContent::Diff {
diff: acp::Diff {
path: "/project/hello.txt".into(),
old_text: Some("hi world".into()),
new_text: "hello world".into(),
},
}],
locations: vec![],
raw_input: None,
raw_output: None,
};
let connection = Rc::new(StubAgentConnection::new());
let thread = cx
.update(|_, cx| {
connection
.clone()
.new_thread(project, Path::new(path!("/project")), cx)
})
.await
.unwrap();
let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
cx.update(|_, cx| {
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
});
let mut view_state = EntryViewState::default();
cx.update(|window, cx| {
view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
});
let multibuffer = thread.read_with(cx, |thread, cx| {
thread
.entries()
.get(0)
.unwrap()
.diffs()
.next()
.unwrap()
.read(cx)
.multibuffer()
.clone()
});
cx.run_until_parked();
let entry = view_state.entry(0).unwrap();
let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
assert_eq!(
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
"hi world\nhello world"
);
let row_infos = diff_editor.read_with(cx, |editor, cx| {
let multibuffer = editor.buffer().read(cx);
multibuffer
.snapshot(cx)
.row_infos(MultiBufferRow(0))
.collect::<Vec<_>>()
});
assert_matches!(
row_infos.as_slice(),
[
RowInfo {
multibuffer_row: Some(MultiBufferRow(0)),
diff_status: Some(DiffHunkStatus {
kind: DiffHunkStatusKind::Deleted,
..
}),
..
},
RowInfo {
multibuffer_row: Some(MultiBufferRow(1)),
diff_status: Some(DiffHunkStatus {
kind: DiffHunkStatusKind::Added,
..
}),
..
}
]
);
}
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
AgentSettings::register(cx);
workspace::init_settings(cx);
ThemeSettings::register(cx);
release_channel::init(SemanticVersion::default(), cx);
EditorSettings::register(cx);
});
}
}

View File

@@ -1,684 +0,0 @@
use crate::acp::completion_provider::ContextPickerCompletionProvider;
use crate::acp::completion_provider::MentionImage;
use crate::acp::completion_provider::MentionSet;
use acp_thread::MentionUri;
use agent::TextThreadStore;
use agent::ThreadStore;
use agent_client_protocol as acp;
use anyhow::Result;
use collections::HashSet;
use editor::ExcerptId;
use editor::actions::Paste;
use editor::display_map::CreaseId;
use editor::{
AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, EditorMode,
EditorStyle, MultiBuffer,
};
use futures::FutureExt as _;
use gpui::ClipboardEntry;
use gpui::Image;
use gpui::ImageFormat;
use gpui::{
AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, Task, TextStyle, WeakEntity,
};
use language::Buffer;
use language::Language;
use language_model::LanguageModelImage;
use parking_lot::Mutex;
use project::{CompletionIntent, Project};
use settings::Settings;
use std::fmt::Write;
use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use theme::ThemeSettings;
use ui::IconName;
use ui::SharedString;
use ui::{
ActiveTheme, App, InteractiveElement, IntoElement, ParentElement, Render, Styled, TextSize,
Window, div,
};
use util::ResultExt;
use workspace::Workspace;
use workspace::notifications::NotifyResultExt as _;
use zed_actions::agent::Chat;
use super::completion_provider::Mention;
pub struct MessageEditor {
editor: Entity<Editor>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
mention_set: Arc<Mutex<MentionSet>>,
}
pub enum MessageEditorEvent {
Send,
Cancel,
}
impl EventEmitter<MessageEditorEvent> for MessageEditor {}
impl MessageEditor {
pub fn new(
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
mode: EditorMode,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
let language = Language::new(
language::LanguageConfig {
completion_query_characters: HashSet::from_iter(['.', '-', '_', '@']),
..Default::default()
},
None,
);
let mention_set = Arc::new(Mutex::new(MentionSet::default()));
let editor = cx.new(|cx| {
let buffer = cx.new(|cx| Buffer::local("", cx).with_language(Arc::new(language), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let mut editor = Editor::new(mode, buffer, None, window, cx);
editor.set_placeholder_text("Message the agent @ to include files", cx);
editor.set_show_indent_guides(false, cx);
editor.set_soft_wrap();
editor.set_use_modal_editing(true);
editor.set_completion_provider(Some(Rc::new(ContextPickerCompletionProvider::new(
mention_set.clone(),
workspace,
thread_store.downgrade(),
text_thread_store.downgrade(),
cx.weak_entity(),
))));
editor.set_context_menu_options(ContextMenuOptions {
min_entries_visible: 12,
max_entries_visible: 12,
placement: Some(ContextMenuPlacement::Above),
});
editor
});
Self {
editor,
project,
mention_set,
thread_store,
text_thread_store,
}
}
pub fn is_empty(&self, cx: &App) -> bool {
self.editor.read(cx).is_empty(cx)
}
pub fn contents(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Task<Result<Vec<acp::ContentBlock>>> {
let contents = self.mention_set.lock().contents(
self.project.clone(),
self.thread_store.clone(),
self.text_thread_store.clone(),
window,
cx,
);
let editor = self.editor.clone();
cx.spawn(async move |_, cx| {
let contents = contents.await?;
editor.update(cx, |editor, cx| {
let mut ix = 0;
let mut chunks: Vec<acp::ContentBlock> = Vec::new();
let text = editor.text(cx);
editor.display_map.update(cx, |map, cx| {
let snapshot = map.snapshot(cx);
for (crease_id, crease) in snapshot.crease_snapshot.creases() {
// Skip creases that have been edited out of the message buffer.
if !crease.range().start.is_valid(&snapshot.buffer_snapshot) {
continue;
}
let Some(mention) = contents.get(&crease_id) else {
continue;
};
let crease_range = crease.range().to_offset(&snapshot.buffer_snapshot);
if crease_range.start > ix {
chunks.push(text[ix..crease_range.start].into());
}
let chunk = match mention {
Mention::Text { uri, content } => {
acp::ContentBlock::Resource(acp::EmbeddedResource {
annotations: None,
resource: acp::EmbeddedResourceResource::TextResourceContents(
acp::TextResourceContents {
mime_type: None,
text: content.clone(),
uri: uri.to_uri().to_string(),
},
),
})
}
Mention::Image(mention_image) => {
acp::ContentBlock::Image(acp::ImageContent {
annotations: None,
data: mention_image.data.to_string(),
mime_type: mention_image.format.mime_type().into(),
uri: mention_image
.abs_path
.as_ref()
.map(|path| format!("file://{}", path.display())),
})
}
};
chunks.push(chunk);
ix = crease_range.end;
}
if ix < text.len() {
let last_chunk = text[ix..].trim_end();
if !last_chunk.is_empty() {
chunks.push(last_chunk.into());
}
}
});
chunks
})
})
}
pub fn clear(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
editor.clear(window, cx);
editor.remove_creases(self.mention_set.lock().drain(), cx)
});
}
fn chat(&mut self, _: &Chat, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Send)
}
fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
cx.emit(MessageEditorEvent::Cancel)
}
fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context<Self>) {
let images = cx
.read_from_clipboard()
.map(|item| {
item.into_entries()
.filter_map(|entry| {
if let ClipboardEntry::Image(image) = entry {
Some(image)
} else {
None
}
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if images.is_empty() {
return;
}
cx.stop_propagation();
let replacement_text = "image";
for image in images {
let (excerpt_id, anchor) = self.editor.update(cx, |message_editor, cx| {
let snapshot = message_editor.snapshot(window, cx);
let (excerpt_id, _, snapshot) = snapshot.buffer_snapshot.as_singleton().unwrap();
let anchor = snapshot.anchor_before(snapshot.len());
message_editor.edit(
[(
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
format!("{replacement_text} "),
)],
cx,
);
(*excerpt_id, anchor)
});
self.insert_image(
excerpt_id,
anchor,
replacement_text.len(),
Arc::new(image),
None,
window,
cx,
);
}
}
pub fn insert_dragged_files(
&self,
paths: Vec<project::ProjectPath>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let buffer = self.editor.read(cx).buffer().clone();
let Some((&excerpt_id, _, _)) = buffer.read(cx).snapshot(cx).as_singleton() else {
return;
};
let Some(buffer) = buffer.read(cx).as_singleton() else {
return;
};
for path in paths {
let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
continue;
};
let Some(abs_path) = self.project.read(cx).absolute_path(&path, cx) else {
continue;
};
let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len()));
let path_prefix = abs_path
.file_name()
.unwrap_or(path.path.as_os_str())
.display()
.to_string();
let Some(completion) = ContextPickerCompletionProvider::completion_for_path(
path,
&path_prefix,
false,
entry.is_dir(),
excerpt_id,
anchor..anchor,
self.editor.clone(),
self.mention_set.clone(),
self.project.clone(),
cx,
) else {
continue;
};
self.editor.update(cx, |message_editor, cx| {
message_editor.edit(
[(
multi_buffer::Anchor::max()..multi_buffer::Anchor::max(),
completion.new_text,
)],
cx,
);
});
if let Some(confirm) = completion.confirm.clone() {
confirm(CompletionIntent::Complete, window, cx);
}
}
}
fn insert_image(
&mut self,
excerpt_id: ExcerptId,
crease_start: text::Anchor,
content_len: usize,
image: Arc<Image>,
abs_path: Option<Arc<Path>>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let Some(crease_id) = insert_crease_for_image(
excerpt_id,
crease_start,
content_len,
self.editor.clone(),
window,
cx,
) else {
return;
};
self.editor.update(cx, |_editor, cx| {
let format = image.format;
let convert = LanguageModelImage::from_image(image, cx);
let task = cx
.spawn_in(window, async move |editor, cx| {
if let Some(image) = convert.await {
Ok(MentionImage {
abs_path,
data: image.source,
format,
})
} else {
editor
.update(cx, |editor, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let Some(anchor) =
snapshot.anchor_in_excerpt(excerpt_id, crease_start)
else {
return;
};
editor.display_map.update(cx, |display_map, cx| {
display_map.unfold_intersecting(vec![anchor..anchor], true, cx);
});
editor.remove_creases([crease_id], cx);
})
.ok();
Err("Failed to convert image".to_string())
}
})
.shared();
cx.spawn_in(window, {
let task = task.clone();
async move |_, cx| task.clone().await.notify_async_err(cx)
})
.detach();
self.mention_set.lock().insert_image(crease_id, task);
});
}
pub fn set_mode(&mut self, mode: EditorMode, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
editor.set_mode(mode);
cx.notify()
});
}
pub fn set_message(
&mut self,
message: Vec<acp::ContentBlock>,
window: &mut Window,
cx: &mut Context<Self>,
) {
let mut text = String::new();
let mut mentions = Vec::new();
let mut images = Vec::new();
for chunk in message {
match chunk {
acp::ContentBlock::Text(text_content) => {
text.push_str(&text_content.text);
}
acp::ContentBlock::Resource(acp::EmbeddedResource {
resource: acp::EmbeddedResourceResource::TextResourceContents(resource),
..
}) => {
if let Some(mention_uri) = MentionUri::parse(&resource.uri).log_err() {
let start = text.len();
write!(&mut text, "{}", mention_uri.as_link()).ok();
let end = text.len();
mentions.push((start..end, mention_uri));
}
}
acp::ContentBlock::Image(content) => {
let start = text.len();
text.push_str("image");
let end = text.len();
images.push((start..end, content));
}
acp::ContentBlock::Audio(_)
| acp::ContentBlock::Resource(_)
| acp::ContentBlock::ResourceLink(_) => {}
}
}
let snapshot = self.editor.update(cx, |editor, cx| {
editor.set_text(text, window, cx);
editor.buffer().read(cx).snapshot(cx)
});
self.mention_set.lock().clear();
for (range, mention_uri) in mentions {
let anchor = snapshot.anchor_before(range.start);
let crease_id = crate::context_picker::insert_crease_for_mention(
anchor.excerpt_id,
anchor.text_anchor,
range.end - range.start,
mention_uri.name().into(),
mention_uri.icon_path(cx),
self.editor.clone(),
window,
cx,
);
if let Some(crease_id) = crease_id {
self.mention_set.lock().insert_uri(crease_id, mention_uri);
}
}
for (range, content) in images {
let Some(format) = ImageFormat::from_mime_type(&content.mime_type) else {
continue;
};
let anchor = snapshot.anchor_before(range.start);
let abs_path = content
.uri
.as_ref()
.and_then(|uri| uri.strip_prefix("file://").map(|s| Path::new(s).into()));
let name = content
.uri
.as_ref()
.and_then(|uri| {
uri.strip_prefix("file://")
.and_then(|path| Path::new(path).file_name())
})
.map(|name| name.to_string_lossy().to_string())
.unwrap_or("Image".to_owned());
let crease_id = crate::context_picker::insert_crease_for_mention(
anchor.excerpt_id,
anchor.text_anchor,
range.end - range.start,
name.into(),
IconName::Image.path().into(),
self.editor.clone(),
window,
cx,
);
let data: SharedString = content.data.to_string().into();
if let Some(crease_id) = crease_id {
self.mention_set.lock().insert_image(
crease_id,
Task::ready(Ok(MentionImage {
abs_path,
data,
format,
}))
.shared(),
);
}
}
cx.notify();
}
#[cfg(test)]
pub fn set_text(&mut self, text: &str, window: &mut Window, cx: &mut Context<Self>) {
self.editor.update(cx, |editor, cx| {
editor.set_text(text, window, cx);
});
}
}
impl Focusable for MessageEditor {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.editor.focus_handle(cx)
}
}
impl Render for MessageEditor {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.key_context("MessageEditor")
.on_action(cx.listener(Self::chat))
.on_action(cx.listener(Self::cancel))
.capture_action(cx.listener(Self::paste))
.flex_1()
.child({
let settings = ThemeSettings::get_global(cx);
let font_size = TextSize::Small
.rems(cx)
.to_pixels(settings.agent_font_size(cx));
let line_height = settings.buffer_line_height.value() * font_size;
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.buffer_font.family.clone(),
font_fallbacks: settings.buffer_font.fallbacks.clone(),
font_features: settings.buffer_font.features.clone(),
font_size: font_size.into(),
line_height: line_height.into(),
..Default::default()
};
EditorElement::new(
&self.editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
syntax: cx.theme().syntax().clone(),
..Default::default()
},
)
})
}
}
pub(crate) fn insert_crease_for_image(
excerpt_id: ExcerptId,
anchor: text::Anchor,
content_len: usize,
editor: Entity<Editor>,
window: &mut Window,
cx: &mut App,
) -> Option<CreaseId> {
crate::context_picker::insert_crease_for_mention(
excerpt_id,
anchor,
content_len,
"Image".into(),
IconName::Image.path().into(),
editor,
window,
cx,
)
}
#[cfg(test)]
mod tests {
use std::path::Path;
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol as acp;
use editor::EditorMode;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use lsp::{CompletionContext, CompletionTriggerKind};
use project::{CompletionIntent, Project};
use serde_json::json;
use util::path;
use workspace::Workspace;
use crate::acp::{message_editor::MessageEditor, thread_view::tests::init_test};
#[gpui::test]
async fn test_at_mention_removal(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree("/project", json!({"file": ""})).await;
let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
let message_editor = cx.update(|window, cx| {
cx.new(|cx| {
MessageEditor::new(
workspace.downgrade(),
project.clone(),
thread_store.clone(),
text_thread_store.clone(),
EditorMode::AutoHeight {
min_lines: 1,
max_lines: None,
},
window,
cx,
)
})
});
let editor = message_editor.update(cx, |message_editor, _| message_editor.editor.clone());
cx.run_until_parked();
let excerpt_id = editor.update(cx, |editor, cx| {
editor
.buffer()
.read(cx)
.excerpt_ids()
.into_iter()
.next()
.unwrap()
});
let completions = editor.update_in(cx, |editor, window, cx| {
editor.set_text("Hello @file ", window, cx);
let buffer = editor.buffer().read(cx).as_singleton().unwrap();
let completion_provider = editor.completion_provider().unwrap();
completion_provider.completions(
excerpt_id,
&buffer,
text::Anchor::MAX,
CompletionContext {
trigger_kind: CompletionTriggerKind::TRIGGER_CHARACTER,
trigger_character: Some("@".into()),
},
window,
cx,
)
});
let [_, completion]: [_; 2] = completions
.await
.unwrap()
.into_iter()
.flat_map(|response| response.completions)
.collect::<Vec<_>>()
.try_into()
.unwrap();
editor.update_in(cx, |editor, window, cx| {
let snapshot = editor.buffer().read(cx).snapshot(cx);
let start = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.start)
.unwrap();
let end = snapshot
.anchor_in_excerpt(excerpt_id, completion.replace_range.end)
.unwrap();
editor.edit([(start..end, completion.new_text)], cx);
(completion.confirm.unwrap())(CompletionIntent::Complete, window, cx);
});
cx.run_until_parked();
// Backspace over the inserted crease (and the following space).
editor.update_in(cx, |editor, window, cx| {
editor.backspace(&Default::default(), window, cx);
editor.backspace(&Default::default(), window, cx);
});
let content = message_editor
.update_in(cx, |message_editor, window, cx| {
message_editor.contents(window, cx)
})
.await
.unwrap();
// We don't send a resource link for the deleted crease.
pretty_assertions::assert_matches!(content.as_slice(), [acp::ContentBlock::Text { .. }]);
}
}

View File

@@ -0,0 +1,92 @@
pub struct MessageHistory<T> {
items: Vec<T>,
current: Option<usize>,
}
impl<T> Default for MessageHistory<T> {
fn default() -> Self {
MessageHistory {
items: Vec::new(),
current: None,
}
}
}
impl<T> MessageHistory<T> {
pub fn push(&mut self, message: T) {
self.current.take();
self.items.push(message);
}
pub fn reset_position(&mut self) {
self.current.take();
}
pub fn prev(&mut self) -> Option<&T> {
if self.items.is_empty() {
return None;
}
let new_ix = self
.current
.get_or_insert(self.items.len())
.saturating_sub(1);
self.current = Some(new_ix);
self.items.get(new_ix)
}
pub fn next(&mut self) -> Option<&T> {
let current = self.current.as_mut()?;
*current += 1;
self.items.get(*current).or_else(|| {
self.current.take();
None
})
}
#[cfg(test)]
pub fn items(&self) -> &[T] {
&self.items
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prev_next() {
let mut history = MessageHistory::default();
// Test empty history
assert_eq!(history.prev(), None);
assert_eq!(history.next(), None);
// Add some messages
history.push("first");
history.push("second");
history.push("third");
// Test prev navigation
assert_eq!(history.prev(), Some(&"third"));
assert_eq!(history.prev(), Some(&"second"));
assert_eq!(history.prev(), Some(&"first"));
assert_eq!(history.prev(), Some(&"first"));
assert_eq!(history.next(), Some(&"second"));
// Test mixed navigation
history.push("fourth");
assert_eq!(history.prev(), Some(&"fourth"));
assert_eq!(history.prev(), Some(&"third"));
assert_eq!(history.next(), Some(&"fourth"));
assert_eq!(history.next(), None);
// Test that push resets navigation
history.prev();
history.prev();
history.push("fifth");
assert_eq!(history.prev(), Some(&"fifth"));
}
}

View File

@@ -1,472 +0,0 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_client_protocol as acp;
use anyhow::Result;
use collections::IndexMap;
use futures::FutureExt;
use fuzzy::{StringMatchCandidate, match_strings};
use gpui::{Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, Task, WeakEntity};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use ui::{
AnyElement, App, Context, IntoElement, ListItem, ListItemSpacing, SharedString, Window,
prelude::*, rems,
};
use util::ResultExt;
pub type AcpModelSelector = Picker<AcpModelPickerDelegate>;
pub fn acp_model_selector(
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
window: &mut Window,
cx: &mut Context<AcpModelSelector>,
) -> AcpModelSelector {
let delegate = AcpModelPickerDelegate::new(session_id, selector, window, cx);
Picker::list(delegate, window, cx)
.show_scrollbar(true)
.width(rems(20.))
.max_height(Some(rems(20.).into()))
}
enum AcpModelPickerEntry {
Separator(SharedString),
Model(AgentModelInfo),
}
pub struct AcpModelPickerDelegate {
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
filtered_entries: Vec<AcpModelPickerEntry>,
models: Option<AgentModelList>,
selected_index: usize,
selected_model: Option<AgentModelInfo>,
_refresh_models_task: Task<()>,
}
impl AcpModelPickerDelegate {
fn new(
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
window: &mut Window,
cx: &mut Context<AcpModelSelector>,
) -> Self {
let mut rx = selector.watch(cx);
let refresh_models_task = cx.spawn_in(window, {
let session_id = session_id.clone();
async move |this, cx| {
async fn refresh(
this: &WeakEntity<Picker<AcpModelPickerDelegate>>,
session_id: &acp::SessionId,
cx: &mut AsyncWindowContext,
) -> Result<()> {
let (models_task, selected_model_task) = this.update(cx, |this, cx| {
(
this.delegate.selector.list_models(cx),
this.delegate.selector.selected_model(session_id, cx),
)
})?;
let (models, selected_model) = futures::join!(models_task, selected_model_task);
this.update_in(cx, |this, window, cx| {
this.delegate.models = models.ok();
this.delegate.selected_model = selected_model.ok();
this.delegate.update_matches(this.query(cx), window, cx)
})?
.await;
Ok(())
}
refresh(&this, &session_id, cx).await.log_err();
while let Ok(()) = rx.recv().await {
refresh(&this, &session_id, cx).await.log_err();
}
}
});
Self {
session_id,
selector,
filtered_entries: Vec::new(),
models: None,
selected_model: None,
selected_index: 0,
_refresh_models_task: refresh_models_task,
}
}
pub fn active_model(&self) -> Option<&AgentModelInfo> {
self.selected_model.as_ref()
}
}
impl PickerDelegate for AcpModelPickerDelegate {
type ListItem = AnyElement;
fn match_count(&self) -> usize {
self.filtered_entries.len()
}
fn selected_index(&self) -> usize {
self.selected_index
}
fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
cx.notify();
}
fn can_select(
&mut self,
ix: usize,
_window: &mut Window,
_cx: &mut Context<Picker<Self>>,
) -> bool {
match self.filtered_entries.get(ix) {
Some(AcpModelPickerEntry::Model(_)) => true,
Some(AcpModelPickerEntry::Separator(_)) | None => false,
}
}
fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
"Select a model…".into()
}
fn update_matches(
&mut self,
query: String,
window: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Task<()> {
cx.spawn_in(window, async move |this, cx| {
let filtered_models = match this
.read_with(cx, |this, cx| {
this.delegate.models.clone().map(move |models| {
fuzzy_search(models, query, cx.background_executor().clone())
})
})
.ok()
.flatten()
{
Some(task) => task.await,
None => AgentModelList::Flat(vec![]),
};
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries =
info_list_to_picker_entries(filtered_models).collect();
// Finds the currently selected model in the list
let new_index = this
.delegate
.selected_model
.as_ref()
.and_then(|selected| {
this.delegate.filtered_entries.iter().position(|entry| {
if let AcpModelPickerEntry::Model(model_info) = entry {
model_info.id == selected.id
} else {
false
}
})
})
.unwrap_or(0);
this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
cx.notify();
})
.ok();
})
}
fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
if let Some(AcpModelPickerEntry::Model(model_info)) =
self.filtered_entries.get(self.selected_index)
{
self.selector
.select_model(self.session_id.clone(), model_info.id.clone(), cx)
.detach_and_log_err(cx);
self.selected_model = Some(model_info.clone());
let current_index = self.selected_index;
self.set_selected_index(current_index, window, cx);
cx.emit(DismissEvent);
}
}
fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
cx.emit(DismissEvent);
}
fn render_match(
&self,
ix: usize,
selected: bool,
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<Self::ListItem> {
match self.filtered_entries.get(ix)? {
AcpModelPickerEntry::Separator(title) => Some(
div()
.px_2()
.pb_1()
.when(ix > 1, |this| {
this.mt_1()
.pt_2()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
})
.child(
Label::new(title)
.size(LabelSize::XSmall)
.color(Color::Muted),
)
.into_any_element(),
),
AcpModelPickerEntry::Model(model_info) => {
let is_selected = Some(model_info) == self.selected_model.as_ref();
let model_icon_color = if is_selected {
Color::Accent
} else {
Color::Muted
};
Some(
ListItem::new(ix)
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
.start_slot::<Icon>(model_info.icon.map(|icon| {
Icon::new(icon)
.color(model_icon_color)
.size(IconSize::Small)
}))
.child(
h_flex()
.w_full()
.pl_0p5()
.gap_1p5()
.w(px(240.))
.child(Label::new(model_info.name.clone()).truncate()),
)
.end_slot(div().pr_3().when(is_selected, |this| {
this.child(
Icon::new(IconName::Check)
.color(Color::Accent)
.size(IconSize::Small),
)
}))
.into_any_element(),
)
}
}
}
fn render_footer(
&self,
_: &mut Window,
cx: &mut Context<Picker<Self>>,
) -> Option<gpui::AnyElement> {
Some(
h_flex()
.w_full()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.p_1()
.gap_4()
.justify_between()
.child(
Button::new("configure", "Configure")
.icon(IconName::Settings)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.icon_position(IconPosition::Start)
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::agent::OpenSettings.boxed_clone(),
cx,
);
}),
)
.into_any(),
)
}
}
fn info_list_to_picker_entries(
model_list: AgentModelList,
) -> impl Iterator<Item = AcpModelPickerEntry> {
match model_list {
AgentModelList::Flat(list) => {
itertools::Either::Left(list.into_iter().map(AcpModelPickerEntry::Model))
}
AgentModelList::Grouped(index_map) => {
itertools::Either::Right(index_map.into_iter().flat_map(|(group_name, models)| {
std::iter::once(AcpModelPickerEntry::Separator(group_name.0))
.chain(models.into_iter().map(AcpModelPickerEntry::Model))
}))
}
}
}
async fn fuzzy_search(
model_list: AgentModelList,
query: String,
executor: BackgroundExecutor,
) -> AgentModelList {
async fn fuzzy_search_list(
model_list: Vec<AgentModelInfo>,
query: &str,
executor: BackgroundExecutor,
) -> Vec<AgentModelInfo> {
let candidates = model_list
.iter()
.enumerate()
.map(|(ix, model)| {
StringMatchCandidate::new(ix, &format!("{}/{}", model.id, model.name))
})
.collect::<Vec<_>>();
let mut matches = match_strings(
&candidates,
&query,
false,
true,
100,
&Default::default(),
executor,
)
.await;
matches.sort_unstable_by_key(|mat| {
let candidate = &candidates[mat.candidate_id];
(Reverse(OrderedFloat(mat.score)), candidate.id)
});
matches
.into_iter()
.map(|mat| model_list[mat.candidate_id].clone())
.collect()
}
match model_list {
AgentModelList::Flat(model_list) => {
AgentModelList::Flat(fuzzy_search_list(model_list, &query, executor).await)
}
AgentModelList::Grouped(index_map) => {
let groups =
futures::future::join_all(index_map.into_iter().map(|(group_name, models)| {
fuzzy_search_list(models, &query, executor.clone())
.map(|results| (group_name, results))
}))
.await;
AgentModelList::Grouped(IndexMap::from_iter(
groups
.into_iter()
.filter(|(_, results)| !results.is_empty()),
))
}
}
}
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use super::*;
fn create_model_list(grouped_models: Vec<(&str, Vec<&str>)>) -> AgentModelList {
AgentModelList::Grouped(IndexMap::from_iter(grouped_models.into_iter().map(
|(group, models)| {
(
acp_thread::AgentModelGroupName(group.to_string().into()),
models
.into_iter()
.map(|model| acp_thread::AgentModelInfo {
id: acp_thread::AgentModelId(model.to_string().into()),
name: model.to_string().into(),
icon: None,
})
.collect::<Vec<_>>(),
)
},
)))
}
fn assert_models_eq(result: AgentModelList, expected: Vec<(&str, Vec<&str>)>) {
let AgentModelList::Grouped(groups) = result else {
panic!("Expected LanguageModelInfoList::Grouped, got {:?}", result);
};
assert_eq!(
groups.len(),
expected.len(),
"Number of groups doesn't match"
);
for (i, (expected_group, expected_models)) in expected.iter().enumerate() {
let (actual_group, actual_models) = groups.get_index(i).unwrap();
assert_eq!(
actual_group.0.as_ref(),
*expected_group,
"Group at position {} doesn't match expected group",
i
);
assert_eq!(
actual_models.len(),
expected_models.len(),
"Number of models in group {} doesn't match",
expected_group
);
for (j, expected_model_name) in expected_models.iter().enumerate() {
assert_eq!(
actual_models[j].name, *expected_model_name,
"Model at position {} in group {} doesn't match expected model",
j, expected_group
);
}
}
}
#[gpui::test]
async fn test_fuzzy_match(cx: &mut TestAppContext) {
let models = create_model_list(vec![
(
"zed",
vec![
"Claude 3.7 Sonnet",
"Claude 3.7 Sonnet Thinking",
"gpt-4.1",
"gpt-4.1-nano",
],
),
("openai", vec!["gpt-3.5-turbo", "gpt-4.1", "gpt-4.1-nano"]),
("ollama", vec!["mistral", "deepseek"]),
]);
// Results should preserve models order whenever possible.
// In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
// similarity scores, but `zed/gpt-4.1` was higher in the models list,
// so it should appear first in the results.
let results = fuzzy_search(models.clone(), "41".into(), cx.executor()).await;
assert_models_eq(
results,
vec![
("zed", vec!["gpt-4.1", "gpt-4.1-nano"]),
("openai", vec!["gpt-4.1", "gpt-4.1-nano"]),
],
);
// Fuzzy search
let results = fuzzy_search(models.clone(), "4n".into(), cx.executor()).await;
assert_models_eq(
results,
vec![
("zed", vec!["gpt-4.1-nano"]),
("openai", vec!["gpt-4.1-nano"]),
],
);
}
}

View File

@@ -1,85 +0,0 @@
use std::rc::Rc;
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
use gpui::{Entity, FocusHandle};
use picker::popover_menu::PickerPopoverMenu;
use ui::{
ButtonLike, Context, IntoElement, PopoverMenuHandle, SharedString, Tooltip, Window, prelude::*,
};
use zed_actions::agent::ToggleModelSelector;
use crate::acp::{AcpModelSelector, model_selector::acp_model_selector};
pub struct AcpModelSelectorPopover {
selector: Entity<AcpModelSelector>,
menu_handle: PopoverMenuHandle<AcpModelSelector>,
focus_handle: FocusHandle,
}
impl AcpModelSelectorPopover {
pub(crate) fn new(
session_id: acp::SessionId,
selector: Rc<dyn AgentModelSelector>,
menu_handle: PopoverMenuHandle<AcpModelSelector>,
focus_handle: FocusHandle,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
Self {
selector: cx.new(move |cx| acp_model_selector(session_id, selector, window, cx)),
menu_handle,
focus_handle,
}
}
pub fn toggle(&self, window: &mut Window, cx: &mut Context<Self>) {
self.menu_handle.toggle(window, cx);
}
}
impl Render for AcpModelSelectorPopover {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let model = self.selector.read(cx).delegate.active_model();
let model_name = model
.as_ref()
.map(|model| model.name.clone())
.unwrap_or_else(|| SharedString::from("Select a Model"));
let model_icon = model.as_ref().and_then(|model| model.icon);
let focus_handle = self.focus_handle.clone();
PickerPopoverMenu::new(
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(model_icon, |this, icon| {
this.child(Icon::new(icon).color(Color::Muted).size(IconSize::XSmall))
})
.child(
Label::new(model_name)
.color(Color::Muted)
.size(LabelSize::Small)
.ml_0p5(),
)
.child(
Icon::new(IconName::ChevronDown)
.color(Color::Muted)
.size(IconSize::XSmall),
),
move |window, cx| {
Tooltip::for_action_in(
"Change Model",
&ToggleModelSelector,
&focus_handle,
window,
cx,
)
},
gpui::Corner::BottomRight,
cx,
)
.with_handle(self.menu_handle.clone())
.render(window, cx)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1521,8 +1521,7 @@ impl AgentDiff {
self.update_reviewing_editors(workspace, window, cx);
}
}
AcpThreadEvent::EntriesRemoved(_)
| AcpThreadEvent::Stopped
AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error
| AcpThreadEvent::ServerExited(_) => {}

File diff suppressed because it is too large Load Diff

View File

@@ -64,8 +64,6 @@ actions!(
NewTextThread,
/// Toggles the context picker interface for adding files, symbols, or other context.
ToggleContextPicker,
/// Toggles the menu to create new agent threads.
ToggleNewThreadMenu,
/// Toggles the navigation menu for switching between threads and views.
ToggleNavigationMenu,
/// Toggles the options menu for agent settings and preferences.
@@ -157,11 +155,11 @@ enum ExternalAgent {
}
impl ExternalAgent {
pub fn server(&self, fs: Arc<dyn fs::Fs>) -> Rc<dyn agent_servers::AgentServer> {
pub fn server(&self) -> Rc<dyn agent_servers::AgentServer> {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs)),
ExternalAgent::NativeAgent => Rc::new(agent2::NativeAgentServer),
}
}
}

View File

@@ -1,16 +1,15 @@
mod completion_provider;
pub(crate) mod fetch_context_picker;
mod fetch_context_picker;
pub(crate) mod file_context_picker;
pub(crate) mod rules_context_picker;
pub(crate) mod symbol_context_picker;
pub(crate) mod thread_context_picker;
mod rules_context_picker;
mod symbol_context_picker;
mod thread_context_picker;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::{Result, anyhow};
use collections::HashSet;
pub use completion_provider::ContextPickerCompletionProvider;
use editor::display_map::{Crease, CreaseId, CreaseMetadata, FoldId};
use editor::{Anchor, AnchorRangeExt as _, Editor, ExcerptId, FoldPlaceholder, ToOffset};
@@ -46,7 +45,7 @@ use agent::{
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ContextPickerEntry {
enum ContextPickerEntry {
Mode(ContextPickerMode),
Action(ContextPickerAction),
}
@@ -75,7 +74,7 @@ impl ContextPickerEntry {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ContextPickerMode {
enum ContextPickerMode {
File,
Symbol,
Fetch,
@@ -84,7 +83,7 @@ pub(crate) enum ContextPickerMode {
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ContextPickerAction {
enum ContextPickerAction {
AddSelections,
}
@@ -532,7 +531,7 @@ impl ContextPicker {
return vec![];
};
recent_context_picker_entries_with_store(
recent_context_picker_entries(
context_store,
self.thread_store.clone(),
self.text_thread_store.clone(),
@@ -586,8 +585,7 @@ impl Render for ContextPicker {
})
}
}
pub(crate) enum RecentEntry {
enum RecentEntry {
File {
project_path: ProjectPath,
path_prefix: Arc<str>,
@@ -595,7 +593,7 @@ pub(crate) enum RecentEntry {
Thread(ThreadContextEntry),
}
pub(crate) fn available_context_picker_entries(
fn available_context_picker_entries(
prompt_store: &Option<Entity<PromptStore>>,
thread_store: &Option<WeakEntity<ThreadStore>>,
workspace: &Entity<Workspace>,
@@ -632,56 +630,24 @@ pub(crate) fn available_context_picker_entries(
entries
}
fn recent_context_picker_entries_with_store(
fn recent_context_picker_entries(
context_store: Entity<ContextStore>,
thread_store: Option<WeakEntity<ThreadStore>>,
text_thread_store: Option<WeakEntity<TextThreadStore>>,
workspace: Entity<Workspace>,
exclude_path: Option<ProjectPath>,
cx: &App,
) -> Vec<RecentEntry> {
let project = workspace.read(cx).project();
let mut exclude_paths = context_store.read(cx).file_paths(cx);
exclude_paths.extend(exclude_path);
let exclude_paths = exclude_paths
.into_iter()
.filter_map(|project_path| project.read(cx).absolute_path(&project_path, cx))
.collect();
let exclude_threads = context_store.read(cx).thread_ids();
recent_context_picker_entries(
thread_store,
text_thread_store,
workspace,
&exclude_paths,
exclude_threads,
cx,
)
}
pub(crate) fn recent_context_picker_entries(
thread_store: Option<WeakEntity<ThreadStore>>,
text_thread_store: Option<WeakEntity<TextThreadStore>>,
workspace: Entity<Workspace>,
exclude_paths: &HashSet<PathBuf>,
exclude_threads: &HashSet<ThreadId>,
cx: &App,
) -> Vec<RecentEntry> {
let mut recent = Vec::with_capacity(6);
let mut current_files = context_store.read(cx).file_paths(cx);
current_files.extend(exclude_path);
let workspace = workspace.read(cx);
let project = workspace.project().read(cx);
recent.extend(
workspace
.recent_navigation_history_iter(cx)
.filter(|(_, abs_path)| {
abs_path
.as_ref()
.map_or(true, |path| !exclude_paths.contains(path.as_path()))
})
.filter(|(path, _)| !current_files.contains(path))
.take(4)
.filter_map(|(project_path, _)| {
project
@@ -693,6 +659,8 @@ pub(crate) fn recent_context_picker_entries(
}),
);
let current_threads = context_store.read(cx).thread_ids();
let active_thread_id = workspace
.panel::<AgentPanel>(cx)
.and_then(|panel| Some(panel.read(cx).active_thread(cx)?.read(cx).id()));
@@ -704,7 +672,7 @@ pub(crate) fn recent_context_picker_entries(
let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx)
.filter(|(_, thread)| match thread {
ThreadContextEntry::Thread { id, .. } => {
Some(id) != active_thread_id && !exclude_threads.contains(id)
Some(id) != active_thread_id && !current_threads.contains(id)
}
ThreadContextEntry::Context { .. } => true,
})
@@ -742,7 +710,7 @@ fn add_selections_as_context(
})
}
pub(crate) fn selection_ranges(
fn selection_ranges(
workspace: &Entity<Workspace>,
cx: &mut App,
) -> Vec<(Entity<Buffer>, Range<text::Anchor>)> {

View File

@@ -35,7 +35,7 @@ use super::symbol_context_picker::search_symbols;
use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads};
use super::{
ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry,
available_context_picker_entries, recent_context_picker_entries_with_store, selection_ranges,
available_context_picker_entries, recent_context_picker_entries, selection_ranges,
};
use crate::message_editor::ContextCreasesAddon;
@@ -787,7 +787,7 @@ impl CompletionProvider for ContextPickerCompletionProvider {
.and_then(|b| b.read(cx).file())
.map(|file| ProjectPath::from_file(file.as_ref(), cx));
let recent_entries = recent_context_picker_entries_with_store(
let recent_entries = recent_context_picker_entries(
context_store.clone(),
thread_store.clone(),
text_thread_store.clone(),

View File

@@ -2,7 +2,7 @@ mod agent_notification;
mod burn_mode_tooltip;
mod context_pill;
mod end_trial_upsell;
// mod new_thread_button;
mod new_thread_button;
mod onboarding_modal;
pub mod preview;
@@ -10,5 +10,5 @@ pub use agent_notification::*;
pub use burn_mode_tooltip::*;
pub use context_pill::*;
pub use end_trial_upsell::*;
// pub use new_thread_button::*;
pub use new_thread_button::*;
pub use onboarding_modal::*;

View File

@@ -11,7 +11,7 @@ pub struct NewThreadButton {
}
impl NewThreadButton {
fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
pub fn new(id: impl Into<ElementId>, label: impl Into<SharedString>, icon: IconName) -> Self {
Self {
id: id.into(),
label: label.into(),
@@ -21,12 +21,12 @@ impl NewThreadButton {
}
}
fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
pub fn keybinding(mut self, keybinding: Option<ui::KeyBinding>) -> Self {
self.keybinding = keybinding;
self
}
fn on_click<F>(mut self, handler: F) -> Self
pub fn on_click<F>(mut self, handler: F) -> Self
where
F: Fn(&mut Window, &mut App) + 'static,
{

View File

@@ -58,7 +58,9 @@ impl Assets {
pub fn load_test_fonts(&self, cx: &App) {
cx.text_system()
.add_fonts(vec![
self.load("fonts/lilex/Lilex-Regular.ttf").unwrap().unwrap(),
self.load("fonts/plex-mono/ZedPlexMono-Regular.ttf")
.unwrap()
.unwrap(),
])
.unwrap()
}

View File

@@ -11,9 +11,6 @@ workspace = true
[lib]
path = "src/assistant_context.rs"
[features]
test-support = []
[dependencies]
agent_settings.workspace = true
anyhow.workspace = true

View File

@@ -138,27 +138,6 @@ impl ContextStore {
})
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
Self {
contexts: Default::default(),
contexts_metadata: Default::default(),
context_server_slash_command_ids: Default::default(),
host_contexts: Default::default(),
fs: project.read(cx).fs().clone(),
languages: project.read(cx).languages().clone(),
slash_commands: Arc::default(),
telemetry: project.read(cx).client().telemetry().clone(),
_watch_updates: Task::ready(None),
client: project.read(cx).client(),
project,
project_is_shared: false,
client_subscription: None,
_project_subscriptions: Default::default(),
prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
}
}
async fn handle_advertise_contexts(
this: Entity<Self>,
envelope: TypedEnvelope<proto::AdvertiseContexts>,

View File

@@ -65,7 +65,7 @@ pub enum EditAgentOutputEvent {
ResolvingEditRange(Range<Anchor>),
UnresolvedEditRange,
AmbiguousEditRange(Vec<Range<usize>>),
Edited(Range<Anchor>),
Edited,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
@@ -178,9 +178,7 @@ impl EditAgent {
)
});
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited(
language::Anchor::MIN..language::Anchor::MAX,
))
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
})?;
@@ -202,9 +200,7 @@ impl EditAgent {
});
})?;
output_events_tx
.unbounded_send(EditAgentOutputEvent::Edited(
language::Anchor::MIN..language::Anchor::MAX,
))
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
}
@@ -340,8 +336,8 @@ impl EditAgent {
// Edit the buffer and report edits to the action log as part of the
// same effect cycle, otherwise the edit will be reported as if the
// user made it.
let (min_edit_start, max_edit_end) = cx.update(|cx| {
let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| {
cx.update(|cx| {
let max_edit_end = buffer.update(cx, |buffer, cx| {
buffer.edit(edits.iter().cloned(), None, cx);
let max_edit_end = buffer
.summaries_for_anchors::<Point, _>(
@@ -349,16 +345,7 @@ impl EditAgent {
)
.max()
.unwrap();
let min_edit_start = buffer
.summaries_for_anchors::<Point, _>(
edits.iter().map(|(range, _)| &range.start),
)
.min()
.unwrap();
(
buffer.anchor_after(min_edit_start),
buffer.anchor_before(max_edit_end),
)
buffer.anchor_before(max_edit_end)
});
self.action_log
.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
@@ -371,10 +358,9 @@ impl EditAgent {
cx,
);
});
(min_edit_start, max_edit_end)
})?;
output_events
.unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end))
.unbounded_send(EditAgentOutputEvent::Edited)
.ok();
}
@@ -769,7 +755,6 @@ mod tests {
use gpui::{AppContext, TestAppContext};
use indoc::indoc;
use language_model::fake_provider::FakeLanguageModel;
use pretty_assertions::assert_matches;
use project::{AgentLocation, Project};
use rand::prelude::*;
use rand::rngs::StdRng;
@@ -1007,10 +992,7 @@ mod tests {
model.send_last_completion_stream_text_chunk("<new_text>abX");
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited(_)]
);
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXc\ndef\nghi\njkl"
@@ -1025,10 +1007,7 @@ mod tests {
model.send_last_completion_stream_text_chunk("cY");
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
);
assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
"abXcY\ndef\nghi\njkl"
@@ -1139,9 +1118,9 @@ mod tests {
model.send_last_completion_stream_text_chunk("GHI</new_text>");
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1186,9 +1165,9 @@ mod tests {
);
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited(_)]
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1204,9 +1183,9 @@ mod tests {
chunks_tx.unbounded_send("```\njkl\n").unwrap();
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1222,9 +1201,9 @@ mod tests {
chunks_tx.unbounded_send("mno\n").unwrap();
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited { .. }]
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
@@ -1240,9 +1219,9 @@ mod tests {
chunks_tx.unbounded_send("pqr\n```").unwrap();
cx.run_until_parked();
assert_matches!(
drain_events(&mut events).as_slice(),
[EditAgentOutputEvent::Edited(_)],
assert_eq!(
drain_events(&mut events),
vec![EditAgentOutputEvent::Edited]
);
assert_eq!(
buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),

View File

@@ -307,7 +307,7 @@ impl Tool for EditFileTool {
let mut ambiguous_ranges = Vec::new();
while let Some(event) = events.next().await {
match event {
EditAgentOutputEvent::Edited { .. } => {
EditAgentOutputEvent::Edited => {
if let Some(card) = card_clone.as_ref() {
card.update(cx, |card, cx| card.update_diff(cx))?;
}

View File

@@ -18,6 +18,6 @@ collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
rodio = { workspace = true, features = ["wav", "playback", "tracing"] }
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -59,9 +59,16 @@ pub enum VersionCheckType {
pub enum AutoUpdateStatus {
Idle,
Checking,
Downloading { version: VersionCheckType },
Installing { version: VersionCheckType },
Updated { version: VersionCheckType },
Downloading {
version: VersionCheckType,
},
Installing {
version: VersionCheckType,
},
Updated {
binary_path: PathBuf,
version: VersionCheckType,
},
Errored,
}
@@ -76,7 +83,6 @@ pub struct AutoUpdater {
current_version: SemanticVersion,
http_client: Arc<HttpClientWithUrl>,
pending_poll: Option<Task<Option<()>>>,
quit_subscription: Option<gpui::Subscription>,
}
#[derive(Deserialize, Clone, Debug)]
@@ -158,7 +164,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
AutoUpdateSetting::register(cx);
cx.observe_new(|workspace: &mut Workspace, _window, _cx| {
workspace.register_action(|_, action, window, cx| check(action, window, cx));
workspace.register_action(|_, action: &Check, window, cx| check(action, window, cx));
workspace.register_action(|_, action, _, cx| {
view_release_notes(action, cx);
@@ -168,7 +174,7 @@ pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
let version = release_channel::AppVersion::global(cx);
let auto_updater = cx.new(|cx| {
let updater = AutoUpdater::new(version, http_client, cx);
let updater = AutoUpdater::new(version, http_client);
let poll_for_updates = ReleaseChannel::try_global(cx)
.map(|channel| channel.poll_for_updates())
@@ -315,34 +321,12 @@ impl AutoUpdater {
cx.default_global::<GlobalAutoUpdate>().0.clone()
}
fn new(
current_version: SemanticVersion,
http_client: Arc<HttpClientWithUrl>,
cx: &mut Context<Self>,
) -> Self {
// On windows, executable files cannot be overwritten while they are
// running, so we must wait to overwrite the application until quitting
// or restarting. When quitting the app, we spawn the auto update helper
// to finish the auto update process after Zed exits. When restarting
// the app after an update, we use `set_restart_path` to run the auto
// update helper instead of the app, so that it can overwrite the app
// and then spawn the new binary.
let quit_subscription = Some(cx.on_app_quit(|_, _| async move {
#[cfg(target_os = "windows")]
finalize_auto_update_on_quit();
}));
cx.on_app_restart(|this, _| {
this.quit_subscription.take();
})
.detach();
fn new(current_version: SemanticVersion, http_client: Arc<HttpClientWithUrl>) -> Self {
Self {
status: AutoUpdateStatus::Idle,
current_version,
http_client,
pending_poll: None,
quit_subscription,
}
}
@@ -552,8 +536,6 @@ impl AutoUpdater {
)
})?;
Self::check_dependencies()?;
this.update(&mut cx, |this, cx| {
this.status = AutoUpdateStatus::Checking;
cx.notify();
@@ -600,15 +582,13 @@ impl AutoUpdater {
cx.notify();
})?;
let new_binary_path = Self::install_release(installer_dir, target_path, &cx).await?;
if let Some(new_binary_path) = new_binary_path {
cx.update(|cx| cx.set_restart_path(new_binary_path))?;
}
let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?;
this.update(&mut cx, |this, cx| {
this.set_should_show_update_notification(true, cx)
.detach_and_log_err(cx);
this.status = AutoUpdateStatus::Updated {
binary_path,
version: newer_version,
};
cx.notify();
@@ -659,15 +639,6 @@ impl AutoUpdater {
}
}
fn check_dependencies() -> Result<()> {
#[cfg(not(target_os = "windows"))]
anyhow::ensure!(
which::which("rsync").is_ok(),
"Aborting. Could not find rsync which is required for auto-updates."
);
Ok(())
}
async fn target_path(installer_dir: &InstallerDir) -> Result<PathBuf> {
let filename = match OS {
"macos" => anyhow::Ok("Zed.dmg"),
@@ -676,14 +647,20 @@ impl AutoUpdater {
unsupported_os => anyhow::bail!("not supported: {unsupported_os}"),
}?;
#[cfg(not(target_os = "windows"))]
anyhow::ensure!(
which::which("rsync").is_ok(),
"Aborting. Could not find rsync which is required for auto-updates."
);
Ok(installer_dir.path().join(filename))
}
async fn install_release(
async fn binary_path(
installer_dir: InstallerDir,
target_path: PathBuf,
cx: &AsyncApp,
) -> Result<Option<PathBuf>> {
) -> Result<PathBuf> {
match OS {
"macos" => install_release_macos(&installer_dir, target_path, cx).await,
"linux" => install_release_linux(&installer_dir, target_path, cx).await,
@@ -824,7 +801,7 @@ async fn install_release_linux(
temp_dir: &InstallerDir,
downloaded_tar_gz: PathBuf,
cx: &AsyncApp,
) -> Result<Option<PathBuf>> {
) -> Result<PathBuf> {
let channel = cx.update(|cx| ReleaseChannel::global(cx).dev_name())?;
let home_dir = PathBuf::from(env::var("HOME").context("no HOME env var set")?);
let running_app_path = cx.update(|cx| cx.app_path())??;
@@ -884,14 +861,14 @@ async fn install_release_linux(
String::from_utf8_lossy(&output.stderr)
);
Ok(Some(to.join(expected_suffix)))
Ok(to.join(expected_suffix))
}
async fn install_release_macos(
temp_dir: &InstallerDir,
downloaded_dmg: PathBuf,
cx: &AsyncApp,
) -> Result<Option<PathBuf>> {
) -> Result<PathBuf> {
let running_app_path = cx.update(|cx| cx.app_path())??;
let running_app_filename = running_app_path
.file_name()
@@ -933,10 +910,10 @@ async fn install_release_macos(
String::from_utf8_lossy(&output.stderr)
);
Ok(None)
Ok(running_app_path)
}
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<Option<PathBuf>> {
async fn install_release_windows(downloaded_installer: PathBuf) -> Result<PathBuf> {
let output = Command::new(downloaded_installer)
.arg("/verysilent")
.arg("/update=true")
@@ -949,36 +926,29 @@ async fn install_release_windows(downloaded_installer: PathBuf) -> Result<Option
"failed to start installer: {:?}",
String::from_utf8_lossy(&output.stderr)
);
// We return the path to the update helper program, because it will
// perform the final steps of the update process, copying the new binary,
// deleting the old one, and launching the new binary.
let helper_path = std::env::current_exe()?
.parent()
.context("No parent dir for Zed.exe")?
.join("tools\\auto_update_helper.exe");
Ok(Some(helper_path))
Ok(std::env::current_exe()?)
}
pub fn finalize_auto_update_on_quit() {
pub fn check_pending_installation() -> bool {
let Some(installer_path) = std::env::current_exe()
.ok()
.and_then(|p| p.parent().map(|p| p.join("updates")))
else {
return;
return false;
};
// The installer will create a flag file after it finishes updating
let flag_file = installer_path.join("versions.txt");
if flag_file.exists()
&& let Some(helper) = installer_path
if flag_file.exists() {
if let Some(helper) = installer_path
.parent()
.map(|p| p.join("tools\\auto_update_helper.exe"))
{
let mut command = std::process::Command::new(helper);
command.arg("--launch");
command.arg("false");
let _ = command.spawn();
{
let _ = std::process::Command::new(helper).spawn();
return true;
}
}
false
}
#[cfg(test)]
@@ -1032,6 +1002,7 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
};
let fetched_version = SemanticVersion::new(1, 0, 1);
@@ -1053,6 +1024,7 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
};
let fetched_version = SemanticVersion::new(1, 0, 2);
@@ -1118,6 +1090,7 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "b".to_string();
@@ -1139,6 +1112,7 @@ mod tests {
let app_commit_sha = Ok(Some("a".to_string()));
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "c".to_string();
@@ -1186,6 +1160,7 @@ mod tests {
let app_commit_sha = Ok(None);
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "b".to_string();
@@ -1208,6 +1183,7 @@ mod tests {
let app_commit_sha = Ok(None);
let installed_version = SemanticVersion::new(1, 0, 0);
let status = AutoUpdateStatus::Updated {
binary_path: PathBuf::new(),
version: VersionCheckType::Sha(AppCommitSha::new("b".to_string())),
};
let fetched_sha = "c".to_string();

View File

@@ -18,7 +18,7 @@ fn main() {}
#[cfg(target_os = "windows")]
mod windows_impl {
use std::{borrow::Cow, path::Path};
use std::path::Path;
use super::dialog::create_dialog_window;
use super::updater::perform_update;
@@ -37,11 +37,6 @@ mod windows_impl {
pub(crate) const WM_JOB_UPDATED: u32 = WM_USER + 1;
pub(crate) const WM_TERMINATE: u32 = WM_USER + 2;
#[derive(Debug, Default)]
struct Args {
launch: bool,
}
pub(crate) fn run() -> Result<()> {
let helper_dir = std::env::current_exe()?
.parent()
@@ -56,9 +51,8 @@ mod windows_impl {
log::info!("======= Starting Zed update =======");
let (tx, rx) = std::sync::mpsc::channel();
let hwnd = create_dialog_window(rx)?.0 as isize;
let args = parse_args(std::env::args().skip(1));
std::thread::spawn(move || {
let result = perform_update(app_dir.as_path(), Some(hwnd), args.launch);
let result = perform_update(app_dir.as_path(), Some(hwnd));
tx.send(result).ok();
unsafe { PostMessageW(Some(HWND(hwnd as _)), WM_TERMINATE, WPARAM(0), LPARAM(0)) }.ok();
});
@@ -83,29 +77,6 @@ mod windows_impl {
Ok(())
}
fn parse_args(input: impl IntoIterator<Item = String>) -> Args {
let mut args: Args = Args { launch: true };
let mut input = input.into_iter();
if let Some(arg) = input.next() {
let launch_arg;
if arg == "--launch" {
launch_arg = input.next().map(Cow::Owned);
} else if let Some(rest) = arg.strip_prefix("--launch=") {
launch_arg = Some(Cow::Borrowed(rest));
} else {
launch_arg = None;
}
if launch_arg.as_deref() == Some("false") {
args.launch = false;
}
}
args
}
pub(crate) fn show_error(mut content: String) {
if content.len() > 600 {
content.truncate(600);
@@ -120,31 +91,4 @@ mod windows_impl {
)
};
}
#[cfg(test)]
mod tests {
use crate::windows_impl::parse_args;
#[test]
fn test_parse_args() {
// launch can be specified via two separate arguments
assert_eq!(parse_args(["--launch".into(), "true".into()]).launch, true);
assert_eq!(
parse_args(["--launch".into(), "false".into()]).launch,
false
);
// launch can be specified via one single argument
assert_eq!(parse_args(["--launch=true".into()]).launch, true);
assert_eq!(parse_args(["--launch=false".into()]).launch, false);
// launch defaults to true on no arguments
assert_eq!(parse_args([]).launch, true);
// launch defaults to true on invalid arguments
assert_eq!(parse_args(["--launch".into()]).launch, true);
assert_eq!(parse_args(["--launch=".into()]).launch, true);
assert_eq!(parse_args(["--launch=invalid".into()]).launch, true);
}
}
}

View File

@@ -72,7 +72,7 @@ pub(crate) fn create_dialog_window(receiver: Receiver<Result<()>>) -> Result<HWN
let hwnd = CreateWindowExW(
WS_EX_TOPMOST,
class_name,
windows::core::w!("Zed"),
windows::core::w!("Zed Editor"),
WS_VISIBLE | WS_POPUP | WS_CAPTION,
rect.right / 2 - width / 2,
rect.bottom / 2 - height / 2,
@@ -171,7 +171,7 @@ unsafe extern "system" fn wnd_proc(
&HSTRING::from(font_name),
);
let temp = SelectObject(hdc, font.into());
let string = HSTRING::from("Updating Zed...");
let string = HSTRING::from("Zed Editor is updating...");
return_if_failed!(TextOutW(hdc, 20, 15, &string).ok());
return_if_failed!(DeleteObject(temp).ok());

View File

@@ -118,7 +118,7 @@ pub(crate) const JOBS: [Job; 2] = [
},
];
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool) -> Result<()> {
pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>) -> Result<()> {
let hwnd = hwnd.map(|ptr| HWND(ptr as _));
for job in JOBS.iter() {
@@ -145,11 +145,9 @@ pub(crate) fn perform_update(app_dir: &Path, hwnd: Option<isize>, launch: bool)
}
}
}
if launch {
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
.spawn();
}
let _ = std::process::Command::new(app_dir.join("Zed.exe"))
.creation_flags(CREATE_NEW_PROCESS_GROUP.0)
.spawn();
log::info!("Update completed successfully");
Ok(())
}
@@ -161,11 +159,11 @@ mod test {
#[test]
fn test_perform_update() {
let app_dir = std::path::Path::new("C:/");
assert!(perform_update(app_dir, None, false).is_ok());
assert!(perform_update(app_dir, None).is_ok());
// Simulate a timeout
unsafe { std::env::set_var("ZED_AUTO_UPDATE", "err") };
let ret = perform_update(app_dir, None, false);
let ret = perform_update(app_dir, None);
assert!(ret.is_err_and(|e| e.to_string().as_str() == "Timed out"));
}
}

View File

@@ -10,10 +10,10 @@ use client::{
};
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::StreamExt;
use futures::{FutureExt, StreamExt};
use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FutureExt as _,
ScreenCaptureSource, ScreenCaptureStream, Task, Timeout, WeakEntity,
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource,
ScreenCaptureStream, Task, WeakEntity,
};
use gpui_tokio::Tokio;
use language::LanguageRegistry;
@@ -370,53 +370,57 @@ impl Room {
})?;
// Wait for client to re-establish a connection to the server.
let executor = cx.background_executor().clone();
let client_reconnection = async {
let mut remaining_attempts = 3;
while remaining_attempts > 0 {
if client_status.borrow().is_connected() {
log::info!("client reconnected, attempting to rejoin room");
let Some(this) = this.upgrade() else { break };
match this.update(cx, |this, cx| this.rejoin(cx)) {
Ok(task) => {
if task.await.log_err().is_some() {
return true;
} else {
remaining_attempts -= 1;
}
}
Err(_app_dropped) => return false,
}
} else if client_status.borrow().is_signed_out() {
return false;
}
log::info!(
"waiting for client status change, remaining attempts {}",
remaining_attempts
);
client_status.next().await;
}
false
};
match client_reconnection
.with_timeout(RECONNECT_TIMEOUT, &executor)
.await
{
Ok(true) => {
log::info!("successfully reconnected to room");
// If we successfully joined the room, go back around the loop
// waiting for future connection status changes.
continue;
let mut reconnection_timeout =
cx.background_executor().timer(RECONNECT_TIMEOUT).fuse();
let client_reconnection = async {
let mut remaining_attempts = 3;
while remaining_attempts > 0 {
if client_status.borrow().is_connected() {
log::info!("client reconnected, attempting to rejoin room");
let Some(this) = this.upgrade() else { break };
match this.update(cx, |this, cx| this.rejoin(cx)) {
Ok(task) => {
if task.await.log_err().is_some() {
return true;
} else {
remaining_attempts -= 1;
}
}
Err(_app_dropped) => return false,
}
} else if client_status.borrow().is_signed_out() {
return false;
}
log::info!(
"waiting for client status change, remaining attempts {}",
remaining_attempts
);
client_status.next().await;
}
false
}
Ok(false) => break,
Err(Timeout) => {
log::info!("room reconnection timeout expired");
break;
.fuse();
futures::pin_mut!(client_reconnection);
futures::select_biased! {
reconnected = client_reconnection => {
if reconnected {
log::info!("successfully reconnected to room");
// If we successfully joined the room, go back around the loop
// waiting for future connection status changes.
continue;
}
}
_ = reconnection_timeout => {
log::info!("room reconnection timeout expired");
}
}
}
break;
}
}

View File

@@ -957,14 +957,17 @@ mod mac_os {
) -> Result<()> {
use anyhow::bail;
let app_path_prompt = format!(
"POSIX path of (path to application \"{}\")",
channel.display_name()
);
let app_path_output = Command::new("osascript")
let app_id_prompt = format!("id of app \"{}\"", channel.display_name());
let app_id_output = Command::new("osascript")
.arg("-e")
.arg(&app_path_prompt)
.arg(&app_id_prompt)
.output()?;
if !app_id_output.status.success() {
bail!("Could not determine app id for {}", channel.display_name());
}
let app_name = String::from_utf8(app_id_output.stdout)?.trim().to_owned();
let app_path_prompt = format!("kMDItemCFBundleIdentifier == '{app_name}'");
let app_path_output = Command::new("mdfind").arg(app_path_prompt).output()?;
if !app_path_output.status.success() {
bail!(
"Could not determine app path for {}",

View File

@@ -340,35 +340,22 @@ impl Telemetry {
}
pub fn log_edit_event(self: &Arc<Self>, environment: &'static str, is_via_ssh: bool) {
static LAST_EVENT_TIME: Mutex<Option<Instant>> = Mutex::new(None);
let mut state = self.state.lock();
let period_data = state.event_coalescer.log_event(environment);
drop(state);
if let Some(mut last_event) = LAST_EVENT_TIME.try_lock() {
let current_time = std::time::Instant::now();
let last_time = last_event.get_or_insert(current_time);
if let Some((start, end, environment)) = period_data {
let duration = end
.saturating_duration_since(start)
.min(Duration::from_secs(60 * 60 * 24))
.as_millis() as i64;
if current_time.duration_since(*last_time) > Duration::from_secs(60 * 10) {
*last_time = current_time;
} else {
return;
}
if let Some((start, end, environment)) = period_data {
let duration = end
.saturating_duration_since(start)
.min(Duration::from_secs(60 * 60 * 24))
.as_millis() as i64;
telemetry::event!(
"Editor Edited",
duration = duration,
environment = environment,
is_via_ssh = is_via_ssh
);
}
telemetry::event!(
"Editor Edited",
duration = duration,
environment = environment,
is_via_ssh = is_via_ssh
);
}
}

View File

@@ -21,7 +21,7 @@ use language::{
point_from_lsp, point_to_lsp,
};
use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId, LanguageServerName};
use node_runtime::{NodeRuntime, VersionStrategy};
use node_runtime::NodeRuntime;
use parking_lot::Mutex;
use project::DisableAiSettings;
use request::StatusNotification;
@@ -349,11 +349,7 @@ impl Copilot {
this.start_copilot(true, false, cx);
cx.observe_global::<SettingsStore>(move |this, cx| {
this.start_copilot(true, false, cx);
if let Ok(server) = this.server.as_running() {
notify_did_change_config_to_server(&server.lsp, cx)
.context("copilot setting change: did change configuration")
.log_err();
}
this.send_configuration_update(cx);
})
.detach();
this
@@ -442,6 +438,43 @@ impl Copilot {
if env.is_empty() { None } else { Some(env) }
}
fn send_configuration_update(&mut self, cx: &mut Context<Self>) {
let copilot_settings = all_language_settings(None, cx)
.edit_predictions
.copilot
.clone();
let settings = json!({
"http": {
"proxy": copilot_settings.proxy,
"proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false)
},
"github-enterprise": {
"uri": copilot_settings.enterprise_uri
}
});
if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
copilot_chat.update(cx, |chat, cx| {
chat.set_configuration(
copilot_chat::CopilotChatConfiguration {
enterprise_uri: copilot_settings.enterprise_uri.clone(),
},
cx,
);
});
}
if let Ok(server) = self.server.as_running() {
server
.lsp
.notify::<lsp::notification::DidChangeConfiguration>(
&lsp::DidChangeConfigurationParams { settings },
)
.log_err();
}
}
#[cfg(any(test, feature = "test-support"))]
pub fn fake(cx: &mut gpui::TestAppContext) -> (Entity<Self>, lsp::FakeLanguageServer) {
use fs::FakeFs;
@@ -540,9 +573,6 @@ impl Copilot {
})?
.await?;
this.update(cx, |_, cx| notify_did_change_config_to_server(&server, cx))?
.context("copilot: did change configuration")?;
let status = server
.request::<request::CheckStatus>(request::CheckStatusParams {
local_checks_only: false,
@@ -568,6 +598,8 @@ impl Copilot {
});
cx.emit(Event::CopilotLanguageServerStarted);
this.update_sign_in_status(status, cx);
// Send configuration now that the LSP is fully started
this.send_configuration_update(cx);
}
Err(error) => {
this.server = CopilotServer::Error(error.to_string().into());
@@ -1124,41 +1156,6 @@ fn uri_for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Result<lsp::Url, ()> {
}
}
fn notify_did_change_config_to_server(
server: &Arc<LanguageServer>,
cx: &mut Context<Copilot>,
) -> std::result::Result<(), anyhow::Error> {
let copilot_settings = all_language_settings(None, cx)
.edit_predictions
.copilot
.clone();
if let Some(copilot_chat) = copilot_chat::CopilotChat::global(cx) {
copilot_chat.update(cx, |chat, cx| {
chat.set_configuration(
copilot_chat::CopilotChatConfiguration {
enterprise_uri: copilot_settings.enterprise_uri.clone(),
},
cx,
);
});
}
let settings = json!({
"http": {
"proxy": copilot_settings.proxy,
"proxyStrictSSL": !copilot_settings.proxy_no_verify.unwrap_or(false)
},
"github-enterprise": {
"uri": copilot_settings.enterprise_uri
}
});
server.notify::<lsp::notification::DidChangeConfiguration>(&lsp::DidChangeConfigurationParams {
settings,
})
}
async fn clear_copilot_dir() {
remove_matching(paths::copilot_dir(), |_| true).await
}
@@ -1184,7 +1181,7 @@ async fn get_copilot_lsp(fs: Arc<dyn Fs>, node_runtime: NodeRuntime) -> anyhow::
PACKAGE_NAME,
&server_path,
paths::copilot_dir(),
VersionStrategy::Latest(&latest_version),
&latest_version,
)
.await;
if should_install {

View File

@@ -273,16 +273,6 @@ pub enum UuidVersion {
V7,
}
/// Splits selection into individual lines.
#[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)]
#[action(namespace = editor)]
#[serde(deny_unknown_fields)]
pub struct SplitSelectionIntoLines {
/// Keep the text selected after splitting instead of collapsing to cursors.
#[serde(default)]
pub keep_selections: bool,
}
/// Goes to the next diagnostic in the file.
#[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema, Action)]
#[action(namespace = editor)]
@@ -682,6 +672,8 @@ actions!(
SortLinesCaseInsensitive,
/// Sorts selected lines case-sensitively.
SortLinesCaseSensitive,
/// Splits selection into individual lines.
SplitSelectionIntoLines,
/// Stops the language server for the current file.
StopLanguageServer,
/// Switches between source and header files.

View File

@@ -2290,6 +2290,8 @@ mod tests {
fn test_blocks_on_wrapped_lines(cx: &mut gpui::TestAppContext) {
cx.update(init_test);
let _font_id = cx.text_system().font_id(&font("Helvetica")).unwrap();
let text = "one two three\nfour five six\nseven eight";
let buffer = cx.update(|cx| MultiBuffer::build_simple(text, cx));

View File

@@ -1223,7 +1223,7 @@ mod tests {
let tab_size = NonZeroU32::new(rng.gen_range(1..=4)).unwrap();
let font = test_font();
let _font_id = text_system.resolve_font(&font);
let _font_id = text_system.font_id(&font);
let font_size = px(14.0);
log::info!("Tab size: {}", tab_size);

View File

@@ -250,24 +250,6 @@ pub type RenderDiffHunkControlsFn = Arc<
) -> AnyElement,
>;
enum ReportEditorEvent {
Saved { auto_saved: bool },
EditorOpened,
ZetaTosClicked,
Closed,
}
impl ReportEditorEvent {
pub fn event_type(&self) -> &'static str {
match self {
Self::Saved { .. } => "Editor Saved",
Self::EditorOpened => "Editor Opened",
Self::ZetaTosClicked => "Edit Prediction Provider ToS Clicked",
Self::Closed => "Editor Closed",
}
}
}
struct InlineValueCache {
enabled: bool,
inlays: Vec<InlayId>,
@@ -2343,7 +2325,7 @@ impl Editor {
}
if editor.mode.is_full() {
editor.report_editor_event(ReportEditorEvent::EditorOpened, None, cx);
editor.report_editor_event("Editor Opened", None, cx);
}
editor
@@ -9142,7 +9124,7 @@ impl Editor {
.on_mouse_down(MouseButton::Left, |_, window, _| window.prevent_default())
.on_click(cx.listener(|this, _event, window, cx| {
cx.stop_propagation();
this.report_editor_event(ReportEditorEvent::ZetaTosClicked, None, cx);
this.report_editor_event("Edit Prediction Provider ToS Clicked", None, cx);
window.dispatch_action(
zed_actions::OpenZedPredictOnboarding.boxed_clone(),
cx,
@@ -12176,8 +12158,6 @@ impl Editor {
let clipboard_text = Cow::Borrowed(text);
self.transact(window, cx, |this, window, cx| {
let had_active_edit_prediction = this.has_active_edit_prediction();
if let Some(mut clipboard_selections) = clipboard_selections {
let old_selections = this.selections.all::<usize>(cx);
let all_selections_were_entire_line =
@@ -12250,11 +12230,6 @@ impl Editor {
} else {
this.insert(&clipboard_text, window, cx);
}
let trigger_in_words =
this.show_edit_predictions_in_menu() || !had_active_edit_prediction;
this.trigger_completion_on_input(&text, trigger_in_words, window, cx);
});
}
@@ -13612,7 +13587,7 @@ impl Editor {
pub fn split_selection_into_lines(
&mut self,
action: &SplitSelectionIntoLines,
_: &SplitSelectionIntoLines,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -13629,21 +13604,8 @@ impl Editor {
let buffer = self.buffer.read(cx).read(cx);
for selection in selections {
for row in selection.start.row..selection.end.row {
let line_start = Point::new(row, 0);
let line_end = Point::new(row, buffer.line_len(MultiBufferRow(row)));
if action.keep_selections {
// Keep the selection range for each line
let selection_start = if row == selection.start.row {
selection.start
} else {
line_start
};
new_selection_ranges.push(selection_start..line_end);
} else {
// Collapse to cursor at end of line
new_selection_ranges.push(line_end..line_end);
}
let cursor = Point::new(row, buffer.line_len(MultiBufferRow(row)));
new_selection_ranges.push(cursor..cursor);
}
let is_multiline_selection = selection.start.row != selection.end.row;
@@ -13651,16 +13613,7 @@ impl Editor {
// so this action feels more ergonomic when paired with other selection operations
let should_skip_last = is_multiline_selection && selection.end.column == 0;
if !should_skip_last {
if action.keep_selections {
if is_multiline_selection {
let line_start = Point::new(selection.end.row, 0);
new_selection_ranges.push(line_start..selection.end);
} else {
new_selection_ranges.push(selection.start..selection.end);
}
} else {
new_selection_ranges.push(selection.end..selection.end);
}
new_selection_ranges.push(selection.end..selection.end);
}
}
}
@@ -15864,25 +15817,19 @@ impl Editor {
let tab_kind = match kind {
Some(GotoDefinitionKind::Implementation) => "Implementations",
Some(GotoDefinitionKind::Symbol) | None => "Definitions",
Some(GotoDefinitionKind::Declaration) => "Declarations",
Some(GotoDefinitionKind::Type) => "Types",
_ => "Definitions",
};
let title = editor
.update_in(acx, |_, _, cx| {
let target = locations
.iter()
.map(|location| {
location
.buffer
.read(cx)
.text_for_range(location.range.clone())
.collect::<String>()
})
.unique()
.take(3)
.join(", ");
format!("{tab_kind} for {target}")
let origin = locations.first().unwrap();
let buffer = origin.buffer.read(cx);
format!(
"{} for {}",
tab_kind,
buffer
.text_for_range(origin.range.clone())
.collect::<String>()
)
})
.context("buffer title")?;
@@ -16078,19 +16025,19 @@ impl Editor {
}
workspace.update_in(cx, |workspace, window, cx| {
let target = locations
.iter()
let title = locations
.first()
.as_ref()
.map(|location| {
location
.buffer
.read(cx)
.text_for_range(location.range.clone())
.collect::<String>()
let buffer = location.buffer.read(cx);
format!(
"References to `{}`",
buffer
.text_for_range(location.range.clone())
.collect::<String>()
)
})
.unique()
.take(3)
.join(", ");
let title = format!("References to {target}");
.unwrap();
Self::open_locations_in_multibuffer(
workspace,
locations,
@@ -20235,7 +20182,6 @@ impl Editor {
);
let old_cursor_shape = self.cursor_shape;
let old_show_breadcrumbs = self.show_breadcrumbs;
{
let editor_settings = EditorSettings::get_global(cx);
@@ -20249,10 +20195,6 @@ impl Editor {
cx.emit(EditorEvent::CursorShapeChanged);
}
if old_show_breadcrumbs != self.show_breadcrumbs {
cx.emit(EditorEvent::BreadcrumbsChanged);
}
let project_settings = ProjectSettings::get_global(cx);
self.serialize_dirty_buffers =
!self.mode.is_minimap() && project_settings.session.restore_unsaved_buffers;
@@ -20605,7 +20547,7 @@ impl Editor {
fn report_editor_event(
&self,
reported_event: ReportEditorEvent,
event_type: &'static str,
file_extension: Option<String>,
cx: &App,
) {
@@ -20639,30 +20581,15 @@ impl Editor {
.show_edit_predictions;
let project = project.read(cx);
let event_type = reported_event.event_type();
if let ReportEditorEvent::Saved { auto_saved } = reported_event {
telemetry::event!(
event_type,
type = if auto_saved {"autosave"} else {"manual"},
file_extension,
vim_mode,
copilot_enabled,
copilot_enabled_for_language,
edit_predictions_provider,
is_via_ssh = project.is_via_ssh(),
);
} else {
telemetry::event!(
event_type,
file_extension,
vim_mode,
copilot_enabled,
copilot_enabled_for_language,
edit_predictions_provider,
is_via_ssh = project.is_via_ssh(),
);
};
telemetry::event!(
event_type,
file_extension,
vim_mode,
copilot_enabled,
copilot_enabled_for_language,
edit_predictions_provider,
is_via_ssh = project.is_via_ssh(),
);
}
/// Copy the highlighted chunks to the clipboard as JSON. The format is an array of lines,
@@ -22874,7 +22801,6 @@ pub enum EditorEvent {
},
Reloaded,
CursorShapeChanged,
BreadcrumbsChanged,
PushedToNavHistory {
anchor: Anchor,
is_deactivate: bool,

View File

@@ -1901,51 +1901,6 @@ fn test_beginning_of_line_stop_at_indent(cx: &mut TestAppContext) {
});
}
#[gpui::test]
fn test_beginning_of_line_with_cursor_between_line_start_and_indent(cx: &mut TestAppContext) {
init_test(cx, |_| {});
let move_to_beg = MoveToBeginningOfLine {
stop_at_soft_wraps: true,
stop_at_indent: true,
};
let editor = cx.add_window(|window, cx| {
let buffer = MultiBuffer::build_simple(" hello\nworld", cx);
build_editor(buffer, window, cx)
});
_ = editor.update(cx, |editor, window, cx| {
// test cursor between line_start and indent_start
editor.change_selections(SelectionEffects::no_scroll(), window, cx, |s| {
s.select_display_ranges([
DisplayPoint::new(DisplayRow(0), 3)..DisplayPoint::new(DisplayRow(0), 3)
]);
});
// cursor should move to line_start
editor.move_to_beginning_of_line(&move_to_beg, window, cx);
assert_eq!(
editor.selections.display_ranges(cx),
&[DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0)]
);
// cursor should move to indent_start
editor.move_to_beginning_of_line(&move_to_beg, window, cx);
assert_eq!(
editor.selections.display_ranges(cx),
&[DisplayPoint::new(DisplayRow(0), 4)..DisplayPoint::new(DisplayRow(0), 4)]
);
// cursor should move to back to line_start
editor.move_to_beginning_of_line(&move_to_beg, window, cx);
assert_eq!(
editor.selections.display_ranges(cx),
&[DisplayPoint::new(DisplayRow(0), 0)..DisplayPoint::new(DisplayRow(0), 0)]
);
});
}
#[gpui::test]
fn test_prev_next_word_boundary(cx: &mut TestAppContext) {
init_test(cx, |_| {});
@@ -6401,7 +6356,7 @@ async fn test_split_selection_into_lines(cx: &mut TestAppContext) {
fn test(cx: &mut EditorTestContext, initial_state: &'static str, expected_state: &'static str) {
cx.set_state(initial_state);
cx.update_editor(|e, window, cx| {
e.split_selection_into_lines(&Default::default(), window, cx)
e.split_selection_into_lines(&SplitSelectionIntoLines, window, cx)
});
cx.assert_editor_state(expected_state);
}
@@ -6489,7 +6444,7 @@ async fn test_split_selection_into_lines_interacting_with_creases(cx: &mut TestA
DisplayPoint::new(DisplayRow(4), 4)..DisplayPoint::new(DisplayRow(4), 4),
])
});
editor.split_selection_into_lines(&Default::default(), window, cx);
editor.split_selection_into_lines(&SplitSelectionIntoLines, window, cx);
assert_eq!(
editor.display_text(cx),
"aaaaa\nbbbbb\nccc⋯eeee\nfffff\nggggg\n⋯i"
@@ -6505,7 +6460,7 @@ async fn test_split_selection_into_lines_interacting_with_creases(cx: &mut TestA
DisplayPoint::new(DisplayRow(5), 0)..DisplayPoint::new(DisplayRow(0), 1)
])
});
editor.split_selection_into_lines(&Default::default(), window, cx);
editor.split_selection_into_lines(&SplitSelectionIntoLines, window, cx);
assert_eq!(
editor.display_text(cx),
"aaaaa\nbbbbb\nccccc\nddddd\neeeee\nfffff\nggggg\nhhhhh\niiiii"
@@ -22501,7 +22456,7 @@ async fn test_invisible_worktree_servers(cx: &mut TestAppContext) {
);
cx.update(|_, cx| {
workspace::reload(cx);
workspace::reload(&workspace::Reload::default(), cx);
});
assert_language_servers_count(
1,

View File

@@ -3011,7 +3011,7 @@ impl EditorElement {
.icon_color(Color::Custom(cx.theme().colors().editor_line_number))
.selected_icon_color(Color::Custom(cx.theme().colors().editor_foreground))
.icon_size(IconSize::Custom(rems(editor_font_size / window.rem_size())))
.width(width)
.width(width.into())
.on_click(move |_, window, cx| {
editor.update(cx, |editor, cx| {
editor.expand_excerpt(excerpt_id, direction, window, cx);
@@ -3627,7 +3627,7 @@ impl EditorElement {
ButtonLike::new("toggle-buffer-fold")
.style(ui::ButtonStyle::Transparent)
.height(px(28.).into())
.width(px(28.))
.width(px(28.).into())
.children(toggle_chevron_icon)
.tooltip({
let focus_handle = focus_handle.clone();

View File

@@ -1,7 +1,7 @@
use crate::{
Anchor, Autoscroll, Editor, EditorEvent, EditorSettings, ExcerptId, ExcerptRange, FormatTarget,
MultiBuffer, MultiBufferSnapshot, NavigationData, ReportEditorEvent, SearchWithinRange,
SelectionEffects, ToPoint as _,
MultiBuffer, MultiBufferSnapshot, NavigationData, SearchWithinRange, SelectionEffects,
ToPoint as _,
display_map::HighlightKey,
editor_settings::SeedQuerySetting,
persistence::{DB, SerializedEditor},
@@ -776,10 +776,6 @@ impl Item for Editor {
}
}
fn on_removed(&self, cx: &App) {
self.report_editor_event(ReportEditorEvent::Closed, None, cx);
}
fn deactivated(&mut self, _: &mut Window, cx: &mut Context<Self>) {
let selection = self.selections.newest_anchor();
self.push_to_nav_history(selection.head(), None, true, false, cx);
@@ -819,9 +815,9 @@ impl Item for Editor {
) -> Task<Result<()>> {
// Add meta data tracking # of auto saves
if options.autosave {
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: true }, None, cx);
self.report_editor_event("Editor Autosaved", None, cx);
} else {
self.report_editor_event(ReportEditorEvent::Saved { auto_saved: false }, None, cx);
self.report_editor_event("Editor Saved", None, cx);
}
let buffers = self.buffer().clone().read(cx).all_buffers();
@@ -900,11 +896,7 @@ impl Item for Editor {
.path
.extension()
.map(|a| a.to_string_lossy().to_string());
self.report_editor_event(
ReportEditorEvent::Saved { auto_saved: false },
file_extension,
cx,
);
self.report_editor_event("Editor Saved", file_extension, cx);
project.update(cx, |project, cx| project.save_buffer_as(buffer, path, cx))
}
@@ -1005,16 +997,12 @@ impl Item for Editor {
) {
self.workspace = Some((workspace.weak_handle(), workspace.database_id()));
if let Some(workspace) = &workspace.weak_handle().upgrade() {
cx.subscribe(
&workspace,
|editor, _, event: &workspace::Event, _cx| match event {
workspace::Event::ModalOpened => {
editor.mouse_context_menu.take();
editor.inline_blame_popover.take();
}
_ => {}
},
)
cx.subscribe(&workspace, |editor, _, event: &workspace::Event, _cx| {
if matches!(event, workspace::Event::ModalOpened) {
editor.mouse_context_menu.take();
editor.inline_blame_popover.take();
}
})
.detach();
}
}
@@ -1036,10 +1024,6 @@ impl Item for Editor {
f(ItemEvent::UpdateBreadcrumbs);
}
EditorEvent::BreadcrumbsChanged => {
f(ItemEvent::UpdateBreadcrumbs);
}
EditorEvent::DirtyChanged => {
f(ItemEvent::UpdateTab);
}

View File

@@ -230,7 +230,7 @@ pub fn indented_line_beginning(
if stop_at_soft_boundaries && soft_line_start > indent_start && display_point != soft_line_start
{
soft_line_start
} else if stop_at_indent && (display_point > indent_start || display_point == line_start) {
} else if stop_at_indent && display_point != indent_start {
indent_start
} else {
line_start

View File

@@ -53,7 +53,7 @@ pub fn marked_display_snapshot(
let (unmarked_text, markers) = marked_text_offsets(text);
let font = Font {
family: ".ZedMono".into(),
family: "Zed Plex Mono".into(),
features: FontFeatures::default(),
fallbacks: None,
weight: FontWeight::default(),

View File

@@ -1118,17 +1118,15 @@ impl ExtensionStore {
extensions_to_unload.len() - reload_count
);
let extension_ids = extensions_to_load
.iter()
.filter_map(|id| {
Some((
id.clone(),
new_index.extensions.get(id)?.manifest.version.clone(),
))
})
.collect::<Vec<_>>();
telemetry::event!("Extensions Loaded", id_and_versions = extension_ids);
for extension_id in &extensions_to_load {
if let Some(extension) = new_index.extensions.get(extension_id) {
telemetry::event!(
"Extension Loaded",
extension_id,
version = extension.manifest.version
);
}
}
let themes_to_remove = old_index
.themes

View File

@@ -33,23 +33,13 @@ impl FileIcons {
// TODO: Associate a type with the languages and have the file's language
// override these associations
if let Some(mut typ) = path.file_name().and_then(|typ| typ.to_str()) {
// check if file name is in suffixes
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
// check if file name is in suffixes
// e.g. catch file named `eslint.config.js` instead of `.eslint.config.js`
if let Some(typ) = path.file_name().and_then(|typ| typ.to_str()) {
let maybe_path = get_icon_from_suffix(typ);
if maybe_path.is_some() {
return maybe_path;
}
// check if suffix based on first dot is in suffixes
// e.g. consider `module.js` as suffix to angular's module file named `auth.module.js`
while let Some((_, suffix)) = typ.split_once('.') {
let maybe_path = get_icon_from_suffix(suffix);
if maybe_path.is_some() {
return maybe_path;
}
typ = suffix;
}
}
// primary case: check if the files extension or the hidden file name

View File

@@ -51,7 +51,6 @@ ashpd.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
git = { workspace = true, features = ["test-support"] }
[features]
test-support = ["gpui/test-support", "git/test-support"]

View File

@@ -1,9 +1,8 @@
use crate::{FakeFs, FakeFsEntry, Fs};
use crate::{FakeFs, Fs};
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
use futures::future::{self, BoxFuture, join_all};
use git::{
Oid,
blame::Blame,
repository::{
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
@@ -11,9 +10,8 @@ use git::{
},
status::{FileStatus, GitStatus, StatusCode, TrackedStatus, UnmergedStatus},
};
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
use gpui::{AsyncApp, BackgroundExecutor, SharedString};
use ignore::gitignore::GitignoreBuilder;
use parking_lot::Mutex;
use rope::Rope;
use smol::future::FutureExt as _;
use std::{path::PathBuf, sync::Arc};
@@ -21,7 +19,6 @@ use std::{path::PathBuf, sync::Arc};
#[derive(Clone)]
pub struct FakeGitRepository {
pub(crate) fs: Arc<FakeFs>,
pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
pub(crate) executor: BackgroundExecutor,
pub(crate) dot_git_path: PathBuf,
pub(crate) repository_dir_path: PathBuf,
@@ -186,7 +183,7 @@ impl GitRepository for FakeGitRepository {
async move { None }.boxed()
}
fn status(&self, path_prefixes: &[RepoPath]) -> Task<Result<GitStatus>> {
fn status(&self, path_prefixes: &[RepoPath]) -> BoxFuture<'_, Result<GitStatus>> {
let workdir_path = self.dot_git_path.parent().unwrap();
// Load gitignores
@@ -314,10 +311,7 @@ impl GitRepository for FakeGitRepository {
entries: entries.into(),
})
});
Task::ready(match result {
Ok(result) => result,
Err(e) => Err(e),
})
async move { result? }.boxed()
}
fn branches(&self) -> BoxFuture<'_, Result<Vec<Branch>>> {
@@ -472,57 +466,22 @@ impl GitRepository for FakeGitRepository {
}
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
let executor = self.executor.clone();
let fs = self.fs.clone();
let checkpoints = self.checkpoints.clone();
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
async move {
executor.simulate_random_delay().await;
let oid = Oid::random(&mut executor.rng());
let entry = fs.entry(&repository_dir_path)?;
checkpoints.lock().insert(oid, entry);
Ok(GitRepositoryCheckpoint { commit_sha: oid })
}
.boxed()
unimplemented!()
}
fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
let executor = self.executor.clone();
let fs = self.fs.clone();
let checkpoints = self.checkpoints.clone();
let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
async move {
executor.simulate_random_delay().await;
let checkpoints = checkpoints.lock();
let entry = checkpoints
.get(&checkpoint.commit_sha)
.context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
fs.insert_entry(&repository_dir_path, entry.clone())?;
Ok(())
}
.boxed()
fn restore_checkpoint(
&self,
_checkpoint: GitRepositoryCheckpoint,
) -> BoxFuture<'_, Result<()>> {
unimplemented!()
}
fn compare_checkpoints(
&self,
left: GitRepositoryCheckpoint,
right: GitRepositoryCheckpoint,
_left: GitRepositoryCheckpoint,
_right: GitRepositoryCheckpoint,
) -> BoxFuture<'_, Result<bool>> {
let executor = self.executor.clone();
let checkpoints = self.checkpoints.clone();
async move {
executor.simulate_random_delay().await;
let checkpoints = checkpoints.lock();
let left = checkpoints
.get(&left.commit_sha)
.context(format!("invalid left checkpoint: {}", left.commit_sha))?;
let right = checkpoints
.get(&right.commit_sha)
.context(format!("invalid right checkpoint: {}", right.commit_sha))?;
Ok(left == right)
}
.boxed()
unimplemented!()
}
fn diff_checkpoints(
@@ -537,63 +496,3 @@ impl GitRepository for FakeGitRepository {
unimplemented!()
}
}
#[cfg(test)]
mod tests {
use crate::{FakeFs, Fs};
use gpui::BackgroundExecutor;
use serde_json::json;
use std::path::Path;
use util::path;
#[gpui::test]
async fn test_checkpoints(executor: BackgroundExecutor) {
let fs = FakeFs::new(executor);
fs.insert_tree(
path!("/"),
json!({
"bar": {
"baz": "qux"
},
"foo": {
".git": {},
"a": "lorem",
"b": "ipsum",
},
}),
)
.await;
fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
.unwrap();
let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
let checkpoint_1 = repository.checkpoint().await.unwrap();
fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
let checkpoint_2 = repository.checkpoint().await.unwrap();
let checkpoint_3 = repository.checkpoint().await.unwrap();
assert!(
repository
.compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
.await
.unwrap()
);
assert!(
!repository
.compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
.await
.unwrap()
);
repository.restore_checkpoint(checkpoint_1).await.unwrap();
assert_eq!(
fs.files_with_contents(Path::new("")),
[
(Path::new("/bar/baz").into(), b"qux".into()),
(Path::new("/foo/a").into(), b"lorem".into()),
(Path::new("/foo/b").into(), b"ipsum".into())
]
);
}
}

View File

@@ -924,7 +924,7 @@ pub struct FakeFs {
#[cfg(any(test, feature = "test-support"))]
struct FakeFsState {
root: FakeFsEntry,
root: Arc<Mutex<FakeFsEntry>>,
next_inode: u64,
next_mtime: SystemTime,
git_event_tx: smol::channel::Sender<PathBuf>,
@@ -939,7 +939,7 @@ struct FakeFsState {
}
#[cfg(any(test, feature = "test-support"))]
#[derive(Clone, Debug)]
#[derive(Debug)]
enum FakeFsEntry {
File {
inode: u64,
@@ -953,7 +953,7 @@ enum FakeFsEntry {
inode: u64,
mtime: MTime,
len: u64,
entries: BTreeMap<String, FakeFsEntry>,
entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
},
Symlink {
@@ -961,67 +961,6 @@ enum FakeFsEntry {
},
}
#[cfg(any(test, feature = "test-support"))]
impl PartialEq for FakeFsEntry {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
Self::File {
inode: l_inode,
mtime: l_mtime,
len: l_len,
content: l_content,
git_dir_path: l_git_dir_path,
},
Self::File {
inode: r_inode,
mtime: r_mtime,
len: r_len,
content: r_content,
git_dir_path: r_git_dir_path,
},
) => {
l_inode == r_inode
&& l_mtime == r_mtime
&& l_len == r_len
&& l_content == r_content
&& l_git_dir_path == r_git_dir_path
}
(
Self::Dir {
inode: l_inode,
mtime: l_mtime,
len: l_len,
entries: l_entries,
git_repo_state: l_git_repo_state,
},
Self::Dir {
inode: r_inode,
mtime: r_mtime,
len: r_len,
entries: r_entries,
git_repo_state: r_git_repo_state,
},
) => {
let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
(Some(l), Some(r)) => Arc::ptr_eq(l, r),
(None, None) => true,
_ => false,
};
l_inode == r_inode
&& l_mtime == r_mtime
&& l_len == r_len
&& l_entries == r_entries
&& same_repo_state
}
(Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
l_target == r_target
}
_ => false,
}
}
}
#[cfg(any(test, feature = "test-support"))]
impl FakeFsState {
fn get_and_increment_mtime(&mut self) -> MTime {
@@ -1036,9 +975,25 @@ impl FakeFsState {
inode
}
fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
let mut canonical_path = PathBuf::new();
fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
Ok(self
.try_read_path(target, true)
.ok_or_else(|| {
anyhow!(io::Error::new(
io::ErrorKind::NotFound,
format!("not found: {target:?}")
))
})?
.0)
}
fn try_read_path(
&self,
target: &Path,
follow_symlink: bool,
) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
let mut path = target.to_path_buf();
let mut canonical_path = PathBuf::new();
let mut entry_stack = Vec::new();
'outer: loop {
let mut path_components = path.components().peekable();
@@ -1048,7 +1003,7 @@ impl FakeFsState {
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
Component::RootDir => {
entry_stack.clear();
entry_stack.push(&self.root);
entry_stack.push(self.root.clone());
canonical_path.clear();
match prefix {
Some(prefix_component) => {
@@ -1065,18 +1020,20 @@ impl FakeFsState {
canonical_path.pop();
}
Component::Normal(name) => {
let current_entry = *entry_stack.last()?;
if let FakeFsEntry::Dir { entries, .. } = current_entry {
let entry = entries.get(name.to_str().unwrap())?;
let current_entry = entry_stack.last().cloned()?;
let current_entry = current_entry.lock();
if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
let entry = entries.get(name.to_str().unwrap()).cloned()?;
if path_components.peek().is_some() || follow_symlink {
if let FakeFsEntry::Symlink { target, .. } = entry {
let entry = entry.lock();
if let FakeFsEntry::Symlink { target, .. } = &*entry {
let mut target = target.clone();
target.extend(path_components);
path = target;
continue 'outer;
}
}
entry_stack.push(entry);
entry_stack.push(entry.clone());
canonical_path = canonical_path.join(name);
} else {
return None;
@@ -1086,72 +1043,19 @@ impl FakeFsState {
}
break;
}
if entry_stack.is_empty() {
None
} else {
Some(canonical_path)
}
Some((entry_stack.pop()?, canonical_path))
}
fn try_entry(
&mut self,
target: &Path,
follow_symlink: bool,
) -> Option<(&mut FakeFsEntry, PathBuf)> {
let canonical_path = self.canonicalize(target, follow_symlink)?;
let mut components = canonical_path.components();
let Some(Component::RootDir) = components.next() else {
panic!(
"the path {:?} was not canonicalized properly {:?}",
target, canonical_path
)
};
let mut entry = &mut self.root;
for component in components {
match component {
Component::Normal(name) => {
if let FakeFsEntry::Dir { entries, .. } = entry {
entry = entries.get_mut(name.to_str().unwrap())?;
} else {
return None;
}
}
_ => {
panic!(
"the path {:?} was not canonicalized properly {:?}",
target, canonical_path
)
}
}
}
Some((entry, canonical_path))
}
fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
Ok(self
.try_entry(target, true)
.ok_or_else(|| {
anyhow!(io::Error::new(
io::ErrorKind::NotFound,
format!("not found: {target:?}")
))
})?
.0)
}
fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
where
Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
{
let path = normalize_path(path);
let filename = path.file_name().context("cannot overwrite the root")?;
let parent_path = path.parent().unwrap();
let parent = self.entry(parent_path)?;
let parent = self.read_path(parent_path)?;
let mut parent = parent.lock();
let new_entry = parent
.dir_entries(parent_path)?
.entry(filename.to_str().unwrap().into());
@@ -1201,13 +1105,13 @@ impl FakeFs {
this: this.clone(),
executor: executor.clone(),
state: Arc::new(Mutex::new(FakeFsState {
root: FakeFsEntry::Dir {
root: Arc::new(Mutex::new(FakeFsEntry::Dir {
inode: 0,
mtime: MTime(UNIX_EPOCH),
len: 0,
entries: Default::default(),
git_repo_state: None,
},
})),
git_event_tx: tx,
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
next_inode: 1,
@@ -1257,15 +1161,15 @@ impl FakeFs {
.write_path(path, move |entry| {
match entry {
btree_map::Entry::Vacant(e) => {
e.insert(FakeFsEntry::File {
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
inode: new_inode,
mtime: new_mtime,
content: Vec::new(),
len: 0,
git_dir_path: None,
});
})));
}
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
FakeFsEntry::Symlink { .. } => {}
@@ -1284,7 +1188,7 @@ impl FakeFs {
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
let mut state = self.state.lock();
let path = path.as_ref();
let file = FakeFsEntry::Symlink { target };
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
state
.write_path(path.as_ref(), move |e| match e {
btree_map::Entry::Vacant(e) => {
@@ -1317,13 +1221,13 @@ impl FakeFs {
match entry {
btree_map::Entry::Vacant(e) => {
kind = Some(PathEventKind::Created);
e.insert(FakeFsEntry::File {
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
inode: new_inode,
mtime: new_mtime,
len: new_len,
content: new_content,
git_dir_path: None,
});
})));
}
btree_map::Entry::Occupied(mut e) => {
kind = Some(PathEventKind::Changed);
@@ -1333,7 +1237,7 @@ impl FakeFs {
len,
content,
..
} = e.get_mut()
} = &mut *e.get_mut().lock()
{
*mtime = new_mtime;
*content = new_content;
@@ -1355,8 +1259,9 @@ impl FakeFs {
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
let path = path.as_ref();
let path = normalize_path(path);
let mut state = self.state.lock();
let entry = state.entry(&path)?;
let state = self.state.lock();
let entry = state.read_path(&path)?;
let entry = entry.lock();
entry.file_content(&path).cloned()
}
@@ -1364,8 +1269,9 @@ impl FakeFs {
let path = path.as_ref();
let path = normalize_path(path);
self.simulate_random_delay().await;
let mut state = self.state.lock();
let entry = state.entry(&path)?;
let state = self.state.lock();
let entry = state.read_path(&path)?;
let entry = entry.lock();
entry.file_content(&path).cloned()
}
@@ -1386,25 +1292,6 @@ impl FakeFs {
self.state.lock().flush_events(count);
}
pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
self.state.lock().entry(target).cloned()
}
pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
let mut state = self.state.lock();
state.write_path(target, |entry| {
match entry {
btree_map::Entry::Vacant(vacant_entry) => {
vacant_entry.insert(new_entry);
}
btree_map::Entry::Occupied(mut occupied_entry) => {
occupied_entry.insert(new_entry);
}
}
Ok(())
})
}
#[must_use]
pub fn insert_tree<'a>(
&'a self,
@@ -1474,19 +1361,20 @@ impl FakeFs {
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
{
let mut state = self.state.lock();
let git_event_tx = state.git_event_tx.clone();
let entry = state.entry(dot_git).context("open .git")?;
let entry = state.read_path(dot_git).context("open .git")?;
let mut entry = entry.lock();
if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
let repo_state = git_repo_state.get_or_insert_with(|| {
log::debug!("insert git state for {dot_git:?}");
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
Arc::new(Mutex::new(FakeGitRepositoryState::new(
state.git_event_tx.clone(),
)))
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, dot_git, dot_git);
drop(repo_state);
if emit_git_event {
state.emit_event([(dot_git, None)]);
}
@@ -1510,20 +1398,21 @@ impl FakeFs {
}
}
.clone();
let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
drop(entry);
let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
anyhow::bail!("pointed-to git dir {path:?} not found")
};
let FakeFsEntry::Dir {
git_repo_state,
entries,
..
} = git_dir_entry
} = &mut *git_dir_entry.lock()
else {
anyhow::bail!("gitfile points to a non-directory")
};
let common_dir = if let Some(child) = entries.get("commondir") {
Path::new(
std::str::from_utf8(child.file_content("commondir".as_ref())?)
std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
.context("commondir content")?,
)
.to_owned()
@@ -1531,14 +1420,15 @@ impl FakeFs {
canonical_path.clone()
};
let repo_state = git_repo_state.get_or_insert_with(|| {
Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
Arc::new(Mutex::new(FakeGitRepositoryState::new(
state.git_event_tx.clone(),
)))
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, &canonical_path, &common_dir);
if emit_git_event {
drop(repo_state);
state.emit_event([(canonical_path, None)]);
}
@@ -1765,12 +1655,14 @@ impl FakeFs {
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
while let Some((path, entry)) = queue.pop_front() {
if let FakeFsEntry::Dir { entries, .. } = entry {
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
for (name, entry) in entries {
queue.push_back((path.join(name), entry));
queue.push_back((path.join(name), entry.clone()));
}
}
if include_dot_git
@@ -1787,12 +1679,14 @@ impl FakeFs {
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
while let Some((path, entry)) = queue.pop_front() {
if let FakeFsEntry::Dir { entries, .. } = entry {
if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
for (name, entry) in entries {
queue.push_back((path.join(name), entry));
queue.push_back((path.join(name), entry.clone()));
}
if include_dot_git
|| !path
@@ -1809,14 +1703,17 @@ impl FakeFs {
pub fn files(&self) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
while let Some((path, entry)) = queue.pop_front() {
match entry {
let e = entry.lock();
match &*e {
FakeFsEntry::File { .. } => result.push(path),
FakeFsEntry::Dir { entries, .. } => {
for (name, entry) in entries {
queue.push_back((path.join(name), entry));
queue.push_back((path.join(name), entry.clone()));
}
}
FakeFsEntry::Symlink { .. } => {}
@@ -1828,10 +1725,13 @@ impl FakeFs {
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
let state = &*self.state.lock();
queue.push_back((PathBuf::from(util::path!("/")), &state.root));
queue.push_back((
PathBuf::from(util::path!("/")),
self.state.lock().root.clone(),
));
while let Some((path, entry)) = queue.pop_front() {
match entry {
let e = entry.lock();
match &*e {
FakeFsEntry::File { content, .. } => {
if path.starts_with(prefix) {
result.push((path, content.clone()));
@@ -1839,7 +1739,7 @@ impl FakeFs {
}
FakeFsEntry::Dir { entries, .. } => {
for (name, entry) in entries {
queue.push_back((path.join(name), entry));
queue.push_back((path.join(name), entry.clone()));
}
}
FakeFsEntry::Symlink { .. } => {}
@@ -1905,7 +1805,10 @@ impl FakeFsEntry {
}
}
fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
fn dir_entries(
&mut self,
path: &Path,
) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
if let Self::Dir { entries, .. } = self {
Ok(entries)
} else {
@@ -1952,12 +1855,12 @@ struct FakeHandle {
impl FileHandle for FakeHandle {
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
let fs = fs.as_fake();
let mut state = fs.state.lock();
let Some(target) = state.moves.get(&self.inode).cloned() else {
let state = fs.state.lock();
let Some(target) = state.moves.get(&self.inode) else {
anyhow::bail!("fake fd not moved")
};
if state.try_entry(&target, false).is_some() {
if state.try_read_path(&target, false).is_some() {
return Ok(target.clone());
}
anyhow::bail!("fake fd target not found")
@@ -1985,13 +1888,13 @@ impl Fs for FakeFs {
state.write_path(&cur_path, |entry| {
entry.or_insert_with(|| {
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
FakeFsEntry::Dir {
Arc::new(Mutex::new(FakeFsEntry::Dir {
inode,
mtime,
len: 0,
entries: Default::default(),
git_repo_state: None,
}
}))
});
Ok(())
})?
@@ -2006,13 +1909,13 @@ impl Fs for FakeFs {
let mut state = self.state.lock();
let inode = state.get_and_increment_inode();
let mtime = state.get_and_increment_mtime();
let file = FakeFsEntry::File {
let file = Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
len: 0,
content: Vec::new(),
git_dir_path: None,
};
}));
let mut kind = Some(PathEventKind::Created);
state.write_path(path, |entry| {
match entry {
@@ -2036,7 +1939,7 @@ impl Fs for FakeFs {
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
let mut state = self.state.lock();
let file = FakeFsEntry::Symlink { target };
let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
state
.write_path(path.as_ref(), move |e| match e {
btree_map::Entry::Vacant(e) => {
@@ -2099,7 +2002,7 @@ impl Fs for FakeFs {
}
})?;
let inode = match moved_entry {
let inode = match *moved_entry.lock() {
FakeFsEntry::File { inode, .. } => inode,
FakeFsEntry::Dir { inode, .. } => inode,
_ => 0,
@@ -2148,8 +2051,8 @@ impl Fs for FakeFs {
let mut state = self.state.lock();
let mtime = state.get_and_increment_mtime();
let inode = state.get_and_increment_inode();
let source_entry = state.entry(&source)?;
let content = source_entry.file_content(&source)?.clone();
let source_entry = state.read_path(&source)?;
let content = source_entry.lock().file_content(&source)?.clone();
let mut kind = Some(PathEventKind::Created);
state.write_path(&target, |e| match e {
btree_map::Entry::Occupied(e) => {
@@ -2163,13 +2066,13 @@ impl Fs for FakeFs {
}
}
btree_map::Entry::Vacant(e) => Ok(Some(
e.insert(FakeFsEntry::File {
e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
inode,
mtime,
len: content.len() as u64,
content,
git_dir_path: None,
})
})))
.clone(),
)),
})?;
@@ -2185,7 +2088,8 @@ impl Fs for FakeFs {
let base_name = path.file_name().context("cannot remove the root")?;
let mut state = self.state.lock();
let parent_entry = state.entry(parent_path)?;
let parent_entry = state.read_path(parent_path)?;
let mut parent_entry = parent_entry.lock();
let entry = parent_entry
.dir_entries(parent_path)?
.entry(base_name.to_str().unwrap().into());
@@ -2196,14 +2100,15 @@ impl Fs for FakeFs {
anyhow::bail!("{path:?} does not exist");
}
}
btree_map::Entry::Occupied(mut entry) => {
btree_map::Entry::Occupied(e) => {
{
let children = entry.get_mut().dir_entries(&path)?;
let mut entry = e.get().lock();
let children = entry.dir_entries(&path)?;
if !options.recursive && !children.is_empty() {
anyhow::bail!("{path:?} is not empty");
}
}
entry.remove();
e.remove();
}
}
state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2217,7 +2122,8 @@ impl Fs for FakeFs {
let parent_path = path.parent().context("cannot remove the root")?;
let base_name = path.file_name().unwrap();
let mut state = self.state.lock();
let parent_entry = state.entry(parent_path)?;
let parent_entry = state.read_path(parent_path)?;
let mut parent_entry = parent_entry.lock();
let entry = parent_entry
.dir_entries(parent_path)?
.entry(base_name.to_str().unwrap().into());
@@ -2227,9 +2133,9 @@ impl Fs for FakeFs {
anyhow::bail!("{path:?} does not exist");
}
}
btree_map::Entry::Occupied(mut entry) => {
entry.get_mut().file_content(&path)?;
entry.remove();
btree_map::Entry::Occupied(e) => {
e.get().lock().file_content(&path)?;
e.remove();
}
}
state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2243,10 +2149,12 @@ impl Fs for FakeFs {
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
self.simulate_random_delay().await;
let mut state = self.state.lock();
let inode = match state.entry(&path)? {
FakeFsEntry::File { inode, .. } => *inode,
FakeFsEntry::Dir { inode, .. } => *inode,
let state = self.state.lock();
let entry = state.read_path(&path)?;
let entry = entry.lock();
let inode = match *entry {
FakeFsEntry::File { inode, .. } => inode,
FakeFsEntry::Dir { inode, .. } => inode,
_ => unreachable!(),
};
Ok(Arc::new(FakeHandle { inode }))
@@ -2296,8 +2204,8 @@ impl Fs for FakeFs {
let path = normalize_path(path);
self.simulate_random_delay().await;
let state = self.state.lock();
let canonical_path = state
.canonicalize(&path, true)
let (_, canonical_path) = state
.try_read_path(&path, true)
.with_context(|| format!("path does not exist: {path:?}"))?;
Ok(canonical_path)
}
@@ -2305,9 +2213,9 @@ impl Fs for FakeFs {
async fn is_file(&self, path: &Path) -> bool {
let path = normalize_path(path);
self.simulate_random_delay().await;
let mut state = self.state.lock();
if let Some((entry, _)) = state.try_entry(&path, true) {
entry.is_file()
let state = self.state.lock();
if let Some((entry, _)) = state.try_read_path(&path, true) {
entry.lock().is_file()
} else {
false
}
@@ -2324,16 +2232,17 @@ impl Fs for FakeFs {
let path = normalize_path(path);
let mut state = self.state.lock();
state.metadata_call_count += 1;
if let Some((mut entry, _)) = state.try_entry(&path, false) {
let is_symlink = entry.is_symlink();
if let Some((mut entry, _)) = state.try_read_path(&path, false) {
let is_symlink = entry.lock().is_symlink();
if is_symlink {
if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
entry = e;
} else {
return Ok(None);
}
}
let entry = entry.lock();
Ok(Some(match &*entry {
FakeFsEntry::File {
inode, mtime, len, ..
@@ -2365,11 +2274,12 @@ impl Fs for FakeFs {
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
self.simulate_random_delay().await;
let path = normalize_path(path);
let mut state = self.state.lock();
let state = self.state.lock();
let (entry, _) = state
.try_entry(&path, false)
.try_read_path(&path, false)
.with_context(|| format!("path does not exist: {path:?}"))?;
if let FakeFsEntry::Symlink { target } = entry {
let entry = entry.lock();
if let FakeFsEntry::Symlink { target } = &*entry {
Ok(target.clone())
} else {
anyhow::bail!("not a symlink: {path:?}")
@@ -2384,7 +2294,8 @@ impl Fs for FakeFs {
let path = normalize_path(path);
let mut state = self.state.lock();
state.read_dir_call_count += 1;
let entry = state.entry(&path)?;
let entry = state.read_path(&path)?;
let mut entry = entry.lock();
let children = entry.dir_entries(&path)?;
let paths = children
.keys()
@@ -2448,7 +2359,6 @@ impl Fs for FakeFs {
dot_git_path: abs_dot_git.to_path_buf(),
repository_dir_path: repository_dir_path.to_owned(),
common_dir_path: common_dir_path.to_owned(),
checkpoints: Arc::default(),
}) as _
},
)

View File

@@ -12,7 +12,7 @@ workspace = true
path = "src/git.rs"
[features]
test-support = ["rand"]
test-support = []
[dependencies]
anyhow.workspace = true
@@ -26,7 +26,6 @@ http_client.workspace = true
log.workspace = true
parking_lot.workspace = true
regex.workspace = true
rand = { workspace = true, optional = true }
rope.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -48,4 +47,3 @@ text = { workspace = true, features = ["test-support"] }
unindent.workspace = true
gpui = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
rand.workspace = true

View File

@@ -73,7 +73,6 @@ async fn run_git_blame(
.current_dir(working_directory)
.arg("blame")
.arg("--incremental")
.arg("-w")
.arg("--contents")
.arg("-")
.arg(path.as_os_str())

Some files were not shown because too many files have changed in this diff Show More