Compare commits

..

6 Commits

Author SHA1 Message Date
Richard Feldman
4bfc8954f3 Switch back to getting mistralrs from GitHub 2025-07-29 22:07:19 -04:00
Richard Feldman
18ca69f07f Get a smaller model working 2025-07-29 21:51:06 -04:00
Richard Feldman
f90459656f Fix local model authentication 2025-07-29 18:56:27 -04:00
Richard Feldman
5830628568 Trying local mistralrs build 2025-07-29 18:24:33 -04:00
Richard Feldman
f62e693b8f Add local model provider 2025-07-29 15:39:36 -04:00
Richard Feldman
4abdec044f Have agent servers respect always_allow_tool_actions 2025-07-29 15:29:51 -04:00
148 changed files with 5522 additions and 8699 deletions

View File

@@ -771,8 +771,7 @@ jobs:
timeout-minutes: 120
name: Create a Windows installer
runs-on: [self-hosted, Windows, X64]
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
# if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
needs: [windows_tests]
env:
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}

4824
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,13 @@
[workspace]
resolver = "2"
members = [
"crates/acp_thread",
"crates/activity_indicator",
"crates/agent",
"crates/agent_servers",
"crates/agent_settings",
"crates/acp_thread",
"crates/agent_ui",
"crates/agent",
"crates/agent_settings",
"crates/ai_onboarding",
"crates/agent_servers",
"crates/anthropic",
"crates/askpass",
"crates/assets",
@@ -29,9 +29,6 @@ members = [
"crates/cli",
"crates/client",
"crates/clock",
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -51,8 +48,8 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/editor",
"crates/eval",
"crates/explorer_command_injector",
"crates/eval",
"crates/extension",
"crates/extension_api",
"crates/extension_cli",
@@ -73,6 +70,7 @@ members = [
"crates/gpui",
"crates/gpui_macros",
"crates/gpui_tokio",
"crates/html_to_markdown",
"crates/http_client",
"crates/http_client_tls",
@@ -101,6 +99,7 @@ members = [
"crates/markdown_preview",
"crates/media",
"crates/menu",
"crates/svg_preview",
"crates/migrator",
"crates/mistral",
"crates/multi_buffer",
@@ -141,7 +140,6 @@ members = [
"crates/semantic_version",
"crates/session",
"crates/settings",
"crates/settings_profile_selector",
"crates/settings_ui",
"crates/snippet",
"crates/snippet_provider",
@@ -154,7 +152,6 @@ members = [
"crates/sum_tree",
"crates/supermaven",
"crates/supermaven_api",
"crates/svg_preview",
"crates/tab_switcher",
"crates/task",
"crates/tasks_ui",
@@ -254,9 +251,6 @@ channel = { path = "crates/channel" }
cli = { path = "crates/cli" }
client = { path = "crates/client" }
clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
collab = { path = "crates/collab" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
@@ -343,7 +337,6 @@ picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
plugin_macros = { path = "crates/plugin_macros" }
prettier = { path = "crates/prettier" }
settings_profile_selector = { path = "crates/settings_profile_selector" }
project = { path = "crates/project" }
project_panel = { path = "crates/project_panel" }
project_symbols = { path = "crates/project_symbols" }
@@ -435,7 +428,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
async-recursion = "1.0.0"
async-tar = "0.5.0"
async-trait = "0.1"
async-tungstenite = "0.29.1"
async-tungstenite = "0.30.0"
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
aws-credential-types = { version = "1.2.2", features = [
@@ -522,7 +515,7 @@ objc = "0.2"
open = "5.0.0"
ordered-float = "2.1.1"
palette = { version = "0.7.5", default-features = false, features = ["std"] }
parking_lot = "0.12.1"
parking_lot = "0.12.4"
partial-json-fixer = "0.5.3"
parse_int = "0.9"
pathdiff = "0.2"
@@ -652,6 +645,7 @@ which = "6.0.0"
windows-core = "0.61"
wit-component = "0.221"
workspace-hack = "0.1.0"
zed_llm_client = "= 0.8.6"
zstd = "0.11"
[workspace.dependencies.async-stripe]
@@ -680,13 +674,8 @@ features = [
"Win32_Globalization",
"Win32_Graphics_Direct2D",
"Win32_Graphics_Direct2D_Common",
"Win32_Graphics_Direct3D",
"Win32_Graphics_Direct3D11",
"Win32_Graphics_Direct3D_Fxc",
"Win32_Graphics_DirectComposition",
"Win32_Graphics_DirectWrite",
"Win32_Graphics_Dwm",
"Win32_Graphics_Dxgi",
"Win32_Graphics_Dxgi_Common",
"Win32_Graphics_Gdi",
"Win32_Graphics_Imaging",

View File

@@ -232,7 +232,7 @@
"ctrl-n": "agent::NewThread",
"ctrl-alt-n": "agent::NewTextThread",
"ctrl-shift-h": "agent::OpenHistory",
"ctrl-alt-c": "agent::OpenSettings",
"ctrl-alt-c": "agent::OpenConfiguration",
"ctrl-alt-p": "agent::OpenRulesLibrary",
"ctrl-i": "agent::ToggleProfileSelector",
"ctrl-alt-/": "agent::ToggleModelSelector",

View File

@@ -272,7 +272,7 @@
"cmd-n": "agent::NewThread",
"cmd-alt-n": "agent::NewTextThread",
"cmd-shift-h": "agent::OpenHistory",
"cmd-alt-c": "agent::OpenSettings",
"cmd-alt-c": "agent::OpenConfiguration",
"cmd-alt-p": "agent::OpenRulesLibrary",
"cmd-i": "agent::ToggleProfileSelector",
"cmd-alt-/": "agent::ToggleModelSelector",

View File

@@ -8,7 +8,7 @@
"ctrl-shift-i": "agent::ToggleFocus",
"ctrl-l": "agent::ToggleFocus",
"ctrl-shift-l": "agent::ToggleFocus",
"ctrl-shift-j": "agent::OpenSettings"
"ctrl-shift-j": "agent::OpenConfiguration"
}
},
{

View File

@@ -8,7 +8,7 @@
"cmd-shift-i": "agent::ToggleFocus",
"cmd-l": "agent::ToggleFocus",
"cmd-shift-l": "agent::ToggleFocus",
"cmd-shift-j": "agent::OpenSettings"
"cmd-shift-j": "agent::OpenConfiguration"
}
},
{

View File

@@ -1877,8 +1877,5 @@
"save_breakpoints": true,
"dock": "bottom",
"button": true
},
// Configures any number of settings profiles that are temporarily applied
// when selected from `settings profile selector: toggle`.
"profiles": []
}
}

View File

@@ -25,7 +25,6 @@ assistant_context.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
context_server.workspace = true
@@ -36,9 +35,9 @@ futures.workspace = true
git.workspace = true
gpui.workspace = true
heed.workspace = true
http_client.workspace = true
icons.workspace = true
indoc.workspace = true
http_client.workspace = true
itertools.workspace = true
language.workspace = true
language_model.workspace = true
@@ -64,6 +63,7 @@ time.workspace = true
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
zstd.workspace = true
[dev-dependencies]

View File

@@ -13,7 +13,6 @@ use anyhow::{Result, anyhow};
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use collections::HashMap;
use feature_flags::{self, FeatureFlagAppExt};
use futures::{FutureExt, StreamExt as _, future::Shared};
@@ -50,6 +49,7 @@ use std::{
use thiserror::Error;
use util::{ResultExt as _, post_inc};
use uuid::Uuid;
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
const MAX_RETRY_ATTEMPTS: u8 = 4;
const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
@@ -1681,7 +1681,7 @@ impl Thread {
let completion_mode = request
.mode
.unwrap_or(cloud_llm_client::CompletionMode::Normal);
.unwrap_or(zed_llm_client::CompletionMode::Normal);
self.last_received_chunk_at = Some(Instant::now());

View File

@@ -19,6 +19,7 @@ doctest = false
[dependencies]
acp_thread.workspace = true
agent-client-protocol.workspace = true
agent_settings.workspace = true
agentic-coding-protocol.workspace = true
anyhow.workspace = true
collections.workspace = true

View File

@@ -3,6 +3,7 @@ use std::path::PathBuf;
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
use acp_thread::AcpThread;
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use anyhow::{Context, Result};
use collections::HashMap;
use context_server::listener::{McpServerTool, ToolResponse};
@@ -13,6 +14,7 @@ use context_server::types::{
use gpui::{App, AsyncApp, Task, WeakEntity};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
pub struct ClaudeZedMcpServer {
server: context_server::listener::McpServer,
@@ -114,6 +116,7 @@ pub struct PermissionToolParams {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
#[cfg_attr(test, derive(serde::Deserialize))]
pub struct PermissionToolResponse {
behavior: PermissionToolBehavior,
updated_input: serde_json::Value,
@@ -121,7 +124,8 @@ pub struct PermissionToolResponse {
#[derive(Serialize)]
#[serde(rename_all = "snake_case")]
enum PermissionToolBehavior {
#[cfg_attr(test, derive(serde::Deserialize))]
pub enum PermissionToolBehavior {
Allow,
Deny,
}
@@ -141,6 +145,26 @@ impl McpServerTool for PermissionTool {
input: Self::Input,
cx: &mut AsyncApp,
) -> Result<ToolResponse<Self::Output>> {
// Check if we should automatically allow tool actions
let always_allow =
cx.update(|cx| AgentSettings::get_global(cx).always_allow_tool_actions)?;
if always_allow {
// If always_allow_tool_actions is true, immediately return Allow without prompting
let response = PermissionToolResponse {
behavior: PermissionToolBehavior::Allow,
updated_input: input.input,
};
return Ok(ToolResponse {
content: vec![ToolResponseContent::Text {
text: serde_json::to_string(&response)?,
}],
structured_content: (),
});
}
// Otherwise, proceed with the normal permission flow
let mut thread_rx = self.thread_rx.clone();
let Some(thread) = thread_rx.recv().await?.upgrade() else {
anyhow::bail!("Thread closed");
@@ -300,3 +324,78 @@ impl McpServerTool for EditTool {
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use project::Project;
use settings::{Settings, SettingsStore};
#[gpui::test]
async fn test_permission_tool_respects_always_allow_setting(cx: &mut TestAppContext) {
// Initialize settings
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
agent_settings::init(cx);
});
// Create a test thread
let project = cx.update(|cx| gpui::Entity::new(cx, |_cx| Project::local()));
let thread = cx.update(|cx| {
gpui::Entity::new(cx, |_cx| {
acp_thread::AcpThread::new(
acp::ConnectionId("test".into()),
project,
std::path::Path::new("/tmp"),
)
})
});
let (tx, rx) = watch::channel(thread.downgrade());
let tool = PermissionTool { thread_rx: rx };
// Test with always_allow_tool_actions = true
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
always_allow_tool_actions: true,
..Default::default()
},
cx,
);
});
let input = PermissionToolParams {
tool_name: "test_tool".to_string(),
input: serde_json::json!({"test": "data"}),
tool_use_id: Some("test_id".to_string()),
};
let result = tool.run(input.clone(), &mut cx.to_async()).await.unwrap();
// Should return Allow without prompting
assert_eq!(result.content.len(), 1);
if let ToolResponseContent::Text { text } = &result.content[0] {
let response: PermissionToolResponse = serde_json::from_str(text).unwrap();
assert!(matches!(response.behavior, PermissionToolBehavior::Allow));
} else {
panic!("Expected text response");
}
// Test with always_allow_tool_actions = false
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
always_allow_tool_actions: false,
..Default::default()
},
cx,
);
});
// This test would require mocking the permission prompt response
// In the real scenario, it would wait for user input
}
}

View File

@@ -7,6 +7,7 @@ use std::{
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use futures::{FutureExt, StreamExt, channel::mpsc, select};
use gpui::{Entity, TestAppContext};
@@ -241,6 +242,57 @@ pub async fn test_tool_call_with_confirmation(
});
}
pub async fn test_tool_call_always_allow(
server: impl AgentServer + 'static,
cx: &mut TestAppContext,
) {
let fs = init_test(cx).await;
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
// Enable always_allow_tool_actions
cx.update(|cx| {
AgentSettings::override_global(
AgentSettings {
always_allow_tool_actions: true,
..Default::default()
},
cx,
);
});
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
let full_turn = thread.update(cx, |thread, cx| {
thread.send_raw(
r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
cx,
)
});
// Wait for the tool call to complete
full_turn.await.unwrap();
thread.read_with(cx, |thread, _cx| {
// With always_allow_tool_actions enabled, the tool call should be immediately allowed
// without waiting for confirmation
let tool_call_entry = thread
.entries()
.iter()
.find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
.expect("Expected a tool call entry");
let AgentThreadEntry::ToolCall(tool_call) = tool_call_entry else {
panic!("Expected tool call entry");
};
// Should be allowed, not waiting for confirmation
assert!(
matches!(tool_call.status, ToolCallStatus::Allowed { .. }),
"Expected tool call to be allowed automatically, but got {:?}",
tool_call.status
);
});
}
pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
let fs = init_test(cx).await;
@@ -351,6 +403,12 @@ macro_rules! common_e2e_tests {
async fn cancel(cx: &mut ::gpui::TestAppContext) {
$crate::e2e_tests::test_cancel($server, cx).await;
}
#[::gpui::test]
#[cfg_attr(not(feature = "e2e"), ignore)]
async fn tool_call_always_allow(cx: &mut ::gpui::TestAppContext) {
$crate::e2e_tests::test_tool_call_always_allow($server, cx).await;
}
}
};
}

View File

@@ -13,7 +13,6 @@ path = "src/agent_settings.rs"
[dependencies]
anyhow.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
gpui.workspace = true
language_model.workspace = true
@@ -21,6 +20,7 @@ schemars.workspace = true
serde.workspace = true
settings.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
fs.workspace = true

View File

@@ -321,11 +321,11 @@ pub enum CompletionMode {
Burn,
}
impl From<CompletionMode> for cloud_llm_client::CompletionMode {
impl From<CompletionMode> for zed_llm_client::CompletionMode {
fn from(value: CompletionMode) -> Self {
match value {
CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal,
CompletionMode::Burn => cloud_llm_client::CompletionMode::Max,
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
}
}
}

View File

@@ -31,7 +31,6 @@ audio.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@@ -47,9 +46,9 @@ futures.workspace = true
fuzzy.workspace = true
gpui.workspace = true
html_to_markdown.workspace = true
indoc.workspace = true
http_client.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
inventory.workspace = true
itertools.workspace = true
jsonschema.workspace = true
@@ -98,6 +97,7 @@ watch.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
assistant_tools.workspace = true

View File

@@ -14,7 +14,6 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
use anyhow::Context as _;
use assistant_tool::ToolUseStatus;
use audio::{Audio, Sound};
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::scroll::Autoscroll;
@@ -53,6 +52,7 @@ use util::ResultExt as _;
use util::markdown::MarkdownCodeBlock;
use workspace::{CollaboratorId, Workspace};
use zed_actions::assistant::OpenRulesLibrary;
use zed_llm_client::CompletionIntent;
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;

View File

@@ -38,6 +38,13 @@ impl AgentModelSelector {
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
// Authenticate the provider when a model is selected
let registry = LanguageModelRegistry::read_global(cx);
if let Some(provider) = registry.provider(&model.provider_id()) {
provider.authenticate(cx).detach();
}
match &model_usage_context {
ModelUsageContext::Thread(thread) => {
thread.update(cx, |thread, cx| {

View File

@@ -44,7 +44,6 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
use assistant_slash_command::SlashCommandWorkingSet;
use assistant_tool::ToolWorkingSet;
use client::{DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::{CompletionIntent, UsageLimit};
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
use feature_flags::{self, FeatureFlagAppExt};
use fs::Fs;
@@ -78,9 +77,10 @@ use workspace::{
};
use zed_actions::{
DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize,
agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector},
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector},
assistant::{OpenRulesLibrary, ToggleFocus},
};
use zed_llm_client::{CompletionIntent, UsageLimit};
const AGENT_PANEL_KEY: &str = "agent_panel";
@@ -105,7 +105,7 @@ pub fn init(cx: &mut App) {
panel.update(cx, |panel, cx| panel.open_history(window, cx));
}
})
.register_action(|workspace, _: &OpenSettings, window, cx| {
.register_action(|workspace, _: &OpenConfiguration, window, cx| {
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
workspace.focus_panel::<AgentPanel>(window, cx);
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
@@ -2088,7 +2088,7 @@ impl AgentPanel {
menu = menu
.action("Rules…", Box::new(OpenRulesLibrary::default()))
.action("Settings", Box::new(OpenSettings))
.action("Settings", Box::new(OpenConfiguration))
.action(zoom_in_label, Box::new(ToggleZoom));
menu
}))
@@ -2482,14 +2482,14 @@ impl AgentPanel {
.icon_color(Color::Muted)
.full_width()
.key_binding(KeyBinding::for_action_in(
&OpenSettings,
&OpenConfiguration,
&focus_handle,
window,
cx,
))
.on_click(|_event, window, cx| {
window.dispatch_action(
OpenSettings.boxed_clone(),
OpenConfiguration.boxed_clone(),
cx,
)
}),
@@ -2713,11 +2713,16 @@ impl AgentPanel {
.style(ButtonStyle::Tinted(ui::TintColor::Warning))
.label_size(LabelSize::Small)
.key_binding(
KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx)
.map(|kb| kb.size(rems_from_px(12.))),
KeyBinding::for_action_in(
&OpenConfiguration,
&focus_handle,
window,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
.on_click(|_event, window, cx| {
window.dispatch_action(OpenSettings.boxed_clone(), cx)
window.dispatch_action(OpenConfiguration.boxed_clone(), cx)
}),
),
ConfigurationError::ProviderPendingTermsAcceptance(provider) => {
@@ -3221,7 +3226,7 @@ impl Render for AgentPanel {
.on_action(cx.listener(|this, _: &OpenHistory, window, cx| {
this.open_history(window, cx);
}))
.on_action(cx.listener(|this, _: &OpenSettings, window, cx| {
.on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| {
this.open_configuration(window, cx);
}))
.on_action(cx.listener(Self::open_active_thread_as_markdown))

View File

@@ -265,8 +265,8 @@ fn update_command_palette_filter(cx: &mut App) {
filter.hide_namespace("agent");
filter.hide_namespace("assistant");
filter.hide_namespace("copilot");
filter.hide_namespace("supermaven");
filter.hide_namespace("zed_predict_onboarding");
filter.hide_namespace("edit_prediction");
use editor::actions::{

View File

@@ -6,7 +6,6 @@ use agent::{
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
use futures::{
@@ -36,6 +35,7 @@ use std::{
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use zed_llm_client::CompletionIntent;
pub struct BufferCodegen {
alternatives: Vec<Entity<CodegenAlternative>>,

View File

@@ -1,10 +1,10 @@
#![allow(unused, dead_code)]
use client::{ModelRequestUsage, RequestUsage};
use cloud_llm_client::{Plan, UsageLimit};
use gpui::Global;
use std::ops::{Deref, DerefMut};
use ui::prelude::*;
use zed_llm_client::{Plan, UsageLimit};
/// Debug only: Used for testing various account states
///

View File

@@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _};
use ui::prelude::*;
use util::{RangeExt, ResultExt, maybe};
use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId};
use zed_actions::agent::OpenSettings;
use zed_actions::agent::OpenConfiguration;
pub fn init(
fs: Arc<dyn Fs>,
@@ -345,7 +345,7 @@ impl InlineAssistant {
if let Some(answer) = answer {
if answer == 0 {
cx.update(|window, cx| {
window.dispatch_action(Box::new(OpenSettings), cx)
window.dispatch_action(Box::new(OpenConfiguration), cx)
})
.ok();
}

View File

@@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.icon_position(IconPosition::Start)
.on_click(|_, window, cx| {
window.dispatch_action(
zed_actions::agent::OpenSettings.boxed_clone(),
zed_actions::agent::OpenConfiguration.boxed_clone(),
cx,
);
}),

View File

@@ -18,7 +18,6 @@ use agent_settings::{AgentSettings, CompletionMode};
use ai_onboarding::ApiKeysWithProviders;
use buffer_diff::BufferDiff;
use client::UserStore;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use editor::actions::{MoveUp, Paste};
use editor::display_map::CreaseId;
@@ -54,6 +53,7 @@ use util::ResultExt as _;
use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::Chat;
use zed_actions::agent::ToggleModelSelector;
use zed_llm_client::CompletionIntent;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@@ -1300,11 +1300,11 @@ impl MessageEditor {
let plan = user_store
.current_plan()
.map(|plan| match plan {
Plan::Free => cloud_llm_client::Plan::ZedFree,
Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
Plan::Free => zed_llm_client::Plan::ZedFree,
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
})
.unwrap_or(cloud_llm_client::Plan::ZedFree);
.unwrap_or(zed_llm_client::Plan::ZedFree);
let usage = user_store.model_request_usage()?;

View File

@@ -10,7 +10,6 @@ use agent::{
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, VecDeque};
use editor::{MultiBuffer, actions::SelectAll};
use fs::Fs;
@@ -28,6 +27,7 @@ use terminal_view::TerminalView;
use ui::prelude::*;
use util::ResultExt;
use workspace::{Toast, Workspace, notifications::NotificationId};
use zed_llm_client::CompletionIntent;
pub fn init(
fs: Arc<dyn Fs>,

View File

@@ -1,8 +1,8 @@
use client::{ModelRequestUsage, RequestUsage, zed_urls};
use cloud_llm_client::{Plan, UsageLimit};
use component::{empty_example, example_group_with_title, single_example};
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
use ui::{Callout, prelude::*};
use zed_llm_client::{Plan, UsageLimit};
#[derive(IntoElement, RegisterComponent)]
pub struct UsageCallout {

View File

@@ -136,7 +136,10 @@ impl RenderOnce for ApiKeysWithoutProviders {
.full_width()
.style(ButtonStyle::Outlined)
.on_click(move |_, window, cx| {
window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
window.dispatch_action(
zed_actions::agent::OpenConfiguration.boxed_clone(),
cx,
);
}),
)
}

View File

@@ -19,7 +19,6 @@ assistant_slash_commands.workspace = true
chrono.workspace = true
client.workspace = true
clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
context_server.workspace = true
fs.workspace = true
@@ -49,6 +48,7 @@ util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
indoc.workspace = true

View File

@@ -11,7 +11,6 @@ use assistant_slash_command::{
use assistant_slash_commands::FileCommandMetadata;
use client::{self, Client, proto, telemetry::Telemetry};
use clock::ReplicaId;
use cloud_llm_client::CompletionIntent;
use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared};
@@ -47,6 +46,7 @@ use text::{BufferSnapshot, ToPoint};
use ui::IconName;
use util::{ResultExt, TryFutureExt, post_inc};
use uuid::Uuid;
use zed_llm_client::CompletionIntent;
pub use crate::context_store::*;

View File

@@ -21,11 +21,9 @@ assistant_tool.workspace = true
buffer_diff.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
derive_more.workspace = true
diffy = "0.4.2"
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
@@ -65,6 +63,8 @@ web_search.workspace = true
which.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_llm_client.workspace = true
diffy = "0.4.2"
[dev-dependencies]
lsp = { workspace = true, features = ["test-support"] }

View File

@@ -7,7 +7,6 @@ mod streaming_fuzzy_matcher;
use crate::{Template, Templates};
use anyhow::Result;
use assistant_tool::ActionLog;
use cloud_llm_client::CompletionIntent;
use create_file_parser::{CreateFileParser, CreateFileParserEvent};
pub use edit_parser::EditFormat;
use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
@@ -30,6 +29,7 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::
use streaming_diff::{CharOperation, StreamingDiff};
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
use util::debug_panic;
use zed_llm_client::CompletionIntent;
#[derive(Serialize)]
struct CreateFilePromptTemplate {

View File

@@ -6,7 +6,6 @@ use anyhow::{Context as _, Result, anyhow};
use assistant_tool::{
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
};
use cloud_llm_client::{WebSearchResponse, WebSearchResult};
use futures::{Future, FutureExt, TryFutureExt};
use gpui::{
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
@@ -18,6 +17,7 @@ use serde::{Deserialize, Serialize};
use ui::{IconName, Tooltip, prelude::*};
use web_search::WebSearchRegistry;
use workspace::Workspace;
use zed_llm_client::{WebSearchResponse, WebSearchResult};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct WebSearchToolInput {

View File

@@ -18,6 +18,6 @@ collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
rodio = { version = "0.20.0", default-features = false, features = ["wav"] }
util.workspace = true
workspace-hack.workspace = true

View File

@@ -3,9 +3,12 @@ use std::{io::Cursor, sync::Arc};
use anyhow::{Context as _, Result};
use collections::HashMap;
use gpui::{App, AssetSource, Global};
use rodio::{Decoder, Source, source::Buffered};
use rodio::{
Decoder, Source,
source::{Buffered, SamplesConverter},
};
type Sound = Buffered<Decoder<Cursor<Vec<u8>>>>;
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
pub struct SoundRegistry {
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
@@ -45,7 +48,7 @@ impl SoundRegistry {
.with_context(|| format!("No asset available for path {path}"))??
.into_owned();
let cursor = Cursor::new(bytes);
let source = Decoder::new(cursor)?.buffered();
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
self.cache.lock().insert(name.to_string(), source.clone());

View File

@@ -1,7 +1,7 @@
use assets::SoundRegistry;
use derive_more::{Deref, DerefMut};
use gpui::{App, AssetSource, BorrowAppContext, Global};
use rodio::{OutputStream, OutputStreamBuilder};
use rodio::{OutputStream, OutputStreamHandle};
use util::ResultExt;
mod assets;
@@ -37,7 +37,8 @@ impl Sound {
#[derive(Default)]
pub struct Audio {
output_handle: Option<OutputStream>,
_output_stream: Option<OutputStream>,
output_handle: Option<OutputStreamHandle>,
}
#[derive(Deref, DerefMut)]
@@ -50,9 +51,11 @@ impl Audio {
Self::default()
}
fn ensure_output_exists(&mut self) -> Option<&OutputStream> {
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
if self.output_handle.is_none() {
self.output_handle = OutputStreamBuilder::open_default_stream().log_err();
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
self.output_handle = output_handle;
self._output_stream = _output_stream;
}
self.output_handle.as_ref()
@@ -66,7 +69,7 @@ impl Audio {
cx.update_global::<GlobalAudio, _>(|this, cx| {
let output_handle = this.ensure_output_exists()?;
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
output_handle.mixer().add(source);
output_handle.play_raw(source).log_err()?;
Some(())
});
}
@@ -77,6 +80,7 @@ impl Audio {
}
cx.update_global::<GlobalAudio, _>(|this, _| {
this._output_stream.take();
this.output_handle.take();
});
}

View File

@@ -22,8 +22,6 @@ async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manua
base64.workspace = true
chrono = { workspace = true, features = ["serde"] }
clock.workspace = true
cloud_api_client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
derive_more.workspace = true
@@ -35,8 +33,8 @@ http_client.workspace = true
http_client_tls.workspace = true
httparse = "1.10"
log.workspace = true
parking_lot.workspace = true
paths.workspace = true
parking_lot.workspace = true
postage.workspace = true
rand.workspace = true
regex.workspace = true
@@ -48,18 +46,19 @@ serde_json.workspace = true
settings.workspace = true
sha2.workspace = true
smol.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
tokio.workspace = true
url.workspace = true
util.workspace = true
workspace-hack.workspace = true
worktree.workspace = true
telemetry.workspace = true
tokio.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -1,7 +1,6 @@
#[cfg(any(test, feature = "test-support"))]
pub mod test;
mod cloud;
mod proxy;
pub mod telemetry;
pub mod user;
@@ -16,7 +15,6 @@ use async_tungstenite::tungstenite::{
};
use chrono::{DateTime, Utc};
use clock::SystemClock;
use cloud_api_client::CloudApiClient;
use credentials_provider::CredentialsProvider;
use futures::{
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
@@ -33,6 +31,7 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
use std::pin::Pin;
use std::{
any::TypeId,
convert::TryFrom,
@@ -46,14 +45,12 @@ use std::{
},
time::{Duration, Instant},
};
use std::{cmp, pin::Pin};
use telemetry::Telemetry;
use thiserror::Error;
use tokio::net::TcpStream;
use url::Url;
use util::{ConnectionResult, ResultExt};
pub use cloud::*;
pub use rpc::*;
pub use telemetry_events::Event;
pub use user::*;
@@ -81,7 +78,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
actions!(
@@ -216,7 +213,6 @@ pub struct Client {
id: AtomicU64,
peer: Arc<Peer>,
http: Arc<HttpClientWithUrl>,
cloud_client: Arc<CloudApiClient>,
telemetry: Arc<Telemetry>,
credentials_provider: ClientCredentialsProvider,
state: RwLock<ClientState>,
@@ -590,7 +586,6 @@ impl Client {
id: AtomicU64::new(0),
peer: Peer::new(0),
telemetry: Telemetry::new(clock, http.clone(), cx),
cloud_client: Arc::new(CloudApiClient::new(http.clone())),
http,
credentials_provider: ClientCredentialsProvider::new(cx),
state: Default::default(),
@@ -623,10 +618,6 @@ impl Client {
self.http.clone()
}
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
self.cloud_client.clone()
}
pub fn set_id(&self, id: u64) -> &Self {
self.id.store(id, Ordering::SeqCst);
self
@@ -736,10 +727,11 @@ impl Client {
},
&cx,
);
let jitter =
Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
cx.background_executor().timer(delay + jitter).await;
delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
cx.background_executor().timer(delay).await;
delay = delay
.mul_f32(rng.gen_range(0.5..=2.5))
.max(INITIAL_RECONNECTION_DELAY)
.min(MAX_RECONNECTION_DELAY);
} else {
break;
}
@@ -939,8 +931,6 @@ impl Client {
}
let credentials = credentials.unwrap();
self.set_id(credentials.user_id);
self.cloud_client
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
if was_disconnected {
self.set_status(Status::Connecting, cx);

View File

@@ -1,3 +0,0 @@
mod user_store;
pub use user_store::*;

View File

@@ -1,41 +0,0 @@
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context as _;
use cloud_api_client::{AuthenticatedUser, CloudApiClient};
use gpui::{Context, Task};
use util::{ResultExt as _, maybe};
pub struct CloudUserStore {
authenticated_user: Option<AuthenticatedUser>,
_fetch_authenticated_user_task: Task<()>,
}
impl CloudUserStore {
pub fn new(cloud_client: Arc<CloudApiClient>, cx: &mut Context<Self>) -> Self {
Self {
authenticated_user: None,
_fetch_authenticated_user_task: cx.spawn(async move |this, cx| {
maybe!(async move {
loop {
if cloud_client.has_credentials() {
break;
}
cx.background_executor()
.timer(Duration::from_millis(100))
.await;
}
let response = cloud_client.get_authenticated_user().await?;
this.update(cx, |this, _cx| {
this.authenticated_user = Some(response.user);
})
})
.await
.context("failed to fetch authenticated user")
.log_err();
}),
}
}
}

View File

@@ -1,10 +1,6 @@
use super::{Client, Status, TypedEnvelope, proto};
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use cloud_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
use collections::{HashMap, HashSet, hash_map::Entry};
use derive_more::Deref;
use feature_flags::FeatureFlagAppExt;
@@ -21,6 +17,10 @@ use std::{
};
use text::ReplicaId;
use util::{TryFutureExt as _, maybe};
use zed_llm_client::{
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
pub type UserId = u64;

View File

@@ -1,21 +0,0 @@
[package]
name = "cloud_api_client"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "Apache-2.0"
[lints]
workspace = true
[lib]
path = "src/cloud_api_client.rs"
[dependencies]
anyhow.workspace = true
cloud_api_types.workspace = true
futures.workspace = true
http_client.workspace = true
parking_lot.workspace = true
serde_json.workspace = true
workspace-hack.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1,79 +0,0 @@
use std::sync::Arc;
use anyhow::{Result, anyhow};
pub use cloud_api_types::*;
use futures::AsyncReadExt as _;
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
use parking_lot::RwLock;
struct Credentials {
user_id: u32,
access_token: String,
}
pub struct CloudApiClient {
credentials: RwLock<Option<Credentials>>,
http_client: Arc<HttpClientWithUrl>,
}
impl CloudApiClient {
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
Self {
credentials: RwLock::new(None),
http_client,
}
}
pub fn has_credentials(&self) -> bool {
self.credentials.read().is_some()
}
pub fn set_credentials(&self, user_id: u32, access_token: String) {
*self.credentials.write() = Some(Credentials {
user_id,
access_token,
});
}
fn authorization_header(&self) -> Result<String> {
let guard = self.credentials.read();
let credentials = guard
.as_ref()
.ok_or_else(|| anyhow!("No credentials provided"))?;
Ok(format!(
"{} {}",
credentials.user_id, credentials.access_token
))
}
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
let request = Request::builder()
.method(Method::GET)
.uri(
self.http_client
.build_zed_cloud_url("/client/users/me", &[])?
.as_ref(),
)
.header("Content-Type", "application/json")
.header("Authorization", self.authorization_header()?)
.body(AsyncBody::default())?;
let mut response = self.http_client.send(request).await?;
if !response.status().is_success() {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::bail!(
"Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
response.status()
)
}
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Ok(serde_json::from_str(&body)?)
}
}

View File

@@ -1,16 +0,0 @@
[package]
name = "cloud_api_types"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
license = "Apache-2.0"
[lints]
workspace = true
[lib]
path = "src/cloud_api_types.rs"
[dependencies]
serde.workspace = true
workspace-hack.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1,14 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct GetAuthenticatedUserResponse {
pub user: AuthenticatedUser,
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct AuthenticatedUser {
pub id: i32,
pub avatar_url: String,
pub github_login: String,
pub name: Option<String>,
}

View File

@@ -1,23 +0,0 @@
[package]
name = "cloud_llm_client"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
license = "Apache-2.0"
[lints]
workspace = true
[lib]
path = "src/cloud_llm_client.rs"
[dependencies]
anyhow.workspace = true
serde = { workspace = true, features = ["derive", "rc"] }
serde_json.workspace = true
strum = { workspace = true, features = ["derive"] }
uuid = { workspace = true, features = ["serde"] }
workspace-hack.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1,370 +0,0 @@
use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use strum::{Display, EnumIter, EnumString};
use uuid::Uuid;
/// The name of the header used to indicate which version of Zed the client is running.
pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
/// The name of the header used to indicate when a request failed due to an
/// expired LLM token.
///
/// The client may use this as a signal to refresh the token.
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
/// The name of the header used to indicate what plan the user is currently on.
pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
/// The name of the header used to indicate the usage limit for model requests.
pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
/// The name of the header used to indicate the usage amount for model requests.
pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
/// The name of the header used to indicate the usage limit for edit predictions.
pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
/// The name of the header used to indicate the usage amount for edit predictions.
pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
/// The name of the header used to indicate the resource for which the subscription limit has been reached.
pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
/// The name of the header used to indicate the the minimum required Zed version.
///
/// This can be used to force a Zed upgrade in order to continue communicating
/// with the LLM service.
pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
"x-zed-client-supports-status-messages";
/// The name of the header used by the server to indicate to the client that it supports sending status messages.
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
"x-zed-server-supports-status-messages";
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UsageLimit {
Limited(i32),
Unlimited,
}
impl FromStr for UsageLimit {
type Err = anyhow::Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
"unlimited" => Ok(Self::Unlimited),
limit => limit
.parse::<i32>()
.map(Self::Limited)
.context("failed to parse limit"),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Plan {
#[default]
#[serde(alias = "Free")]
ZedFree,
#[serde(alias = "ZedPro")]
ZedPro,
#[serde(alias = "ZedProTrial")]
ZedProTrial,
}
impl Plan {
pub fn as_str(&self) -> &'static str {
match self {
Plan::ZedFree => "zed_free",
Plan::ZedPro => "zed_pro",
Plan::ZedProTrial => "zed_pro_trial",
}
}
pub fn model_requests_limit(&self) -> UsageLimit {
match self {
Plan::ZedPro => UsageLimit::Limited(500),
Plan::ZedProTrial => UsageLimit::Limited(150),
Plan::ZedFree => UsageLimit::Limited(50),
}
}
pub fn edit_predictions_limit(&self) -> UsageLimit {
match self {
Plan::ZedPro => UsageLimit::Unlimited,
Plan::ZedProTrial => UsageLimit::Unlimited,
Plan::ZedFree => UsageLimit::Limited(2_000),
}
}
}
impl FromStr for Plan {
type Err = anyhow::Error;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
"zed_free" => Ok(Plan::ZedFree),
"zed_pro" => Ok(Plan::ZedPro),
"zed_pro_trial" => Ok(Plan::ZedProTrial),
plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
}
}
}
#[derive(
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum LanguageModelProvider {
Anthropic,
OpenAi,
Google,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsBody {
#[serde(skip_serializing_if = "Option::is_none", default)]
pub outline: Option<String>,
pub input_events: String,
pub input_excerpt: String,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub speculated_output: Option<String>,
/// Whether the user provided consent for sampling this interaction.
#[serde(default, alias = "data_collection_permission")]
pub can_collect_data: bool,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
pub request_id: Uuid,
pub output_excerpt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AcceptEditPredictionBody {
pub request_id: Uuid,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionMode {
Normal,
Max,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionIntent {
UserPrompt,
ToolResults,
ThreadSummarization,
ThreadContextSummarization,
CreateFile,
EditFile,
InlineAssist,
TerminalInlineAssist,
GenerateGitCommitMessage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionBody {
#[serde(skip_serializing_if = "Option::is_none", default)]
pub thread_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub prompt_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub intent: Option<CompletionIntent>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub mode: Option<CompletionMode>,
pub provider: LanguageModelProvider,
pub model: String,
pub provider_request: serde_json::Value,
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionRequestStatus {
Queued {
position: usize,
},
Started,
Failed {
code: String,
message: String,
request_id: Uuid,
/// Retry duration in seconds.
retry_after: Option<f64>,
},
UsageUpdated {
amount: usize,
limit: UsageLimit,
},
ToolUseLimitReached,
}
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompletionEvent<T> {
Status(CompletionRequestStatus),
Event(T),
}
impl<T> CompletionEvent<T> {
pub fn into_status(self) -> Option<CompletionRequestStatus> {
match self {
Self::Status(status) => Some(status),
Self::Event(_) => None,
}
}
pub fn into_event(self) -> Option<T> {
match self {
Self::Event(event) => Some(event),
Self::Status(_) => None,
}
}
}
#[derive(Serialize, Deserialize)]
pub struct WebSearchBody {
pub query: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct WebSearchResponse {
pub results: Vec<WebSearchResult>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct WebSearchResult {
pub title: String,
pub url: String,
pub text: String,
}
#[derive(Serialize, Deserialize)]
pub struct CountTokensBody {
pub provider: LanguageModelProvider,
pub model: String,
pub provider_request: serde_json::Value,
}
#[derive(Serialize, Deserialize)]
pub struct CountTokensResponse {
pub tokens: usize,
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
pub struct LanguageModelId(pub Arc<str>);
impl std::fmt::Display for LanguageModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LanguageModel {
pub provider: LanguageModelProvider,
pub id: LanguageModelId,
pub display_name: String,
pub max_token_count: usize,
pub max_token_count_in_max_mode: Option<usize>,
pub max_output_tokens: usize,
pub supports_tools: bool,
pub supports_images: bool,
pub supports_thinking: bool,
pub supports_max_mode: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ListModelsResponse {
pub models: Vec<LanguageModel>,
pub default_model: LanguageModelId,
pub default_fast_model: LanguageModelId,
pub recommended_models: Vec<LanguageModelId>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GetSubscriptionResponse {
pub plan: Plan,
pub usage: Option<CurrentUsage>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CurrentUsage {
pub model_requests: UsageData,
pub edit_predictions: UsageData,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UsageData {
pub used: u32,
pub limit: UsageLimit,
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use super::*;
#[test]
fn test_plan_deserialize_snake_case() {
let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
assert_eq!(plan, Plan::ZedFree);
let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
assert_eq!(plan, Plan::ZedPro);
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
assert_eq!(plan, Plan::ZedProTrial);
}
#[test]
fn test_plan_deserialize_aliases() {
let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
assert_eq!(plan, Plan::ZedFree);
let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
assert_eq!(plan, Plan::ZedPro);
let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
assert_eq!(plan, Plan::ZedProTrial);
}
#[test]
fn test_usage_limit_from_str() {
let limit = UsageLimit::from_str("unlimited").unwrap();
assert!(matches!(limit, UsageLimit::Unlimited));
let limit = UsageLimit::from_str(&0.to_string()).unwrap();
assert!(matches!(limit, UsageLimit::Limited(0)));
let limit = UsageLimit::from_str(&50.to_string()).unwrap();
assert!(matches!(limit, UsageLimit::Limited(50)));
for value in ["not_a_number", "50xyz"] {
let limit = UsageLimit::from_str(value);
assert!(limit.is_err());
}
}
}

View File

@@ -23,14 +23,13 @@ async-stripe.workspace = true
async-trait.workspace = true
async-tungstenite.workspace = true
aws-config = { version = "1.1.5" }
aws-sdk-kinesis = "1.51.0"
aws-sdk-s3 = { version = "1.15.0" }
aws-sdk-kinesis = "1.51.0"
axum = { version = "0.6", features = ["json", "headers", "ws"] }
axum-extra = { version = "0.4", features = ["erased-json"] }
base64.workspace = true
chrono.workspace = true
clock.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
@@ -76,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re
util.workspace = true
uuid.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
agent_settings.workspace = true

View File

@@ -100,7 +100,7 @@ impl std::fmt::Display for SystemIdHeader {
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/user", get(legacy_update_or_create_authenticated_user))
.route("/user", get(update_or_create_authenticated_user))
.route("/users/look_up", get(look_up_user))
.route("/users/:id/access_tokens", post(create_access_token))
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
@@ -161,10 +161,7 @@ struct AuthenticatedUserResponse {
feature_flags: Vec<String>,
}
/// This is a legacy endpoint that is no longer used in production.
///
/// It currently only exists to be used when developing Collab locally.
async fn legacy_update_or_create_authenticated_user(
async fn update_or_create_authenticated_user(
Query(params): Query<AuthenticatedUserParams>,
Extension(app): Extension<Arc<AppState>>,
) -> Result<Json<AuthenticatedUserResponse>> {
@@ -356,9 +353,9 @@ async fn refresh_llm_tokens(
#[derive(Debug, Serialize, Deserialize)]
struct UpdatePlanBody {
pub plan: cloud_llm_client::Plan,
pub plan: zed_llm_client::Plan,
pub subscription_period: SubscriptionPeriod,
pub usage: cloud_llm_client::CurrentUsage,
pub usage: zed_llm_client::CurrentUsage,
pub trial_started_at: Option<DateTime<Utc>>,
pub is_usage_based_billing_enabled: bool,
pub is_account_too_young: bool,
@@ -380,9 +377,9 @@ async fn update_plan(
extract::Json(body): extract::Json<UpdatePlanBody>,
) -> Result<Json<UpdatePlanResponse>> {
let plan = match body.plan {
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
};
let update_user_plan = proto::UpdateUserPlan {
@@ -414,15 +411,15 @@ async fn update_plan(
Ok(Json(UpdatePlanResponse {}))
}
fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
proto::UsageLimit {
variant: Some(match limit {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),

View File

@@ -1,11 +1,11 @@
use anyhow::{Context as _, bail};
use chrono::{DateTime, Utc};
use cloud_llm_client::LanguageModelProvider;
use collections::{HashMap, HashSet};
use sea_orm::ActiveValue;
use std::{sync::Arc, time::Duration};
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
use util::{ResultExt, maybe};
use zed_llm_client::LanguageModelProvider;
use crate::AppState;
use crate::db::billing_subscription::{
@@ -87,14 +87,6 @@ async fn poll_stripe_events(
stripe_client: &Arc<dyn StripeClient>,
real_stripe_client: &stripe::Client,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_events_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
if sync_events_using_cloud {
return Ok(());
}
fn event_type_to_string(event_type: EventType) -> String {
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
// so we need to unquote it.
@@ -577,14 +569,6 @@ async fn sync_model_request_usage_with_stripe(
llm_db: &Arc<LlmDatabase>,
stripe_billing: &Arc<StripeBilling>,
) -> anyhow::Result<()> {
let feature_flags = app.db.list_feature_flags().await?;
let sync_model_request_usage_using_cloud = feature_flags
.iter()
.any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
if sync_model_request_usage_using_cloud {
return Ok(());
}
log::info!("Stripe usage sync: Starting");
let started_at = Utc::now();

View File

@@ -8,6 +8,7 @@ use axum::{
use chrono::{NaiveDateTime, SecondsFormat};
use serde::{Deserialize, Serialize};
use crate::api::AuthenticatedUserParams;
use crate::db::ContributorSelector;
use crate::{AppState, Result};
@@ -103,18 +104,9 @@ impl RenovateBot {
}
}
#[derive(Debug, Deserialize)]
struct AddContributorBody {
github_user_id: i32,
github_login: String,
github_email: Option<String>,
github_name: Option<String>,
github_user_created_at: chrono::DateTime<chrono::Utc>,
}
async fn add_contributor(
Extension(app): Extension<Arc<AppState>>,
extract::Json(params): extract::Json<AddContributorBody>,
extract::Json(params): extract::Json<AuthenticatedUserParams>,
) -> Result<()> {
let initial_channel_id = app.config.auto_join_channel_id;
app.db

View File

@@ -95,7 +95,7 @@ pub enum SubscriptionKind {
ZedFree,
}
impl From<SubscriptionKind> for cloud_llm_client::Plan {
impl From<SubscriptionKind> for zed_llm_client::Plan {
fn from(value: SubscriptionKind) -> Self {
match value {
SubscriptionKind::ZedPro => Self::ZedPro,

View File

@@ -6,11 +6,11 @@ mod tables;
#[cfg(test)]
mod tests;
use cloud_llm_client::LanguageModelProvider;
use collections::HashMap;
pub use ids::*;
pub use seed::*;
pub use tables::*;
use zed_llm_client::LanguageModelProvider;
#[cfg(test)]
pub use tests::TestLlmDb;

View File

@@ -1,5 +1,5 @@
use cloud_llm_client::LanguageModelProvider;
use pretty_assertions::assert_eq;
use zed_llm_client::LanguageModelProvider;
use crate::llm::db::LlmDatabase;
use crate::test_llm_db;

View File

@@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA
use crate::{Config, db::billing_preference};
use anyhow::{Context as _, Result};
use chrono::{NaiveDateTime, Utc};
use cloud_llm_client::Plan;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;
use zed_llm_client::Plan;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -2868,12 +2868,12 @@ async fn make_update_user_plan_message(
}
fn model_requests_limit(
plan: cloud_llm_client::Plan,
plan: zed_llm_client::Plan,
feature_flags: &Vec<String>,
) -> cloud_llm_client::UsageLimit {
) -> zed_llm_client::UsageLimit {
match plan.model_requests_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
zed_llm_client::UsageLimit::Limited(limit) => {
let limit = if plan == zed_llm_client::Plan::ZedProTrial
&& feature_flags
.iter()
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
@@ -2883,9 +2883,9 @@ fn model_requests_limit(
limit
};
cloud_llm_client::UsageLimit::Limited(limit)
zed_llm_client::UsageLimit::Limited(limit)
}
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
}
}
@@ -2895,21 +2895,21 @@ fn subscription_usage_to_proto(
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: usage.model_requests as u32,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
@@ -2917,12 +2917,12 @@ fn subscription_usage_to_proto(
edit_predictions_usage_amount: usage.edit_predictions as u32,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
@@ -2935,21 +2935,21 @@ fn make_default_subscription_usage(
feature_flags: &Vec<String>,
) -> proto::SubscriptionUsage {
let plan = match plan {
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
};
proto::SubscriptionUsage {
model_requests_usage_amount: 0,
model_requests_usage_limit: Some(proto::UsageLimit {
variant: Some(match model_requests_limit(plan, feature_flags) {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),
@@ -2957,12 +2957,12 @@ fn make_default_subscription_usage(
edit_predictions_usage_amount: 0,
edit_predictions_usage_limit: Some(proto::UsageLimit {
variant: Some(match plan.edit_predictions_limit() {
cloud_llm_client::UsageLimit::Limited(limit) => {
zed_llm_client::UsageLimit::Limited(limit) => {
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
limit: limit as u32,
})
}
cloud_llm_client::UsageLimit::Unlimited => {
zed_llm_client::UsageLimit::Unlimited => {
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
}
}),

View File

@@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room(
// Allow user A to reconnect to the server.
server.allow_connections();
executor.advance_clock(RECONNECT_TIMEOUT);
executor.advance_clock(RECEIVE_TIMEOUT);
// Call user B again from client A.
active_call_a
@@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections(
// User A reconnects automatically, then calls user B again.
server.allow_connections();
executor.advance_clock(RECONNECT_TIMEOUT);
executor.advance_clock(RECEIVE_TIMEOUT);
active_call_a
.update(cx_a, |call, cx| {
call.invite(client_b1.user_id().unwrap(), None, cx)

View File

@@ -8,7 +8,6 @@ use crate::{
use anyhow::anyhow;
use call::ActiveCall;
use channel::{ChannelBuffer, ChannelStore};
use client::CloudUserStore;
use client::{
self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
proto::PeerId,
@@ -282,14 +281,12 @@ impl TestServer {
.register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance()));
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
let cloud_user_store = cx.new(|cx| CloudUserStore::new(client.cloud_client(), cx));
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
let session = cx.new(|cx| AppSession::new(Session::test(), cx));
let app_state = Arc::new(workspace::AppState {
client: client.clone(),
user_store: user_store.clone(),
cloud_user_store,
workspace_store,
languages: language_registry,
fs: fs.clone(),

View File

@@ -7,17 +7,17 @@ license = "GPL-3.0-or-later"
[dependencies]
anyhow.workspace = true
command_palette.workspace = true
gpui.workspace = true
clap.workspace = true
mdbook = "0.4.40"
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
regex.workspace = true
util.workspace = true
workspace-hack.workspace = true
zed.workspace = true
zlog.workspace = true
gpui.workspace = true
command_palette.workspace = true
[lints]
workspace = true

View File

@@ -1,15 +1,14 @@
use anyhow::{Context, Result};
use anyhow::Result;
use clap::{Arg, ArgMatches, Command};
use mdbook::BookItem;
use mdbook::book::{Book, Chapter};
use mdbook::preprocess::CmdPreprocessor;
use regex::Regex;
use settings::KeymapFile;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::io::{self, Read};
use std::process;
use std::sync::LazyLock;
use util::paths::PathExt;
static KEYMAP_MACOS: LazyLock<KeymapFile> = LazyLock::new(|| {
load_keymap("keymaps/default-macos.json").expect("Failed to load MacOS keymap")
@@ -21,68 +20,60 @@ static KEYMAP_LINUX: LazyLock<KeymapFile> = LazyLock::new(|| {
static ALL_ACTIONS: LazyLock<Vec<ActionDef>> = LazyLock::new(dump_all_gpui_actions);
const FRONT_MATTER_COMMENT: &'static str = "<!-- ZED_META {} -->";
pub fn make_app() -> Command {
Command::new("zed-docs-preprocessor")
.about("Preprocesses Zed Docs content to provide rich action & keybinding support and more")
.subcommand(
Command::new("supports")
.arg(Arg::new("renderer").required(true))
.about("Check whether a renderer is supported by this preprocessor"),
)
}
fn main() -> Result<()> {
zlog::init();
zlog::init_output_stderr();
let matches = make_app().get_matches();
// call a zed:: function so everything in `zed` crate is linked and
// all actions in the actual app are registered
zed::stdout_is_a_pty();
let args = std::env::args().skip(1).collect::<Vec<_>>();
match args.get(0).map(String::as_str) {
Some("supports") => {
let renderer = args.get(1).expect("Required argument");
let supported = renderer != "not-supported";
if supported {
process::exit(0);
} else {
process::exit(1);
}
}
Some("postprocess") => handle_postprocessing()?,
_ => handle_preprocessing()?,
if let Some(sub_args) = matches.subcommand_matches("supports") {
handle_supports(sub_args);
} else {
handle_preprocessing()?;
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum PreprocessorError {
enum Error {
ActionNotFound { action_name: String },
DeprecatedActionUsed { used: String, should_be: String },
InvalidFrontmatterLine(String),
}
impl PreprocessorError {
impl Error {
fn new_for_not_found_action(action_name: String) -> Self {
for action in &*ALL_ACTIONS {
for alias in action.deprecated_aliases {
if alias == &action_name {
return PreprocessorError::DeprecatedActionUsed {
return Error::DeprecatedActionUsed {
used: action_name.clone(),
should_be: action.name.to_string(),
};
}
}
}
PreprocessorError::ActionNotFound {
Error::ActionNotFound {
action_name: action_name.to_string(),
}
}
}
impl std::fmt::Display for PreprocessorError {
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PreprocessorError::InvalidFrontmatterLine(line) => {
write!(f, "Invalid frontmatter line: {}", line)
}
PreprocessorError::ActionNotFound { action_name } => {
write!(f, "Action not found: {}", action_name)
}
PreprocessorError::DeprecatedActionUsed { used, should_be } => write!(
Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name),
Error::DeprecatedActionUsed { used, should_be } => write!(
f,
"Deprecated action used: {} should be {}",
used, should_be
@@ -98,9 +89,8 @@ fn handle_preprocessing() -> Result<()> {
let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?;
let mut errors = HashSet::<PreprocessorError>::new();
let mut errors = HashSet::<Error>::new();
handle_frontmatter(&mut book, &mut errors);
template_and_validate_keybindings(&mut book, &mut errors);
template_and_validate_actions(&mut book, &mut errors);
@@ -118,41 +108,19 @@ fn handle_preprocessing() -> Result<()> {
Ok(())
}
fn handle_frontmatter(book: &mut Book, errors: &mut HashSet<PreprocessorError>) {
let frontmatter_regex = Regex::new(r"(?s)^\s*---(.*?)---").unwrap();
for_each_chapter_mut(book, |chapter| {
let new_content = frontmatter_regex.replace(&chapter.content, |caps: &regex::Captures| {
let frontmatter = caps[1].trim();
let frontmatter = frontmatter.trim_matches(&[' ', '-', '\n']);
let mut metadata = HashMap::<String, String>::default();
for line in frontmatter.lines() {
let Some((name, value)) = line.split_once(':') else {
errors.insert(PreprocessorError::InvalidFrontmatterLine(format!(
"{}: {}",
chapter_breadcrumbs(&chapter),
line
)));
continue;
};
let name = name.trim();
let value = value.trim();
metadata.insert(name.to_string(), value.to_string());
}
FRONT_MATTER_COMMENT.replace(
"{}",
&serde_json::to_string(&metadata).expect("Failed to serialize metadata"),
)
});
match new_content {
Cow::Owned(content) => {
chapter.content = content;
}
Cow::Borrowed(_) => {}
}
});
fn handle_supports(sub_args: &ArgMatches) -> ! {
let renderer = sub_args
.get_one::<String>("renderer")
.expect("Required argument");
let supported = renderer != "not-supported";
if supported {
process::exit(0);
} else {
process::exit(1);
}
}
fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<PreprocessorError>) {
fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Error>) {
let regex = Regex::new(r"\{#kb (.*?)\}").unwrap();
for_each_chapter_mut(book, |chapter| {
@@ -160,9 +128,7 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Prepr
.replace_all(&chapter.content, |caps: &regex::Captures| {
let action = caps[1].trim();
if find_action_by_name(action).is_none() {
errors.insert(PreprocessorError::new_for_not_found_action(
action.to_string(),
));
errors.insert(Error::new_for_not_found_action(action.to_string()));
return String::new();
}
let macos_binding = find_binding("macos", action).unwrap_or_default();
@@ -178,7 +144,7 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Prepr
});
}
fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<PreprocessorError>) {
fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<Error>) {
let regex = Regex::new(r"\{#action (.*?)\}").unwrap();
for_each_chapter_mut(book, |chapter| {
@@ -186,9 +152,7 @@ fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<Preproces
.replace_all(&chapter.content, |caps: &regex::Captures| {
let name = caps[1].trim();
let Some(action) = find_action_by_name(name) else {
errors.insert(PreprocessorError::new_for_not_found_action(
name.to_string(),
));
errors.insert(Error::new_for_not_found_action(name.to_string()));
return String::new();
};
format!("<code class=\"hljs\">{}</code>", &action.human_name)
@@ -253,13 +217,6 @@ fn name_for_action(action_as_str: String) -> String {
.unwrap_or(action_as_str)
}
fn chapter_breadcrumbs(chapter: &Chapter) -> String {
let mut breadcrumbs = Vec::with_capacity(chapter.parent_names.len() + 1);
breadcrumbs.extend(chapter.parent_names.iter().map(String::as_str));
breadcrumbs.push(chapter.name.as_str());
format!("[{:?}] {}", chapter.source_path, breadcrumbs.join(" > "))
}
fn load_keymap(asset_path: &str) -> Result<KeymapFile> {
let content = util::asset_str::<settings::SettingsAssets>(asset_path);
KeymapFile::parse(content.as_ref())
@@ -297,126 +254,3 @@ fn dump_all_gpui_actions() -> Vec<ActionDef> {
return actions;
}
fn handle_postprocessing() -> Result<()> {
let logger = zlog::scoped!("render");
let mut ctx = mdbook::renderer::RenderContext::from_json(io::stdin())?;
let output = ctx
.config
.get_mut("output")
.expect("has output")
.as_table_mut()
.expect("output is table");
let zed_html = output.remove("zed-html").expect("zed-html output defined");
let default_description = zed_html
.get("default-description")
.expect("Default description not found")
.as_str()
.expect("Default description not a string")
.to_string();
let default_title = zed_html
.get("default-title")
.expect("Default title not found")
.as_str()
.expect("Default title not a string")
.to_string();
output.insert("html".to_string(), zed_html);
mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?;
let ignore_list = ["toc.html"];
let root_dir = ctx.destination.clone();
let mut files = Vec::with_capacity(128);
let mut queue = Vec::with_capacity(64);
queue.push(root_dir.clone());
while let Some(dir) = queue.pop() {
for entry in std::fs::read_dir(&dir).context(dir.to_sanitized_string())? {
let Ok(entry) = entry else {
continue;
};
let file_type = entry.file_type().context("Failed to determine file type")?;
if file_type.is_dir() {
queue.push(entry.path());
}
if file_type.is_file()
&& matches!(
entry.path().extension().and_then(std::ffi::OsStr::to_str),
Some("html")
)
{
if ignore_list.contains(&&*entry.file_name().to_string_lossy()) {
zlog::info!(logger => "Ignoring {}", entry.path().to_string_lossy());
} else {
files.push(entry.path());
}
}
}
}
zlog::info!(logger => "Processing {} `.html` files", files.len());
let meta_regex = Regex::new(&FRONT_MATTER_COMMENT.replace("{}", "(.*)")).unwrap();
for file in files {
let contents = std::fs::read_to_string(&file)?;
let mut meta_description = None;
let mut meta_title = None;
let contents = meta_regex.replace(&contents, |caps: &regex::Captures| {
let metadata: HashMap<String, String> = serde_json::from_str(&caps[1]).with_context(|| format!("JSON Metadata: {:?}", &caps[1])).expect("Failed to deserialize metadata");
for (kind, content) in metadata {
match kind.as_str() {
"description" => {
meta_description = Some(content);
}
"title" => {
meta_title = Some(content);
}
_ => {
zlog::warn!(logger => "Unrecognized frontmatter key: {} in {:?}", kind, pretty_path(&file, &root_dir));
}
}
}
String::new()
});
let meta_description = meta_description.as_ref().unwrap_or_else(|| {
zlog::warn!(logger => "No meta description found for {:?}", pretty_path(&file, &root_dir));
&default_description
});
let page_title = extract_title_from_page(&contents, pretty_path(&file, &root_dir));
let meta_title = meta_title.as_ref().unwrap_or_else(|| {
zlog::debug!(logger => "No meta title found for {:?}", pretty_path(&file, &root_dir));
&default_title
});
let meta_title = format!("{} | {}", page_title, meta_title);
zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir));
let contents = contents.replace("#description#", meta_description);
let contents = TITLE_REGEX
.replace(&contents, |_: &regex::Captures| {
format!("<title>{}</title>", meta_title)
})
.to_string();
// let contents = contents.replace("#title#", &meta_title);
std::fs::write(file, contents)?;
}
return Ok(());
fn pretty_path<'a>(
path: &'a std::path::PathBuf,
root: &'a std::path::PathBuf,
) -> &'a std::path::Path {
&path.strip_prefix(&root).unwrap_or(&path)
}
const TITLE_REGEX: std::cell::LazyCell<Regex> =
std::cell::LazyCell::new(|| Regex::new(r"<title>\s*(.*?)\s*</title>").unwrap());
fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String {
let title_tag_contents = &TITLE_REGEX
.captures(&contents)
.with_context(|| format!("Failed to find title in {:?}", pretty_path))
.expect("Page has <title> element")[1];
let title = title_tag_contents
.trim()
.strip_suffix("- Zed")
.unwrap_or(title_tag_contents)
.trim()
.to_string();
title
}
}

View File

@@ -56,7 +56,7 @@ use aho_corasick::AhoCorasick;
use anyhow::{Context as _, Result, anyhow};
use blink_manager::BlinkManager;
use buffer_diff::DiffHunkStatus;
use client::{Collaborator, DisableAiSettings, ParticipantIndex};
use client::{Collaborator, ParticipantIndex};
use clock::{AGENT_REPLICA_ID, ReplicaId};
use collections::{BTreeMap, HashMap, HashSet, VecDeque};
use convert_case::{Case, Casing};
@@ -65,7 +65,7 @@ use display_map::*;
pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder};
pub use editor_settings::{
CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode,
ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, ShowScrollbar,
ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowScrollbar,
};
use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings};
pub use editor_settings_controls::*;
@@ -7048,7 +7048,7 @@ impl Editor {
}
pub fn update_edit_prediction_settings(&mut self, cx: &mut Context<Self>) {
if self.edit_prediction_provider.is_none() || DisableAiSettings::get_global(cx).disable_ai {
if self.edit_prediction_provider.is_none() {
self.edit_prediction_settings = EditPredictionSettings::Disabled;
} else {
let selection = self.selections.newest_anchor();

View File

@@ -19,8 +19,8 @@ path = "src/explorer.rs"
[dependencies]
agent.workspace = true
agent_settings.workspace = true
agent_ui.workspace = true
agent_settings.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
assistant_tools.workspace = true
@@ -29,7 +29,6 @@ buffer_diff.workspace = true
chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
dirs.workspace = true
@@ -69,3 +68,4 @@ util.workspace = true
uuid.workspace = true
watch.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true

View File

@@ -15,11 +15,11 @@ use agent_settings::AgentProfileId;
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use buffer_diff::DiffHunkStatus;
use cloud_llm_client::CompletionIntent;
use collections::HashMap;
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
use gpui::{App, AppContext, AsyncApp, Entity};
use language_model::{LanguageModel, Role, StopReason};
use zed_llm_client::CompletionIntent;
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);

View File

@@ -106,7 +106,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn language_server_initialization_options(
@@ -131,7 +131,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn language_server_workspace_configuration(
@@ -154,7 +154,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn language_server_additional_initialization_options(
@@ -179,7 +179,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn language_server_additional_workspace_configuration(
@@ -204,7 +204,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn labels_for_completions(
@@ -230,7 +230,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn labels_for_symbols(
@@ -256,7 +256,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn complete_slash_command_argument(
@@ -275,7 +275,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn run_slash_command(
@@ -301,7 +301,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn context_server_command(
@@ -320,7 +320,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn context_server_configuration(
@@ -347,7 +347,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
@@ -362,7 +362,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn index_docs(
@@ -388,7 +388,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn get_dap_binary(
@@ -410,7 +410,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn dap_request_kind(
&self,
@@ -427,7 +427,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn dap_config_to_scenario(&self, config: ZedDebugConfig) -> Result<DebugScenario> {
@@ -441,7 +441,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn dap_locator_create_scenario(
@@ -465,7 +465,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
async fn run_dap_locator(
&self,
@@ -481,7 +481,7 @@ impl extension::Extension for WasmExtension {
}
.boxed()
})
.await?
.await
}
}
@@ -761,7 +761,7 @@ impl WasmExtension {
.with_context(|| format!("failed to load wasm extension {}", manifest.id))
}
pub async fn call<T, Fn>(&self, f: Fn) -> Result<T>
pub async fn call<T, Fn>(&self, f: Fn) -> T
where
T: 'static + Send,
Fn: 'static
@@ -777,15 +777,14 @@ impl WasmExtension {
}
.boxed()
}))
.map_err(|_| {
anyhow!(
.unwrap_or_else(|_| {
panic!(
"wasm extension channel should not be closed yet, extension {} (id {})",
self.manifest.name,
self.manifest.id,
self.manifest.name, self.manifest.id,
)
})?;
return_rx.await.with_context(|| {
format!(
});
return_rx.await.unwrap_or_else(|_| {
panic!(
"wasm extension channel, extension {} (id {})",
self.manifest.name, self.manifest.id,
)

View File

@@ -24,7 +24,6 @@ buffer_diff.workspace = true
call.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
component.workspace = true
@@ -63,6 +62,7 @@ watch.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
[target.'cfg(windows)'.dependencies]
windows.workspace = true

View File

@@ -71,12 +71,12 @@ use ui::{
use util::{ResultExt, TryFutureExt, maybe};
use workspace::SERIALIZATION_THROTTLE_TIME;
use cloud_llm_client::CompletionIntent;
use workspace::{
Workspace,
dock::{DockPosition, Panel, PanelEvent},
notifications::{DetachAndPromptErr, ErrorMessagePrompt, NotificationId},
};
use zed_llm_client::CompletionIntent;
actions!(
git_panel,

View File

@@ -216,6 +216,10 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf
x11-clipboard = { version = "0.9.3", optional = true }
[target.'cfg(target_os = "windows")'.dependencies]
blade-util.workspace = true
bytemuck = "1"
blade-graphics.workspace = true
blade-macros.workspace = true
flume = "0.11"
rand.workspace = true
windows.workspace = true
@@ -236,6 +240,7 @@ util = { workspace = true, features = ["test-support"] }
[target.'cfg(target_os = "windows")'.build-dependencies]
embed-resource = "3.0"
naga.workspace = true
[target.'cfg(target_os = "macos")'.build-dependencies]
bindgen = "0.71"

View File

@@ -9,10 +9,7 @@ fn main() {
let target = env::var("CARGO_CFG_TARGET_OS");
println!("cargo::rustc-check-cfg=cfg(gles)");
#[cfg(any(
not(any(target_os = "macos", target_os = "windows")),
all(target_os = "macos", feature = "macos-blade")
))]
#[cfg(any(not(target_os = "macos"), feature = "macos-blade"))]
check_wgsl_shaders();
match target.as_deref() {
@@ -20,18 +17,21 @@ fn main() {
#[cfg(target_os = "macos")]
macos::build();
}
#[cfg(all(target_os = "windows", feature = "windows-manifest"))]
Ok("windows") => {
#[cfg(target_os = "windows")]
windows::build();
let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml");
let rc_file = std::path::Path::new("resources/windows/gpui.rc");
println!("cargo:rerun-if-changed={}", manifest.display());
println!("cargo:rerun-if-changed={}", rc_file.display());
embed_resource::compile(rc_file, embed_resource::NONE)
.manifest_required()
.unwrap();
}
_ => (),
};
}
#[cfg(any(
not(any(target_os = "macos", target_os = "windows")),
all(target_os = "macos", feature = "macos-blade")
))]
#[allow(dead_code)]
fn check_wgsl_shaders() {
use std::path::PathBuf;
use std::process;
@@ -243,215 +243,3 @@ mod macos {
}
}
}
#[cfg(target_os = "windows")]
mod windows {
use std::{
fs,
io::Write,
path::{Path, PathBuf},
process::{self, Command},
};
pub(super) fn build() {
// Compile HLSL shaders
#[cfg(not(debug_assertions))]
compile_shaders();
// Embed the Windows manifest and resource file
#[cfg(feature = "windows-manifest")]
embed_resource();
}
#[cfg(feature = "windows-manifest")]
fn embed_resource() {
let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml");
let rc_file = std::path::Path::new("resources/windows/gpui.rc");
println!("cargo:rerun-if-changed={}", manifest.display());
println!("cargo:rerun-if-changed={}", rc_file.display());
embed_resource::compile(rc_file, embed_resource::NONE)
.manifest_required()
.unwrap();
}
/// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler.
fn compile_shaders() {
let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
.join("src/platform/windows/shaders.hlsl");
let out_dir = std::env::var("OUT_DIR").unwrap();
println!("cargo:rerun-if-changed={}", shader_path.display());
// Check if fxc.exe is available
let fxc_path = find_fxc_compiler();
// Define all modules
let modules = [
"quad",
"shadow",
"path_rasterization",
"path_sprite",
"underline",
"monochrome_sprite",
"polychrome_sprite",
];
let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir);
if Path::new(&rust_binding_path).exists() {
fs::remove_file(&rust_binding_path)
.expect("Failed to remove existing Rust binding file");
}
for module in modules {
compile_shader_for_module(
module,
&out_dir,
&fxc_path,
shader_path.to_str().unwrap(),
&rust_binding_path,
);
}
{
let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
.join("src/platform/windows/color_text_raster.hlsl");
compile_shader_for_module(
"emoji_rasterization",
&out_dir,
&fxc_path,
shader_path.to_str().unwrap(),
&rust_binding_path,
);
}
}
/// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler.
fn find_fxc_compiler() -> String {
// Check environment variable
if let Ok(path) = std::env::var("GPUI_FXC_PATH") {
if Path::new(&path).exists() {
return path;
}
}
// Try to find in PATH
// NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe`
if let Ok(output) = std::process::Command::new("where.exe")
.arg("fxc.exe")
.output()
{
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout);
return path.trim().to_string();
}
}
// Check the default path
if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe")
.exists()
{
return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe"
.to_string();
}
panic!("Failed to find fxc.exe");
}
fn compile_shader_for_module(
module: &str,
out_dir: &str,
fxc_path: &str,
shader_path: &str,
rust_binding_path: &str,
) {
// Compile vertex shader
let output_file = format!("{}/{}_vs.h", out_dir, module);
let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase());
compile_shader_impl(
fxc_path,
&format!("{module}_vertex"),
&output_file,
&const_name,
shader_path,
"vs_4_1",
);
generate_rust_binding(&const_name, &output_file, &rust_binding_path);
// Compile fragment shader
let output_file = format!("{}/{}_ps.h", out_dir, module);
let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase());
compile_shader_impl(
fxc_path,
&format!("{module}_fragment"),
&output_file,
&const_name,
shader_path,
"ps_4_1",
);
generate_rust_binding(&const_name, &output_file, &rust_binding_path);
}
fn compile_shader_impl(
fxc_path: &str,
entry_point: &str,
output_path: &str,
var_name: &str,
shader_path: &str,
target: &str,
) {
let output = Command::new(fxc_path)
.args([
"/T",
target,
"/E",
entry_point,
"/Fh",
output_path,
"/Vn",
var_name,
"/O3",
shader_path,
])
.output();
match output {
Ok(result) => {
if result.status.success() {
return;
}
eprintln!(
"Shader compilation failed for {}:\n{}",
entry_point,
String::from_utf8_lossy(&result.stderr)
);
process::exit(1);
}
Err(e) => {
eprintln!("Failed to run fxc for {}: {}", entry_point, e);
process::exit(1);
}
}
}
fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) {
let header_content = fs::read_to_string(head_file).expect("Failed to read header file");
let const_definition = {
let global_var_start = header_content.find("const BYTE").unwrap();
let global_var = &header_content[global_var_start..];
let equal = global_var.find('=').unwrap();
global_var[equal + 1..].trim()
};
let rust_binding = format!(
"const {}: &[u8] = &{}\n",
const_name,
const_definition.replace('{', "[").replace('}', "]")
);
let mut options = fs::OpenOptions::new()
.create(true)
.append(true)
.open(output_path)
.expect("Failed to open Rust binding file");
options
.write_all(rust_binding.as_bytes())
.expect("Failed to write Rust binding file");
}
}

View File

@@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid {
"χ", "ψ", "", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р",
"У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*",
"_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "", "µ",
"", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎",
"", "<=", "!=", "==", "--", "++", "=>", "->",
];
let columns = 11;

View File

@@ -35,7 +35,6 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) {
/// An RGBA color
#[derive(PartialEq, Clone, Copy, Default)]
#[repr(C)]
pub struct Rgba {
/// The red component of the color, in the range 0.0 to 1.0
pub r: f32,

View File

@@ -13,7 +13,8 @@ mod mac;
any(target_os = "linux", target_os = "freebsd"),
any(feature = "x11", feature = "wayland")
),
all(target_os = "macos", feature = "macos-blade")
target_os = "windows",
feature = "macos-blade"
))]
mod blade;
@@ -447,8 +448,6 @@ impl Tiling {
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
pub(crate) struct RequestFrameOptions {
pub(crate) require_presentation: bool,
/// Force refresh of all rendering states when true
pub(crate) force_render: bool,
}
pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {

View File

@@ -1004,13 +1004,12 @@ impl X11Client {
let mut keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code);
let keysym = state.xkb.key_get_one_sym(code);
if keysym.is_modifier_key() {
return Some(());
}
// should be called after key_get_one_sym
state.xkb.update_key(code, xkbc::KeyDirection::Down);
if keysym.is_modifier_key() {
return Some(());
}
if let Some(mut compose_state) = state.compose_state.take() {
compose_state.feed(keysym);
match compose_state.status() {
@@ -1068,13 +1067,12 @@ impl X11Client {
let keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code);
let keysym = state.xkb.key_get_one_sym(code);
if keysym.is_modifier_key() {
return Some(());
}
// should be called after key_get_one_sym
state.xkb.update_key(code, xkbc::KeyDirection::Up);
if keysym.is_modifier_key() {
return Some(());
}
keystroke
};
drop(state);
@@ -1795,7 +1793,6 @@ impl X11ClientState {
drop(state);
window.refresh(RequestFrameOptions {
require_presentation: expose_event_received,
force_render: false,
});
}
xcb_connection

View File

@@ -1,8 +1,6 @@
mod clipboard;
mod destination_list;
mod direct_write;
mod directx_atlas;
mod directx_renderer;
mod dispatcher;
mod display;
mod events;
@@ -16,8 +14,6 @@ mod wrapper;
pub(crate) use clipboard::*;
pub(crate) use destination_list::*;
pub(crate) use direct_write::*;
pub(crate) use directx_atlas::*;
pub(crate) use directx_renderer::*;
pub(crate) use dispatcher::*;
pub(crate) use display::*;
pub(crate) use events::*;

View File

@@ -1,39 +0,0 @@
struct RasterVertexOutput {
float4 position : SV_Position;
float2 texcoord : TEXCOORD0;
};
RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID)
{
RasterVertexOutput output;
output.texcoord = float2((vertexID << 1) & 2, vertexID & 2);
output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f);
output.position.y = -output.position.y;
return output;
}
struct PixelInput {
float4 position: SV_Position;
float2 texcoord : TEXCOORD0;
};
struct Bounds {
int2 origin;
int2 size;
};
Texture2D<float4> t_layer : register(t0);
SamplerState s_layer : register(s0);
cbuffer GlyphLayerTextureParams : register(b0) {
Bounds bounds;
float4 run_color;
};
float4 emoji_rasterization_fragment(PixelInput input): SV_Target {
float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb;
float alpha = (sampled.r + sampled.g + sampled.b) / 3;
return float4(run_color.rgb, alpha);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,309 +0,0 @@
use collections::FxHashMap;
use etagere::BucketedAtlasAllocator;
use parking_lot::Mutex;
use windows::Win32::Graphics::{
Direct3D11::{
D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC,
D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView,
ID3D11Texture2D,
},
Dxgi::Common::*,
};
use crate::{
AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas,
Point, Size, platform::AtlasTextureList,
};
pub(crate) struct DirectXAtlas(Mutex<DirectXAtlasState>);
struct DirectXAtlasState {
device: ID3D11Device,
device_context: ID3D11DeviceContext,
monochrome_textures: AtlasTextureList<DirectXAtlasTexture>,
polychrome_textures: AtlasTextureList<DirectXAtlasTexture>,
tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
}
struct DirectXAtlasTexture {
id: AtlasTextureId,
bytes_per_pixel: u32,
allocator: BucketedAtlasAllocator,
texture: ID3D11Texture2D,
view: [Option<ID3D11ShaderResourceView>; 1],
live_atlas_keys: u32,
}
impl DirectXAtlas {
pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self {
DirectXAtlas(Mutex::new(DirectXAtlasState {
device: device.clone(),
device_context: device_context.clone(),
monochrome_textures: Default::default(),
polychrome_textures: Default::default(),
tiles_by_key: Default::default(),
}))
}
pub(crate) fn get_texture_view(
&self,
id: AtlasTextureId,
) -> [Option<ID3D11ShaderResourceView>; 1] {
let lock = self.0.lock();
let tex = lock.texture(id);
tex.view.clone()
}
pub(crate) fn handle_device_lost(
&self,
device: &ID3D11Device,
device_context: &ID3D11DeviceContext,
) {
let mut lock = self.0.lock();
lock.device = device.clone();
lock.device_context = device_context.clone();
lock.monochrome_textures = AtlasTextureList::default();
lock.polychrome_textures = AtlasTextureList::default();
lock.tiles_by_key.clear();
}
}
impl PlatformAtlas for DirectXAtlas {
fn get_or_insert_with<'a>(
&self,
key: &AtlasKey,
build: &mut dyn FnMut() -> anyhow::Result<
Option<(Size<DevicePixels>, std::borrow::Cow<'a, [u8]>)>,
>,
) -> anyhow::Result<Option<AtlasTile>> {
let mut lock = self.0.lock();
if let Some(tile) = lock.tiles_by_key.get(key) {
Ok(Some(tile.clone()))
} else {
let Some((size, bytes)) = build()? else {
return Ok(None);
};
let tile = lock
.allocate(size, key.texture_kind())
.ok_or_else(|| anyhow::anyhow!("failed to allocate"))?;
let texture = lock.texture(tile.texture_id);
texture.upload(&lock.device_context, tile.bounds, &bytes);
lock.tiles_by_key.insert(key.clone(), tile.clone());
Ok(Some(tile))
}
}
fn remove(&self, key: &AtlasKey) {
let mut lock = self.0.lock();
let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else {
return;
};
let textures = match id.kind {
AtlasTextureKind::Monochrome => &mut lock.monochrome_textures,
AtlasTextureKind::Polychrome => &mut lock.polychrome_textures,
};
let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else {
return;
};
if let Some(mut texture) = texture_slot.take() {
texture.decrement_ref_count();
if texture.is_unreferenced() {
textures.free_list.push(texture.id.index as usize);
lock.tiles_by_key.remove(key);
} else {
*texture_slot = Some(texture);
}
}
}
}
impl DirectXAtlasState {
fn allocate(
&mut self,
size: Size<DevicePixels>,
texture_kind: AtlasTextureKind,
) -> Option<AtlasTile> {
{
let textures = match texture_kind {
AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
};
if let Some(tile) = textures
.iter_mut()
.rev()
.find_map(|texture| texture.allocate(size))
{
return Some(tile);
}
}
let texture = self.push_texture(size, texture_kind)?;
texture.allocate(size)
}
fn push_texture(
&mut self,
min_size: Size<DevicePixels>,
kind: AtlasTextureKind,
) -> Option<&mut DirectXAtlasTexture> {
const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size {
width: DevicePixels(1024),
height: DevicePixels(1024),
};
// Max texture size for DirectX. See:
// https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits
const MAX_ATLAS_SIZE: Size<DevicePixels> = Size {
width: DevicePixels(16384),
height: DevicePixels(16384),
};
let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE);
let pixel_format;
let bind_flag;
let bytes_per_pixel;
match kind {
AtlasTextureKind::Monochrome => {
pixel_format = DXGI_FORMAT_R8_UNORM;
bind_flag = D3D11_BIND_SHADER_RESOURCE;
bytes_per_pixel = 1;
}
AtlasTextureKind::Polychrome => {
pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM;
bind_flag = D3D11_BIND_SHADER_RESOURCE;
bytes_per_pixel = 4;
}
}
let texture_desc = D3D11_TEXTURE2D_DESC {
Width: size.width.0 as u32,
Height: size.height.0 as u32,
MipLevels: 1,
ArraySize: 1,
Format: pixel_format,
SampleDesc: DXGI_SAMPLE_DESC {
Count: 1,
Quality: 0,
},
Usage: D3D11_USAGE_DEFAULT,
BindFlags: bind_flag.0 as u32,
CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32,
MiscFlags: 0,
};
let mut texture: Option<ID3D11Texture2D> = None;
unsafe {
// This only returns None if the device is lost, which we will recreate later.
// So it's ok to return None here.
self.device
.CreateTexture2D(&texture_desc, None, Some(&mut texture))
.ok()?;
}
let texture = texture.unwrap();
let texture_list = match kind {
AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
};
let index = texture_list.free_list.pop();
let view = unsafe {
let mut view = None;
self.device
.CreateShaderResourceView(&texture, None, Some(&mut view))
.ok()?;
[view]
};
let atlas_texture = DirectXAtlasTexture {
id: AtlasTextureId {
index: index.unwrap_or(texture_list.textures.len()) as u32,
kind,
},
bytes_per_pixel,
allocator: etagere::BucketedAtlasAllocator::new(size.into()),
texture,
view,
live_atlas_keys: 0,
};
if let Some(ix) = index {
texture_list.textures[ix] = Some(atlas_texture);
texture_list.textures.get_mut(ix).unwrap().as_mut()
} else {
texture_list.textures.push(Some(atlas_texture));
texture_list.textures.last_mut().unwrap().as_mut()
}
}
fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture {
let textures = match id.kind {
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
};
textures[id.index as usize].as_ref().unwrap()
}
}
impl DirectXAtlasTexture {
fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
let allocation = self.allocator.allocate(size.into())?;
let tile = AtlasTile {
texture_id: self.id,
tile_id: allocation.id.into(),
bounds: Bounds {
origin: allocation.rectangle.min.into(),
size,
},
padding: 0,
};
self.live_atlas_keys += 1;
Some(tile)
}
fn upload(
&self,
device_context: &ID3D11DeviceContext,
bounds: Bounds<DevicePixels>,
bytes: &[u8],
) {
unsafe {
device_context.UpdateSubresource(
&self.texture,
0,
Some(&D3D11_BOX {
left: bounds.left().0 as u32,
top: bounds.top().0 as u32,
front: 0,
right: bounds.right().0 as u32,
bottom: bounds.bottom().0 as u32,
back: 1,
}),
bytes.as_ptr() as _,
bounds.size.width.to_bytes(self.bytes_per_pixel as u8),
0,
);
}
}
fn decrement_ref_count(&mut self) {
self.live_atlas_keys -= 1;
}
fn is_unreferenced(&mut self) -> bool {
self.live_atlas_keys == 0
}
}
impl From<Size<DevicePixels>> for etagere::Size {
fn from(size: Size<DevicePixels>) -> Self {
etagere::Size::new(size.width.into(), size.height.into())
}
}
impl From<etagere::Point> for Point<DevicePixels> {
fn from(value: etagere::Point) -> Self {
Point {
x: DevicePixels::from(value.x),
y: DevicePixels::from(value.y),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,6 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1;
pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2;
pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3;
pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4;
pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5;
const SIZE_MOVE_LOOP_TIMER_ID: usize = 1;
const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1;
@@ -38,7 +37,6 @@ pub(crate) fn handle_msg(
let handled = match msg {
WM_ACTIVATE => handle_activate_msg(wparam, state_ptr),
WM_CREATE => handle_create_msg(handle, state_ptr),
WM_DEVICECHANGE => handle_device_change_msg(handle, wparam, state_ptr),
WM_MOVE => handle_move_msg(handle, lparam, state_ptr),
WM_SIZE => handle_size_msg(wparam, lparam, state_ptr),
WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr),
@@ -50,7 +48,7 @@ pub(crate) fn handle_msg(
WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr),
WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr),
WM_PAINT => handle_paint_msg(handle, state_ptr),
WM_CLOSE => handle_close_msg(state_ptr),
WM_CLOSE => handle_close_msg(handle, state_ptr),
WM_DESTROY => handle_destroy_msg(handle, state_ptr),
WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr),
WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr),
@@ -98,7 +96,6 @@ pub(crate) fn handle_msg(
WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr),
WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr),
WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr),
WM_GPUI_FORCE_UPDATE_WINDOW => draw_window(handle, true, state_ptr),
_ => None,
};
if let Some(n) = handled {
@@ -184,9 +181,11 @@ fn handle_size_msg(
let new_size = size(DevicePixels(width), DevicePixels(height));
let scale_factor = lock.scale_factor;
if lock.restore_from_minimized.is_some() {
lock.renderer
.update_drawable_size_even_if_unchanged(new_size);
lock.callbacks.request_frame = lock.restore_from_minimized.take();
} else {
lock.renderer.resize(new_size).log_err();
lock.renderer.update_drawable_size(new_size);
}
let new_size = new_size.to_pixels(scale_factor);
lock.logical_size = new_size;
@@ -239,14 +238,40 @@ fn handle_timer_msg(
}
fn handle_paint_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
draw_window(handle, false, state_ptr)
let mut lock = state_ptr.state.borrow_mut();
if let Some(mut request_frame) = lock.callbacks.request_frame.take() {
drop(lock);
request_frame(Default::default());
state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame);
}
unsafe { ValidateRect(Some(handle), None).ok().log_err() };
Some(0)
}
fn handle_close_msg(state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
let mut callback = state_ptr.state.borrow_mut().callbacks.should_close.take()?;
let should_close = callback();
state_ptr.state.borrow_mut().callbacks.should_close = Some(callback);
if should_close { None } else { Some(0) }
fn handle_close_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
let mut lock = state_ptr.state.borrow_mut();
let output = if let Some(mut callback) = lock.callbacks.should_close.take() {
drop(lock);
let should_close = callback();
state_ptr.state.borrow_mut().callbacks.should_close = Some(callback);
if should_close { None } else { Some(0) }
} else {
None
};
// Workaround as window close animation is not played with `WS_EX_LAYERED` enabled.
if output.is_none() {
unsafe {
let current_style = get_window_long(handle, GWL_EXSTYLE);
set_window_long(
handle,
GWL_EXSTYLE,
current_style & !WS_EX_LAYERED.0 as isize,
);
}
}
output
}
fn handle_destroy_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
@@ -1198,53 +1223,6 @@ fn handle_input_language_changed(
Some(0)
}
fn handle_device_change_msg(
handle: HWND,
wparam: WPARAM,
state_ptr: Rc<WindowsWindowStatePtr>,
) -> Option<isize> {
if wparam.0 == DBT_DEVNODES_CHANGED as usize {
// The reason for sending this message is to actually trigger a redraw of the window.
unsafe {
PostMessageW(
Some(handle),
WM_GPUI_FORCE_UPDATE_WINDOW,
WPARAM(0),
LPARAM(0),
)
.log_err();
}
// If the GPU device is lost, this redraw will take care of recreating the device context.
// The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after
// the device context has been recreated.
draw_window(handle, true, state_ptr)
} else {
// Other device change messages are not handled.
None
}
}
#[inline]
fn draw_window(
handle: HWND,
force_render: bool,
state_ptr: Rc<WindowsWindowStatePtr>,
) -> Option<isize> {
let mut request_frame = state_ptr
.state
.borrow_mut()
.callbacks
.request_frame
.take()?;
request_frame(RequestFrameOptions {
require_presentation: false,
force_render,
});
state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame);
unsafe { ValidateRect(Some(handle), None).ok().log_err() };
Some(0)
}
#[inline]
fn parse_char_message(wparam: WPARAM, state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<String> {
let code_point = wparam.loword();

View File

@@ -28,12 +28,13 @@ use windows::{
core::*,
};
use crate::*;
use crate::{platform::blade::BladeContext, *};
pub(crate) struct WindowsPlatform {
state: RefCell<WindowsPlatformState>,
raw_window_handles: RwLock<SmallVec<[HWND; 4]>>,
// The below members will never change throughout the entire lifecycle of the app.
gpu_context: BladeContext,
icon: HICON,
main_receiver: flume::Receiver<Runnable>,
background_executor: BackgroundExecutor,
@@ -44,7 +45,6 @@ pub(crate) struct WindowsPlatform {
drop_target_helper: IDropTargetHelper,
validation_number: usize,
main_thread_id_win32: u32,
disable_direct_composition: bool,
}
pub(crate) struct WindowsPlatformState {
@@ -94,18 +94,14 @@ impl WindowsPlatform {
main_thread_id_win32,
validation_number,
));
let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION)
.is_ok_and(|value| value == "true" || value == "1");
let background_executor = BackgroundExecutor::new(dispatcher.clone());
let foreground_executor = ForegroundExecutor::new(dispatcher);
let directx_devices = DirectXDevices::new(disable_direct_composition)
.context("Unable to init directx devices.")?;
let bitmap_factory = ManuallyDrop::new(unsafe {
CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER)
.context("Error creating bitmap factory.")?
});
let text_system = Arc::new(
DirectWriteTextSystem::new(&directx_devices, &bitmap_factory)
DirectWriteTextSystem::new(&bitmap_factory)
.context("Error creating DirectWriteTextSystem")?,
);
let drop_target_helper: IDropTargetHelper = unsafe {
@@ -115,17 +111,18 @@ impl WindowsPlatform {
let icon = load_icon().unwrap_or_default();
let state = RefCell::new(WindowsPlatformState::new());
let raw_window_handles = RwLock::new(SmallVec::new());
let gpu_context = BladeContext::new().context("Unable to init GPU context")?;
let windows_version = WindowsVersion::new().context("Error retrieve windows version")?;
Ok(Self {
state,
raw_window_handles,
gpu_context,
icon,
main_receiver,
background_executor,
foreground_executor,
text_system,
disable_direct_composition,
windows_version,
bitmap_factory,
drop_target_helper,
@@ -190,7 +187,6 @@ impl WindowsPlatform {
validation_number: self.validation_number,
main_receiver: self.main_receiver.clone(),
main_thread_id_win32: self.main_thread_id_win32,
disable_direct_composition: self.disable_direct_composition,
}
}
@@ -347,11 +343,27 @@ impl Platform for WindowsPlatform {
fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) {
on_finish_launching();
loop {
if self.handle_events() {
break;
let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) };
begin_vsync(*vsync_event);
'a: loop {
let wait_result = unsafe {
MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT)
};
match wait_result {
// compositor clock ticked so we should draw a frame
WAIT_EVENT(0) => self.redraw_all(),
// Windows thread messages are posted
WAIT_EVENT(1) => {
if self.handle_events() {
break 'a;
}
}
_ => {
log::error!("Something went wrong while waiting {:?}", wait_result);
break;
}
}
self.redraw_all();
}
if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit {
@@ -443,7 +455,12 @@ impl Platform for WindowsPlatform {
handle: AnyWindowHandle,
options: WindowParams,
) -> Result<Box<dyn PlatformWindow>> {
let window = WindowsWindow::new(handle, options, self.generate_creation_info())?;
let window = WindowsWindow::new(
handle,
options,
self.generate_creation_info(),
&self.gpu_context,
)?;
let handle = window.get_raw_handle();
self.raw_window_handles.write().push(handle);
@@ -722,7 +739,6 @@ pub(crate) struct WindowCreationInfo {
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<Runnable>,
pub(crate) main_thread_id_win32: u32,
pub(crate) disable_direct_composition: bool,
}
fn open_target(target: &str) {
@@ -830,6 +846,16 @@ fn file_save_dialog(directory: PathBuf, window: Option<HWND>) -> Result<Option<P
Ok(Some(PathBuf::from(file_path_string)))
}
fn begin_vsync(vsync_event: HANDLE) {
let event: SafeHandle = vsync_event.into();
std::thread::spawn(move || unsafe {
loop {
windows::Win32::Graphics::Dwm::DwmFlush().log_err();
SetEvent(*event).log_err();
}
});
}
fn load_icon() -> Result<HICON> {
let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? };
let handle = unsafe {

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,7 @@ use windows::{
core::*,
};
use crate::platform::blade::{BladeContext, BladeRenderer};
use crate::*;
pub(crate) struct WindowsWindow(pub Rc<WindowsWindowStatePtr>);
@@ -48,7 +49,7 @@ pub struct WindowsWindowState {
pub system_key_handled: bool,
pub hovered: bool,
pub renderer: DirectXRenderer,
pub renderer: BladeRenderer,
pub click_state: ClickState,
pub system_settings: WindowsSystemSettings,
@@ -79,12 +80,13 @@ pub(crate) struct WindowsWindowStatePtr {
impl WindowsWindowState {
fn new(
hwnd: HWND,
transparent: bool,
cs: &CREATESTRUCTW,
current_cursor: Option<HCURSOR>,
display: WindowsDisplay,
gpu_context: &BladeContext,
min_size: Option<Size<Pixels>>,
appearance: WindowAppearance,
disable_direct_composition: bool,
) -> Result<Self> {
let scale_factor = {
let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32;
@@ -101,8 +103,7 @@ impl WindowsWindowState {
};
let border_offset = WindowBorderOffset::default();
let restore_from_minimized = None;
let renderer = DirectXRenderer::new(hwnd, disable_direct_composition)
.context("Creating DirectX renderer")?;
let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?;
let callbacks = Callbacks::default();
let input_handler = None;
let pending_surrogate = None;
@@ -205,12 +206,13 @@ impl WindowsWindowStatePtr {
fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result<Rc<Self>> {
let state = RefCell::new(WindowsWindowState::new(
hwnd,
context.transparent,
cs,
context.current_cursor,
context.display,
context.gpu_context,
context.min_size,
context.appearance,
context.disable_direct_composition,
)?);
Ok(Rc::new_cyclic(|this| Self {
@@ -327,11 +329,12 @@ pub(crate) struct Callbacks {
pub(crate) appearance_changed: Option<Box<dyn FnMut()>>,
}
struct WindowCreateContext {
struct WindowCreateContext<'a> {
inner: Option<Result<Rc<WindowsWindowStatePtr>>>,
handle: AnyWindowHandle,
hide_title_bar: bool,
display: WindowsDisplay,
transparent: bool,
is_movable: bool,
min_size: Option<Size<Pixels>>,
executor: ForegroundExecutor,
@@ -340,9 +343,9 @@ struct WindowCreateContext {
drop_target_helper: IDropTargetHelper,
validation_number: usize,
main_receiver: flume::Receiver<Runnable>,
gpu_context: &'a BladeContext,
main_thread_id_win32: u32,
appearance: WindowAppearance,
disable_direct_composition: bool,
}
impl WindowsWindow {
@@ -350,6 +353,7 @@ impl WindowsWindow {
handle: AnyWindowHandle,
params: WindowParams,
creation_info: WindowCreationInfo,
gpu_context: &BladeContext,
) -> Result<Self> {
let WindowCreationInfo {
icon,
@@ -360,7 +364,6 @@ impl WindowsWindow {
validation_number,
main_receiver,
main_thread_id_win32,
disable_direct_composition,
} = creation_info;
let classname = register_wnd_class(icon);
let hide_title_bar = params
@@ -376,18 +379,14 @@ impl WindowsWindow {
.map(|title| title.as_ref())
.unwrap_or(""),
);
let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp {
(WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0))
let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp {
(WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0))
} else {
(
WS_EX_APPWINDOW,
WS_EX_APPWINDOW | WS_EX_LAYERED,
WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX,
)
};
if !disable_direct_composition {
dwexstyle |= WS_EX_NOREDIRECTIONBITMAP;
}
let hinstance = get_module_handle();
let display = if let Some(display_id) = params.display_id {
@@ -402,6 +401,7 @@ impl WindowsWindow {
handle,
hide_title_bar,
display,
transparent: true,
is_movable: params.is_movable,
min_size: params.window_min_size,
executor,
@@ -410,9 +410,9 @@ impl WindowsWindow {
drop_target_helper,
validation_number,
main_receiver,
gpu_context,
main_thread_id_win32,
appearance,
disable_direct_composition,
};
let lpparam = Some(&context as *const _ as *const _);
let creation_result = unsafe {
@@ -453,6 +453,14 @@ impl WindowsWindow {
state: WindowOpenState::Windowed,
});
}
// The render pipeline will perform compositing on the GPU when the
// swapchain is configured correctly (see downstream of
// update_transparency).
// The following configuration is a one-time setup to ensure that the
// window is going to be composited with per-pixel alpha, but the render
// pipeline is responsible for effectively calling UpdateLayeredWindow
// at the appropriate time.
unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? };
Ok(Self(state_ptr))
}
@@ -477,6 +485,7 @@ impl rwh::HasDisplayHandle for WindowsWindow {
impl Drop for WindowsWindow {
fn drop(&mut self) {
self.0.state.borrow_mut().renderer.destroy();
// clone this `Rc` to prevent early release of the pointer
let this = self.0.clone();
self.0
@@ -696,21 +705,24 @@ impl PlatformWindow for WindowsWindow {
}
fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) {
let hwnd = self.0.hwnd;
let mut window_state = self.0.state.borrow_mut();
window_state
.renderer
.update_transparency(background_appearance != WindowBackgroundAppearance::Opaque);
match background_appearance {
WindowBackgroundAppearance::Opaque => {
// ACCENT_DISABLED
set_window_composition_attribute(hwnd, None, 0);
set_window_composition_attribute(window_state.hwnd, None, 0);
}
WindowBackgroundAppearance::Transparent => {
// Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background
set_window_composition_attribute(hwnd, None, 2);
set_window_composition_attribute(window_state.hwnd, None, 2);
}
WindowBackgroundAppearance::Blurred => {
// Enable acrylic blur
// ACCENT_ENABLE_ACRYLICBLURBEHIND
set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4);
set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4);
}
}
}
@@ -782,11 +794,11 @@ impl PlatformWindow for WindowsWindow {
}
fn draw(&self, scene: &Scene) {
self.0.state.borrow_mut().renderer.draw(scene).log_err();
self.0.state.borrow_mut().renderer.draw(scene)
}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> {
self.0.state.borrow().renderer.sprite_atlas()
self.0.state.borrow().renderer.sprite_atlas().clone()
}
fn get_raw_handle(&self) -> HWND {
@@ -794,11 +806,11 @@ impl PlatformWindow for WindowsWindow {
}
fn gpu_specs(&self) -> Option<GpuSpecs> {
self.0.state.borrow().renderer.gpu_specs().log_err()
Some(self.0.state.borrow().renderer.gpu_specs())
}
fn update_ime_position(&self, _bounds: Bounds<ScaledPixels>) {
// There is no such thing on Windows.
// todo(windows)
}
}
@@ -1294,6 +1306,52 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32
}
}
mod windows_renderer {
use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig};
use raw_window_handle as rwh;
use std::num::NonZeroIsize;
use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE};
use crate::{get_window_long, show_error};
pub(super) fn init(
context: &BladeContext,
hwnd: HWND,
transparent: bool,
) -> anyhow::Result<BladeRenderer> {
let raw = RawWindow { hwnd };
let config = BladeSurfaceConfig {
size: Default::default(),
transparent,
};
BladeRenderer::new(context, &raw, config)
.inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string()))
}
struct RawWindow {
hwnd: HWND,
}
impl rwh::HasWindowHandle for RawWindow {
fn window_handle(&self) -> Result<rwh::WindowHandle<'_>, rwh::HandleError> {
Ok(unsafe {
let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize);
let mut handle = rwh::Win32WindowHandle::new(hwnd);
let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE);
handle.hinstance = NonZeroIsize::new(hinstance);
rwh::WindowHandle::borrow_raw(handle.into())
})
}
}
impl rwh::HasDisplayHandle for RawWindow {
fn display_handle(&self) -> Result<rwh::DisplayHandle<'_>, rwh::HandleError> {
let handle = rwh::WindowsDisplayHandle::new();
Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) })
}
}
}
#[cfg(test)]
mod tests {
use super::ClickState;

View File

@@ -1020,7 +1020,7 @@ impl Window {
|| (active.get()
&& last_input_timestamp.get().elapsed() < Duration::from_secs(1));
if invalidator.is_dirty() || request_frame_options.force_render {
if invalidator.is_dirty() {
measure("frame duration", || {
handle
.update(&mut cx, |_, window, cx| {

View File

@@ -236,22 +236,6 @@ impl HttpClientWithUrl {
)?)
}
/// Builds a Zed Cloud URL using the given path.
pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
let base_url = self.base_url();
let base_api_url = match base_url.as_ref() {
"https://zed.dev" => "https://cloud.zed.dev",
"https://staging.zed.dev" => "https://cloud.zed.dev",
"http://localhost:3000" => "http://localhost:8787",
other => other,
};
Ok(Url::parse_with_params(
&format!("{}{}", base_api_url, path),
query,
)?)
}
/// Builds a Zed LLM URL using the given path.
pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
let base_url = self.base_url();

View File

@@ -15,7 +15,6 @@ doctest = false
[dependencies]
anyhow.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
copilot.workspace = true
editor.workspace = true
feature_flags.workspace = true
@@ -33,6 +32,7 @@ ui.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_llm_client.workspace = true
zeta.workspace = true
[dev-dependencies]

View File

@@ -1,6 +1,5 @@
use anyhow::Result;
use client::{DisableAiSettings, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
use copilot::{Copilot, Status};
use editor::{
Editor, SelectionEffects,
@@ -35,6 +34,7 @@ use workspace::{
notifications::NotificationId,
};
use zed_actions::OpenBrowser;
use zed_llm_client::UsageLimit;
use zeta::RateCompletions;
actions!(

View File

@@ -20,7 +20,6 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
base64.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -38,6 +37,7 @@ telemetry_events.workspace = true
thiserror.workspace = true
util.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -11,7 +11,6 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::Client;
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -27,6 +26,7 @@ use std::time::Duration;
use std::{fmt, io};
use thiserror::Error;
use util::serde::is_default;
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
pub use crate::model::*;
pub use crate::rate_limiter::*;

View File

@@ -1,9 +1,10 @@
use std::io::{Cursor, Write};
use std::sync::Arc;
use crate::role::Role;
use crate::{LanguageModelToolUse, LanguageModelToolUseId};
use anyhow::Result;
use base64::write::EncoderWriter;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use gpui::{
App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
point, px, size,
@@ -11,9 +12,7 @@ use gpui::{
use image::codecs::png::PngEncoder;
use serde::{Deserialize, Serialize};
use util::ResultExt;
use crate::role::Role;
use crate::{LanguageModelToolUse, LanguageModelToolUseId};
use zed_llm_client::{CompletionIntent, CompletionMode};
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct LanguageModelImage {

View File

@@ -15,18 +15,20 @@ path = "src/language_models.rs"
ai_onboarding.workspace = true
anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
aws-config = { workspace = true, features = ["behavior-version-latest"] }
aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] }
aws-credential-types = { workspace = true, features = [
"hardcoded-credentials",
] }
aws_http_client.workspace = true
bedrock.workspace = true
chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
credentials_provider.workspace = true
convert_case.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true
futures.workspace = true
@@ -34,7 +36,6 @@ google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
language.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
@@ -43,6 +44,8 @@ mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
vercel = { workspace = true, features = ["schemars"] }
x_ai = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
proto.workspace = true
release_channel.workspace = true
@@ -59,9 +62,10 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
ui_input.workspace = true
util.workspace = true
vercel = { workspace = true, features = ["schemars"] }
workspace-hack.workspace = true
x_ai = { workspace = true, features = ["schemars"] }
zed_llm_client.workspace = true
language.workspace = true
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs", tag = "v0.6.0", features = [] }
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View File

@@ -17,6 +17,7 @@ use crate::provider::cloud::CloudLanguageModelProvider;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
use crate::provider::local::LocalLanguageModelProvider;
use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
@@ -150,4 +151,8 @@ fn register_language_model_providers(
);
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
registry.register_provider(
LocalLanguageModelProvider::new(client.http_client(), cx),
cx,
);
}

View File

@@ -5,6 +5,7 @@ pub mod copilot_chat;
pub mod deepseek;
pub mod google;
pub mod lmstudio;
pub mod local;
pub mod mistral;
pub mod ollama;
pub mod open_ai;

View File

@@ -3,13 +3,6 @@ use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
@@ -40,6 +33,13 @@ use std::time::Duration;
use thiserror::Error;
use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe};
use zed_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
use crate::provider::google::{GoogleEventMapper, into_google};
@@ -120,10 +120,10 @@ pub struct State {
user_store: Entity<UserStore>,
status: client::Status,
accept_terms_of_service_task: Option<Task<Result<()>>>,
models: Vec<Arc<cloud_llm_client::LanguageModel>>,
default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
models: Vec<Arc<zed_llm_client::LanguageModel>>,
default_model: Option<Arc<zed_llm_client::LanguageModel>>,
default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
_fetch_models_task: Task<()>,
_settings_subscription: Subscription,
_llm_token_subscription: Subscription,
@@ -238,8 +238,8 @@ impl State {
// Right now we represent thinking variants of models as separate models on the client,
// so we need to insert variants for any model that supports thinking.
if model.supports_thinking {
models.push(Arc::new(cloud_llm_client::LanguageModel {
id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
models.push(Arc::new(zed_llm_client::LanguageModel {
id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
display_name: format!("{} Thinking", model.display_name),
..model
}));
@@ -328,7 +328,7 @@ impl CloudLanguageModelProvider {
fn create_language_model(
&self,
model: Arc<cloud_llm_client::LanguageModel>,
model: Arc<zed_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
) -> Arc<dyn LanguageModel> {
Arc::new(CloudLanguageModel {
@@ -518,7 +518,7 @@ fn render_accept_terms(
pub struct CloudLanguageModel {
id: LanguageModelId,
model: Arc<cloud_llm_client::LanguageModel>,
model: Arc<zed_llm_client::LanguageModel>,
llm_api_token: LlmApiToken,
client: Arc<Client>,
request_limiter: RateLimiter,
@@ -611,12 +611,12 @@ impl CloudLanguageModel {
.headers()
.get(CURRENT_PLAN_HEADER_NAME)
.and_then(|plan| plan.to_str().ok())
.and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
.and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
{
let plan = match plan {
cloud_llm_client::Plan::ZedFree => Plan::Free,
cloud_llm_client::Plan::ZedPro => Plan::ZedPro,
cloud_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
zed_llm_client::Plan::ZedFree => Plan::Free,
zed_llm_client::Plan::ZedPro => Plan::ZedPro,
zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
};
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
@@ -729,7 +729,7 @@ impl LanguageModel for CloudLanguageModel {
}
fn upstream_provider_id(&self) -> LanguageModelProviderId {
use cloud_llm_client::LanguageModelProvider::*;
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
@@ -738,7 +738,7 @@ impl LanguageModel for CloudLanguageModel {
}
fn upstream_provider_name(&self) -> LanguageModelProviderName {
use cloud_llm_client::LanguageModelProvider::*;
use zed_llm_client::LanguageModelProvider::*;
match self.model.provider {
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
@@ -772,11 +772,11 @@ impl LanguageModel for CloudLanguageModel {
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic
| cloud_llm_client::LanguageModelProvider::OpenAi => {
zed_llm_client::LanguageModelProvider::Anthropic
| zed_llm_client::LanguageModelProvider::OpenAi => {
LanguageModelToolSchemaFormat::JsonSchema
}
cloud_llm_client::LanguageModelProvider::Google => {
zed_llm_client::LanguageModelProvider::Google => {
LanguageModelToolSchemaFormat::JsonSchemaSubset
}
}
@@ -795,15 +795,15 @@ impl LanguageModel for CloudLanguageModel {
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
match &self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic => {
zed_llm_client::LanguageModelProvider::Anthropic => {
Some(LanguageModelCacheConfiguration {
min_total_token: 2_048,
should_speculate: true,
max_cache_anchors: 4,
})
}
cloud_llm_client::LanguageModelProvider::OpenAi
| cloud_llm_client::LanguageModelProvider::Google => None,
zed_llm_client::LanguageModelProvider::OpenAi
| zed_llm_client::LanguageModelProvider::Google => None,
}
}
@@ -813,17 +813,15 @@ impl LanguageModel for CloudLanguageModel {
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic => {
count_anthropic_tokens(request, cx)
}
cloud_llm_client::LanguageModelProvider::OpenAi => {
zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
zed_llm_client::LanguageModelProvider::OpenAi => {
let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
};
count_open_ai_tokens(request, model, cx)
}
cloud_llm_client::LanguageModelProvider::Google => {
zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let model_id = self.model.id.to_string();
@@ -834,7 +832,7 @@ impl LanguageModel for CloudLanguageModel {
let token = llm_api_token.acquire(&client).await?;
let request_body = CountTokensBody {
provider: cloud_llm_client::LanguageModelProvider::Google,
provider: zed_llm_client::LanguageModelProvider::Google,
model: model_id,
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
generate_content_request,
@@ -895,7 +893,7 @@ impl LanguageModel for CloudLanguageModel {
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
let thinking_allowed = request.thinking_allowed;
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic => {
zed_llm_client::LanguageModelProvider::Anthropic => {
let request = into_anthropic(
request,
self.model.id.to_string(),
@@ -926,7 +924,7 @@ impl LanguageModel for CloudLanguageModel {
prompt_id,
intent,
mode,
provider: cloud_llm_client::LanguageModelProvider::Anthropic,
provider: zed_llm_client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
@@ -950,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
cloud_llm_client::LanguageModelProvider::OpenAi => {
zed_llm_client::LanguageModelProvider::OpenAi => {
let client = self.client.clone();
let model = match open_ai::Model::from_id(&self.model.id.0) {
Ok(model) => model,
@@ -978,7 +976,7 @@ impl LanguageModel for CloudLanguageModel {
prompt_id,
intent,
mode,
provider: cloud_llm_client::LanguageModelProvider::OpenAi,
provider: zed_llm_client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,
@@ -998,7 +996,7 @@ impl LanguageModel for CloudLanguageModel {
});
async move { Ok(future.await?.boxed()) }.boxed()
}
cloud_llm_client::LanguageModelProvider::Google => {
zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
let request =
into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
@@ -1018,7 +1016,7 @@ impl LanguageModel for CloudLanguageModel {
prompt_id,
intent,
mode,
provider: cloud_llm_client::LanguageModelProvider::Google,
provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.model_id.clone(),
provider_request: serde_json::to_value(&request)
.map_err(|e| anyhow!(e))?,

View File

@@ -3,7 +3,6 @@ use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Result, anyhow};
use cloud_llm_client::CompletionIntent;
use collections::HashMap;
use copilot::copilot_chat::{
ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
@@ -31,6 +30,7 @@ use settings::SettingsStore;
use std::time::Duration;
use ui::prelude::*;
use util::debug_panic;
use zed_llm_client::CompletionIntent;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;

View File

@@ -0,0 +1,474 @@
use anyhow::{Result, anyhow};
use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
};
use mistralrs::{
IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
TextModelBuilder,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use ui::{ButtonLike, IconName, Indicator, prelude::*};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LocalSettings {
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: u64,
}
pub struct LocalLanguageModelProvider {
state: Entity<State>,
}
pub struct State {
model: Option<Arc<MistralModel>>,
status: ModelStatus,
}
#[derive(Clone, Debug, PartialEq)]
enum ModelStatus {
NotLoaded,
Loading,
Loaded,
Error(String),
}
impl State {
fn new(_cx: &mut Context<Self>) -> Self {
Self {
model: None,
status: ModelStatus::NotLoaded,
}
}
fn is_authenticated(&self) -> bool {
// Local models don't require authentication
true
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
// Skip if already loaded or currently loading
if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
return Task::ready(Ok(()));
}
self.status = ModelStatus::Loading;
cx.notify();
let background_executor = cx.background_executor().clone();
cx.spawn(async move |this, cx| {
eprintln!("Local model: Starting to load model");
// Move the model loading to a background thread
let model_result = background_executor
.spawn(async move { load_mistral_model().await })
.await;
match model_result {
Ok(model) => {
eprintln!("Local model: Model loaded successfully");
this.update(cx, |state, cx| {
state.model = Some(model);
state.status = ModelStatus::Loaded;
cx.notify();
eprintln!("Local model: Status updated to Loaded");
})?;
Ok(())
}
Err(e) => {
let error_msg = e.to_string();
eprintln!("Local model: Failed to load model - {}", error_msg);
this.update(cx, |state, cx| {
state.status = ModelStatus::Error(error_msg.clone());
cx.notify();
eprintln!("Local model: Status updated to Failed");
})?;
Err(AuthenticateError::Other(anyhow!(
"Failed to load model: {}",
error_msg
)))
}
}
})
}
}
async fn load_mistral_model() -> Result<Arc<MistralModel>> {
println!("\n\n\n\nLoading mistral model...\n\n\n");
eprintln!("Starting to load model: {}", DEFAULT_MODEL);
// Configure the model builder to use background threads for downloads
eprintln!("Creating TextModelBuilder...");
let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
eprintln!("Building model (this should be quick for a 0.5B model)...");
let start_time = std::time::Instant::now();
match builder.build().await {
Ok(model) => {
let elapsed = start_time.elapsed();
eprintln!("Model loaded successfully in {:?}", elapsed);
Ok(Arc::new(model))
}
Err(e) => {
eprintln!("Failed to load model: {:?}", e);
Err(e)
}
}
}
impl LocalLanguageModelProvider {
pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(State::new);
Self { state }
}
}
impl LanguageModelProviderState for LocalLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for LocalLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::Ai
}
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(LocalLanguageModel {
state: self.state.clone(),
request_limiter: RateLimiter::new(4),
})]
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.default_model(cx)
}
fn is_authenticated(&self, _cx: &App) -> bool {
// Local models don't require authentication
true
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
cx.new(|_cx| ConfigurationView {
state: self.state.clone(),
})
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| {
state.model = None;
state.status = ModelStatus::NotLoaded;
cx.notify();
});
Task::ready(Ok(()))
}
}
pub struct LocalLanguageModel {
state: Entity<State>,
request_limiter: RateLimiter,
}
impl LocalLanguageModel {
fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
let mut messages = TextMessages::new();
for message in &request.messages {
let mut text_content = String::new();
for content in &message.content {
match content {
MessageContent::Text(text) => {
text_content.push_str(text);
}
MessageContent::Image { .. } => {
// For now, skip image content
continue;
}
MessageContent::ToolResult { .. } => {
// Skip tool results for now
continue;
}
MessageContent::Thinking { .. } => {
// Skip thinking content
continue;
}
MessageContent::RedactedThinking(_) => {
// Skip redacted thinking
continue;
}
MessageContent::ToolUse(_) => {
// Skip tool use
continue;
}
}
}
if text_content.is_empty() {
continue;
}
let role = match message.role {
Role::User => TextMessageRole::User,
Role::Assistant => TextMessageRole::Assistant,
Role::System => TextMessageRole::System,
};
messages = messages.add_message(role, text_content);
}
messages
}
}
impl LanguageModel for LocalLanguageModel {
fn id(&self) -> LanguageModelId {
LanguageModelId(DEFAULT_MODEL.into())
}
fn name(&self) -> LanguageModelName {
LanguageModelName(DEFAULT_MODEL.into())
}
fn provider_id(&self) -> LanguageModelProviderId {
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
PROVIDER_NAME
}
fn telemetry_id(&self) -> String {
format!("local/{}", DEFAULT_MODEL)
}
fn supports_tools(&self) -> bool {
true
}
fn supports_images(&self) -> bool {
false
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
true
}
fn max_token_count(&self) -> u64 {
128000 // Qwen2.5 supports 128k context
}
fn count_tokens(
&self,
request: LanguageModelRequest,
_cx: &App,
) -> BoxFuture<'static, Result<u64>> {
// Rough estimation: 1 token ≈ 4 characters
let mut total_chars = 0;
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) => total_chars += text.len(),
_ => {}
}
}
}
let tokens = (total_chars / 4) as u64;
futures::future::ready(Ok(tokens)).boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let messages = self.to_mistral_messages(&request);
let state = self.state.clone();
let limiter = self.request_limiter.clone();
cx.spawn(async move |cx| {
let result: Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
> = limiter
.run(async move {
let model = cx
.read_entity(&state, |state, _| {
eprintln!(
"Local model: Checking if model is loaded: {:?}",
state.status
);
state.model.clone()
})
.map_err(|_| {
LanguageModelCompletionError::Other(anyhow!("App state dropped"))
})?
.ok_or_else(|| {
eprintln!("Local model: Model is not loaded!");
LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
})?;
let (mut tx, rx) = mpsc::channel(32);
// Spawn a task to handle the stream
let _ = smol::spawn(async move {
let mut stream = match model.stream_chat_request(messages).await {
Ok(stream) => stream,
Err(e) => {
let _ = tx
.send(Err(LanguageModelCompletionError::Other(anyhow!(
"Failed to start stream: {}",
e
))))
.await;
return;
}
};
while let Some(response) = stream.next().await {
let event = match response {
MistralResponse::Chunk(chunk) => {
if let Some(choice) = chunk.choices.first() {
if let Some(content) = &choice.delta.content {
Some(Ok(LanguageModelCompletionEvent::Text(
content.clone(),
)))
} else if let Some(finish_reason) = &choice.finish_reason {
let stop_reason = match finish_reason.as_str() {
"stop" => StopReason::EndTurn,
"length" => StopReason::MaxTokens,
_ => StopReason::EndTurn,
};
Some(Ok(LanguageModelCompletionEvent::Stop(
stop_reason,
)))
} else {
None
}
} else {
None
}
}
MistralResponse::Done(_response) => {
// For now, we don't emit usage events since the format doesn't match
None
}
_ => None,
};
if let Some(event) = event {
if tx.send(event).await.is_err() {
break;
}
}
}
})
.detach();
Ok(rx.boxed())
})
.await;
result
})
.boxed()
}
}
struct ConfigurationView {
state: Entity<State>,
}
impl Render for ConfigurationView {
fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
let status = self.state.read(cx).status.clone();
div().size_full().child(
div()
.p_4()
.child(
div()
.flex()
.gap_2()
.items_center()
.child(match &status {
ModelStatus::NotLoaded => Label::new("Model not loaded"),
ModelStatus::Loading => Label::new("Loading model..."),
ModelStatus::Loaded => Label::new("Model loaded"),
ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
})
.child(match &status {
ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
ModelStatus::Loading => Indicator::dot().color(Color::Modified),
ModelStatus::Loaded => Indicator::dot().color(Color::Success),
ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
}),
)
.when(!matches!(status, ModelStatus::Loading), |this| {
this.child(
ButtonLike::new("load_model")
.child(Label::new(if matches!(status, ModelStatus::Loaded) {
"Reload Model"
} else {
"Load Model"
}))
.on_click(cx.listener(|this, _, _window, cx| {
this.state.update(cx, |state, cx| {
state.authenticate(cx).detach();
});
})),
)
}),
)
}
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,259 @@
use super::*;
use gpui::TestAppContext;
use http_client::FakeHttpClient;
use language_model::{LanguageModelRequest, MessageContent, Role};
#[gpui::test]
fn test_local_provider_creation(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
cx.read(|cx| {
assert_eq!(provider.id(), PROVIDER_ID);
assert_eq!(provider.name(), PROVIDER_NAME);
assert!(!provider.is_authenticated(cx));
assert_eq!(provider.provided_models(cx).len(), 1);
});
}
#[gpui::test]
fn test_state_initialization(cx: &mut TestAppContext) {
cx.update(|cx| {
let state = cx.new(State::new);
assert!(!state.read(cx).is_authenticated());
assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
assert!(state.read(cx).model.is_none());
});
}
#[gpui::test]
fn test_model_properties(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
// Create a model directly for testing (bypassing authentication)
let model = LocalLanguageModel {
state: provider.state.clone(),
request_limiter: RateLimiter::new(4),
};
assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
assert_eq!(model.provider_id(), PROVIDER_ID);
assert_eq!(model.provider_name(), PROVIDER_NAME);
assert_eq!(model.max_token_count(), 128000);
assert!(!model.supports_tools());
assert!(!model.supports_images());
}
#[gpui::test]
async fn test_token_counting(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
let model = LocalLanguageModel {
state: provider.state.clone(),
request_limiter: RateLimiter::new(4),
};
let request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text("Hello, world!".to_string())],
cache: false,
}],
tools: Vec::new(),
tool_choice: None,
stop: Vec::new(),
temperature: None,
thinking_allowed: false,
};
let count = cx
.update(|cx| model.count_tokens(request, cx))
.await
.unwrap();
// "Hello, world!" is 13 characters, so ~3 tokens
assert!(count > 0);
assert!(count < 10);
}
#[gpui::test]
async fn test_message_conversion(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
let model = LocalLanguageModel {
state: provider.state.clone(),
request_limiter: RateLimiter::new(4),
};
let request = LanguageModelRequest {
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
messages: vec![
language_model::LanguageModelRequestMessage {
role: Role::System,
content: vec![MessageContent::Text(
"You are a helpful assistant.".to_string(),
)],
cache: false,
},
language_model::LanguageModelRequestMessage {
role: Role::User,
content: vec![MessageContent::Text("Hello!".to_string())],
cache: false,
},
language_model::LanguageModelRequestMessage {
role: Role::Assistant,
content: vec![MessageContent::Text("Hi there!".to_string())],
cache: false,
},
],
tools: Vec::new(),
tool_choice: None,
stop: Vec::new(),
temperature: None,
thinking_allowed: false,
};
let _messages = model.to_mistral_messages(&request);
// We can't directly inspect TextMessages, but we can verify it doesn't panic
assert!(true); // Placeholder assertion
}
#[gpui::test]
async fn test_reset_credentials(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
// Simulate loading a model by just setting the status
cx.update(|cx| {
provider.state.update(cx, |state, cx| {
state.status = ModelStatus::Loaded;
// We don't actually set a model since we can't mock it safely
cx.notify();
});
});
cx.read(|cx| {
// Since is_authenticated checks for model presence, we need to check status directly
assert_eq!(provider.state.read(cx).status, ModelStatus::Loaded);
});
// Reset credentials
let task = cx.update(|cx| provider.reset_credentials(cx));
task.await.unwrap();
cx.read(|cx| {
assert!(!provider.is_authenticated(cx));
assert_eq!(provider.state.read(cx).status, ModelStatus::NotLoaded);
assert!(provider.state.read(cx).model.is_none());
});
}
// TODO: Fix this test - need to handle window creation in tests
// #[gpui::test]
// async fn test_configuration_view_rendering(cx: &mut TestAppContext) {
// let http_client = FakeHttpClient::with_200_response();
// let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
// let view = cx.update(|cx| provider.configuration_view(cx.window(), cx));
// // Basic test to ensure the view can be created without panicking
// assert!(view.entity_type() == std::any::TypeId::of::<ConfigurationView>());
// }
#[gpui::test]
fn test_status_transitions(cx: &mut TestAppContext) {
cx.update(|cx| {
let state = cx.new(State::new);
// Initial state
assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
// Transition to loading
state.update(cx, |state, cx| {
state.status = ModelStatus::Loading;
cx.notify();
});
assert_eq!(state.read(cx).status, ModelStatus::Loading);
// Transition to loaded
state.update(cx, |state, cx| {
state.status = ModelStatus::Loaded;
cx.notify();
});
assert_eq!(state.read(cx).status, ModelStatus::Loaded);
// Transition to error
state.update(cx, |state, cx| {
state.status = ModelStatus::Error("Test error".to_string());
cx.notify();
});
match &state.read(cx).status {
ModelStatus::Error(msg) => assert_eq!(msg, "Test error"),
_ => panic!("Expected error status"),
}
});
}
#[gpui::test]
fn test_provider_shows_models_without_authentication(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
cx.read(|cx| {
// Provider should show models even when not authenticated
let models = provider.provided_models(cx);
assert_eq!(models.len(), 1);
let model = &models[0];
assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
assert_eq!(model.provider_id(), PROVIDER_ID);
assert_eq!(model.provider_name(), PROVIDER_NAME);
});
}
#[gpui::test]
fn test_provider_has_icon(cx: &mut TestAppContext) {
let http_client = FakeHttpClient::with_200_response();
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
assert_eq!(provider.icon(), IconName::Ai);
}
#[gpui::test]
fn test_provider_appears_in_registry(cx: &mut TestAppContext) {
use language_model::LanguageModelRegistry;
cx.update(|cx| {
let registry = cx.new(|_| LanguageModelRegistry::default());
let http_client = FakeHttpClient::with_200_response();
// Register the local provider
registry.update(cx, |registry, cx| {
let provider = LocalLanguageModelProvider::new(Arc::new(http_client), cx);
registry.register_provider(provider, cx);
});
// Verify the provider is registered
let provider = registry.read(cx).provider(&PROVIDER_ID).unwrap();
assert_eq!(provider.name(), PROVIDER_NAME);
assert_eq!(provider.icon(), IconName::Ai);
// Verify it provides models even without authentication
let models = provider.provided_models(cx);
assert_eq!(models.len(), 1);
assert_eq!(models[0].id(), LanguageModelId(DEFAULT_MODEL.into()));
});
}

View File

@@ -421,14 +421,14 @@ impl LanguageServer {
.map(|stderr| {
let io_handlers = io_handlers.clone();
let stderr_captures = stderr_capture.clone();
cx.background_spawn(async move {
cx.spawn(async move |_| {
Self::handle_stderr(stderr, io_handlers, stderr_captures)
.log_err()
.await
})
})
.unwrap_or_else(|| Task::ready(None));
let input_task = cx.background_spawn(async move {
let input_task = cx.spawn(async move |_| {
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
stdout.or(stderr)
});
@@ -846,7 +846,7 @@ impl LanguageServer {
configuration: Arc<DidChangeConfigurationParams>,
cx: &App,
) -> Task<Result<Arc<Self>>> {
cx.background_spawn(async move {
cx.spawn(async move |_| {
let response = self
.request::<request::Initialize>(params)
.await

View File

@@ -18,19 +18,12 @@ default = []
anyhow.workspace = true
command_palette_hooks.workspace = true
db.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
gpui.workspace = true
language.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
settings.workspace = true
theme.workspace = true
ui.workspace = true
util.workspace = true
workspace-hack.workspace = true
workspace.workspace = true
workspace-hack.workspace = true
zed_actions.workspace = true
zlog.workspace = true

Some files were not shown because too many files have changed in this diff Show More