Compare commits
6 Commits
remove-d2d
...
local-hugg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bfc8954f3 | ||
|
|
18ca69f07f | ||
|
|
f90459656f | ||
|
|
5830628568 | ||
|
|
f62e693b8f | ||
|
|
4abdec044f |
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@@ -771,8 +771,7 @@ jobs:
|
||||
timeout-minutes: 120
|
||||
name: Create a Windows installer
|
||||
runs-on: [self-hosted, Windows, X64]
|
||||
if: contains(github.event.pull_request.labels.*.name, 'run-bundling')
|
||||
# if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
if: (startsWith(github.ref, 'refs/tags/v') || contains(github.event.pull_request.labels.*.name, 'run-bundling'))
|
||||
needs: [windows_tests]
|
||||
env:
|
||||
AZURE_TENANT_ID: ${{ secrets.AZURE_SIGNING_TENANT_ID }}
|
||||
|
||||
4824
Cargo.lock
generated
4824
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
31
Cargo.toml
31
Cargo.toml
@@ -1,13 +1,13 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/acp_thread",
|
||||
"crates/activity_indicator",
|
||||
"crates/agent",
|
||||
"crates/agent_servers",
|
||||
"crates/agent_settings",
|
||||
"crates/acp_thread",
|
||||
"crates/agent_ui",
|
||||
"crates/agent",
|
||||
"crates/agent_settings",
|
||||
"crates/ai_onboarding",
|
||||
"crates/agent_servers",
|
||||
"crates/anthropic",
|
||||
"crates/askpass",
|
||||
"crates/assets",
|
||||
@@ -29,9 +29,6 @@ members = [
|
||||
"crates/cli",
|
||||
"crates/client",
|
||||
"crates/clock",
|
||||
"crates/cloud_api_client",
|
||||
"crates/cloud_api_types",
|
||||
"crates/cloud_llm_client",
|
||||
"crates/collab",
|
||||
"crates/collab_ui",
|
||||
"crates/collections",
|
||||
@@ -51,8 +48,8 @@ members = [
|
||||
"crates/diagnostics",
|
||||
"crates/docs_preprocessor",
|
||||
"crates/editor",
|
||||
"crates/eval",
|
||||
"crates/explorer_command_injector",
|
||||
"crates/eval",
|
||||
"crates/extension",
|
||||
"crates/extension_api",
|
||||
"crates/extension_cli",
|
||||
@@ -73,6 +70,7 @@ members = [
|
||||
"crates/gpui",
|
||||
"crates/gpui_macros",
|
||||
"crates/gpui_tokio",
|
||||
|
||||
"crates/html_to_markdown",
|
||||
"crates/http_client",
|
||||
"crates/http_client_tls",
|
||||
@@ -101,6 +99,7 @@ members = [
|
||||
"crates/markdown_preview",
|
||||
"crates/media",
|
||||
"crates/menu",
|
||||
"crates/svg_preview",
|
||||
"crates/migrator",
|
||||
"crates/mistral",
|
||||
"crates/multi_buffer",
|
||||
@@ -141,7 +140,6 @@ members = [
|
||||
"crates/semantic_version",
|
||||
"crates/session",
|
||||
"crates/settings",
|
||||
"crates/settings_profile_selector",
|
||||
"crates/settings_ui",
|
||||
"crates/snippet",
|
||||
"crates/snippet_provider",
|
||||
@@ -154,7 +152,6 @@ members = [
|
||||
"crates/sum_tree",
|
||||
"crates/supermaven",
|
||||
"crates/supermaven_api",
|
||||
"crates/svg_preview",
|
||||
"crates/tab_switcher",
|
||||
"crates/task",
|
||||
"crates/tasks_ui",
|
||||
@@ -254,9 +251,6 @@ channel = { path = "crates/channel" }
|
||||
cli = { path = "crates/cli" }
|
||||
client = { path = "crates/client" }
|
||||
clock = { path = "crates/clock" }
|
||||
cloud_api_client = { path = "crates/cloud_api_client" }
|
||||
cloud_api_types = { path = "crates/cloud_api_types" }
|
||||
cloud_llm_client = { path = "crates/cloud_llm_client" }
|
||||
collab = { path = "crates/collab" }
|
||||
collab_ui = { path = "crates/collab_ui" }
|
||||
collections = { path = "crates/collections" }
|
||||
@@ -343,7 +337,6 @@ picker = { path = "crates/picker" }
|
||||
plugin = { path = "crates/plugin" }
|
||||
plugin_macros = { path = "crates/plugin_macros" }
|
||||
prettier = { path = "crates/prettier" }
|
||||
settings_profile_selector = { path = "crates/settings_profile_selector" }
|
||||
project = { path = "crates/project" }
|
||||
project_panel = { path = "crates/project_panel" }
|
||||
project_symbols = { path = "crates/project_symbols" }
|
||||
@@ -435,7 +428,7 @@ async-pipe = { git = "https://github.com/zed-industries/async-pipe-rs", rev = "8
|
||||
async-recursion = "1.0.0"
|
||||
async-tar = "0.5.0"
|
||||
async-trait = "0.1"
|
||||
async-tungstenite = "0.29.1"
|
||||
async-tungstenite = "0.30.0"
|
||||
async_zip = { version = "0.0.17", features = ["deflate", "deflate64"] }
|
||||
aws-config = { version = "1.6.1", features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { version = "1.2.2", features = [
|
||||
@@ -522,7 +515,7 @@ objc = "0.2"
|
||||
open = "5.0.0"
|
||||
ordered-float = "2.1.1"
|
||||
palette = { version = "0.7.5", default-features = false, features = ["std"] }
|
||||
parking_lot = "0.12.1"
|
||||
parking_lot = "0.12.4"
|
||||
partial-json-fixer = "0.5.3"
|
||||
parse_int = "0.9"
|
||||
pathdiff = "0.2"
|
||||
@@ -652,6 +645,7 @@ which = "6.0.0"
|
||||
windows-core = "0.61"
|
||||
wit-component = "0.221"
|
||||
workspace-hack = "0.1.0"
|
||||
zed_llm_client = "= 0.8.6"
|
||||
zstd = "0.11"
|
||||
|
||||
[workspace.dependencies.async-stripe]
|
||||
@@ -680,13 +674,8 @@ features = [
|
||||
"Win32_Globalization",
|
||||
"Win32_Graphics_Direct2D",
|
||||
"Win32_Graphics_Direct2D_Common",
|
||||
"Win32_Graphics_Direct3D",
|
||||
"Win32_Graphics_Direct3D11",
|
||||
"Win32_Graphics_Direct3D_Fxc",
|
||||
"Win32_Graphics_DirectComposition",
|
||||
"Win32_Graphics_DirectWrite",
|
||||
"Win32_Graphics_Dwm",
|
||||
"Win32_Graphics_Dxgi",
|
||||
"Win32_Graphics_Dxgi_Common",
|
||||
"Win32_Graphics_Gdi",
|
||||
"Win32_Graphics_Imaging",
|
||||
|
||||
@@ -232,7 +232,7 @@
|
||||
"ctrl-n": "agent::NewThread",
|
||||
"ctrl-alt-n": "agent::NewTextThread",
|
||||
"ctrl-shift-h": "agent::OpenHistory",
|
||||
"ctrl-alt-c": "agent::OpenSettings",
|
||||
"ctrl-alt-c": "agent::OpenConfiguration",
|
||||
"ctrl-alt-p": "agent::OpenRulesLibrary",
|
||||
"ctrl-i": "agent::ToggleProfileSelector",
|
||||
"ctrl-alt-/": "agent::ToggleModelSelector",
|
||||
|
||||
@@ -272,7 +272,7 @@
|
||||
"cmd-n": "agent::NewThread",
|
||||
"cmd-alt-n": "agent::NewTextThread",
|
||||
"cmd-shift-h": "agent::OpenHistory",
|
||||
"cmd-alt-c": "agent::OpenSettings",
|
||||
"cmd-alt-c": "agent::OpenConfiguration",
|
||||
"cmd-alt-p": "agent::OpenRulesLibrary",
|
||||
"cmd-i": "agent::ToggleProfileSelector",
|
||||
"cmd-alt-/": "agent::ToggleModelSelector",
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"ctrl-shift-i": "agent::ToggleFocus",
|
||||
"ctrl-l": "agent::ToggleFocus",
|
||||
"ctrl-shift-l": "agent::ToggleFocus",
|
||||
"ctrl-shift-j": "agent::OpenSettings"
|
||||
"ctrl-shift-j": "agent::OpenConfiguration"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"cmd-shift-i": "agent::ToggleFocus",
|
||||
"cmd-l": "agent::ToggleFocus",
|
||||
"cmd-shift-l": "agent::ToggleFocus",
|
||||
"cmd-shift-j": "agent::OpenSettings"
|
||||
"cmd-shift-j": "agent::OpenConfiguration"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1877,8 +1877,5 @@
|
||||
"save_breakpoints": true,
|
||||
"dock": "bottom",
|
||||
"button": true
|
||||
},
|
||||
// Configures any number of settings profiles that are temporarily applied
|
||||
// when selected from `settings profile selector: toggle`.
|
||||
"profiles": []
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ assistant_context.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
context_server.workspace = true
|
||||
@@ -36,9 +35,9 @@ futures.workspace = true
|
||||
git.workspace = true
|
||||
gpui.workspace = true
|
||||
heed.workspace = true
|
||||
http_client.workspace = true
|
||||
icons.workspace = true
|
||||
indoc.workspace = true
|
||||
http_client.workspace = true
|
||||
itertools.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
@@ -64,6 +63,7 @@ time.workspace = true
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
zstd.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -13,7 +13,6 @@ use anyhow::{Result, anyhow};
|
||||
use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use collections::HashMap;
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use futures::{FutureExt, StreamExt as _, future::Shared};
|
||||
@@ -50,6 +49,7 @@ use std::{
|
||||
use thiserror::Error;
|
||||
use util::{ResultExt as _, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
|
||||
const MAX_RETRY_ATTEMPTS: u8 = 4;
|
||||
const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
|
||||
@@ -1681,7 +1681,7 @@ impl Thread {
|
||||
|
||||
let completion_mode = request
|
||||
.mode
|
||||
.unwrap_or(cloud_llm_client::CompletionMode::Normal);
|
||||
.unwrap_or(zed_llm_client::CompletionMode::Normal);
|
||||
|
||||
self.last_received_chunk_at = Some(Instant::now());
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ doctest = false
|
||||
[dependencies]
|
||||
acp_thread.workspace = true
|
||||
agent-client-protocol.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agentic-coding-protocol.workspace = true
|
||||
anyhow.workspace = true
|
||||
collections.workspace = true
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::path::PathBuf;
|
||||
use crate::claude::tools::{ClaudeTool, EditToolParams, ReadToolParams};
|
||||
use acp_thread::AcpThread;
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context, Result};
|
||||
use collections::HashMap;
|
||||
use context_server::listener::{McpServerTool, ToolResponse};
|
||||
@@ -13,6 +14,7 @@ use context_server::types::{
|
||||
use gpui::{App, AsyncApp, Task, WeakEntity};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Settings;
|
||||
|
||||
pub struct ClaudeZedMcpServer {
|
||||
server: context_server::listener::McpServer,
|
||||
@@ -114,6 +116,7 @@ pub struct PermissionToolParams {
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[cfg_attr(test, derive(serde::Deserialize))]
|
||||
pub struct PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior,
|
||||
updated_input: serde_json::Value,
|
||||
@@ -121,7 +124,8 @@ pub struct PermissionToolResponse {
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum PermissionToolBehavior {
|
||||
#[cfg_attr(test, derive(serde::Deserialize))]
|
||||
pub enum PermissionToolBehavior {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
@@ -141,6 +145,26 @@ impl McpServerTool for PermissionTool {
|
||||
input: Self::Input,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<ToolResponse<Self::Output>> {
|
||||
// Check if we should automatically allow tool actions
|
||||
let always_allow =
|
||||
cx.update(|cx| AgentSettings::get_global(cx).always_allow_tool_actions)?;
|
||||
|
||||
if always_allow {
|
||||
// If always_allow_tool_actions is true, immediately return Allow without prompting
|
||||
let response = PermissionToolResponse {
|
||||
behavior: PermissionToolBehavior::Allow,
|
||||
updated_input: input.input,
|
||||
};
|
||||
|
||||
return Ok(ToolResponse {
|
||||
content: vec![ToolResponseContent::Text {
|
||||
text: serde_json::to_string(&response)?,
|
||||
}],
|
||||
structured_content: (),
|
||||
});
|
||||
}
|
||||
|
||||
// Otherwise, proceed with the normal permission flow
|
||||
let mut thread_rx = self.thread_rx.clone();
|
||||
let Some(thread) = thread_rx.recv().await?.upgrade() else {
|
||||
anyhow::bail!("Thread closed");
|
||||
@@ -300,3 +324,78 @@ impl McpServerTool for EditTool {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use project::Project;
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_permission_tool_respects_always_allow_setting(cx: &mut TestAppContext) {
|
||||
// Initialize settings
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
agent_settings::init(cx);
|
||||
});
|
||||
|
||||
// Create a test thread
|
||||
let project = cx.update(|cx| gpui::Entity::new(cx, |_cx| Project::local()));
|
||||
let thread = cx.update(|cx| {
|
||||
gpui::Entity::new(cx, |_cx| {
|
||||
acp_thread::AcpThread::new(
|
||||
acp::ConnectionId("test".into()),
|
||||
project,
|
||||
std::path::Path::new("/tmp"),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let (tx, rx) = watch::channel(thread.downgrade());
|
||||
let tool = PermissionTool { thread_rx: rx };
|
||||
|
||||
// Test with always_allow_tool_actions = true
|
||||
cx.update(|cx| {
|
||||
AgentSettings::override_global(
|
||||
AgentSettings {
|
||||
always_allow_tool_actions: true,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
let input = PermissionToolParams {
|
||||
tool_name: "test_tool".to_string(),
|
||||
input: serde_json::json!({"test": "data"}),
|
||||
tool_use_id: Some("test_id".to_string()),
|
||||
};
|
||||
|
||||
let result = tool.run(input.clone(), &mut cx.to_async()).await.unwrap();
|
||||
|
||||
// Should return Allow without prompting
|
||||
assert_eq!(result.content.len(), 1);
|
||||
if let ToolResponseContent::Text { text } = &result.content[0] {
|
||||
let response: PermissionToolResponse = serde_json::from_str(text).unwrap();
|
||||
assert!(matches!(response.behavior, PermissionToolBehavior::Allow));
|
||||
} else {
|
||||
panic!("Expected text response");
|
||||
}
|
||||
|
||||
// Test with always_allow_tool_actions = false
|
||||
cx.update(|cx| {
|
||||
AgentSettings::override_global(
|
||||
AgentSettings {
|
||||
always_allow_tool_actions: false,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
// This test would require mocking the permission prompt response
|
||||
// In the real scenario, it would wait for user input
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::{
|
||||
use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
|
||||
use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
|
||||
use agent_client_protocol as acp;
|
||||
use agent_settings::AgentSettings;
|
||||
|
||||
use futures::{FutureExt, StreamExt, channel::mpsc, select};
|
||||
use gpui::{Entity, TestAppContext};
|
||||
@@ -241,6 +242,57 @@ pub async fn test_tool_call_with_confirmation(
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_tool_call_always_allow(
|
||||
server: impl AgentServer + 'static,
|
||||
cx: &mut TestAppContext,
|
||||
) {
|
||||
let fs = init_test(cx).await;
|
||||
let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
|
||||
|
||||
// Enable always_allow_tool_actions
|
||||
cx.update(|cx| {
|
||||
AgentSettings::override_global(
|
||||
AgentSettings {
|
||||
always_allow_tool_actions: true,
|
||||
..Default::default()
|
||||
},
|
||||
cx,
|
||||
);
|
||||
});
|
||||
|
||||
let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
|
||||
let full_turn = thread.update(cx, |thread, cx| {
|
||||
thread.send_raw(
|
||||
r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Wait for the tool call to complete
|
||||
full_turn.await.unwrap();
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
// With always_allow_tool_actions enabled, the tool call should be immediately allowed
|
||||
// without waiting for confirmation
|
||||
let tool_call_entry = thread
|
||||
.entries()
|
||||
.iter()
|
||||
.find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
|
||||
.expect("Expected a tool call entry");
|
||||
|
||||
let AgentThreadEntry::ToolCall(tool_call) = tool_call_entry else {
|
||||
panic!("Expected tool call entry");
|
||||
};
|
||||
|
||||
// Should be allowed, not waiting for confirmation
|
||||
assert!(
|
||||
matches!(tool_call.status, ToolCallStatus::Allowed { .. }),
|
||||
"Expected tool call to be allowed automatically, but got {:?}",
|
||||
tool_call.status
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx).await;
|
||||
|
||||
@@ -351,6 +403,12 @@ macro_rules! common_e2e_tests {
|
||||
async fn cancel(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_cancel($server, cx).await;
|
||||
}
|
||||
|
||||
#[::gpui::test]
|
||||
#[cfg_attr(not(feature = "e2e"), ignore)]
|
||||
async fn tool_call_always_allow(cx: &mut ::gpui::TestAppContext) {
|
||||
$crate::e2e_tests::test_tool_call_always_allow($server, cx).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ path = "src/agent_settings.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
gpui.workspace = true
|
||||
language_model.workspace = true
|
||||
@@ -21,6 +20,7 @@ schemars.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
fs.workspace = true
|
||||
|
||||
@@ -321,11 +321,11 @@ pub enum CompletionMode {
|
||||
Burn,
|
||||
}
|
||||
|
||||
impl From<CompletionMode> for cloud_llm_client::CompletionMode {
|
||||
impl From<CompletionMode> for zed_llm_client::CompletionMode {
|
||||
fn from(value: CompletionMode) -> Self {
|
||||
match value {
|
||||
CompletionMode::Normal => cloud_llm_client::CompletionMode::Normal,
|
||||
CompletionMode::Burn => cloud_llm_client::CompletionMode::Max,
|
||||
CompletionMode::Normal => zed_llm_client::CompletionMode::Normal,
|
||||
CompletionMode::Burn => zed_llm_client::CompletionMode::Max,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ audio.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -47,9 +46,9 @@ futures.workspace = true
|
||||
fuzzy.workspace = true
|
||||
gpui.workspace = true
|
||||
html_to_markdown.workspace = true
|
||||
indoc.workspace = true
|
||||
http_client.workspace = true
|
||||
indexed_docs.workspace = true
|
||||
indoc.workspace = true
|
||||
inventory.workspace = true
|
||||
itertools.workspace = true
|
||||
jsonschema.workspace = true
|
||||
@@ -98,6 +97,7 @@ watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
assistant_tools.workspace = true
|
||||
|
||||
@@ -14,7 +14,6 @@ use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
|
||||
use anyhow::Context as _;
|
||||
use assistant_tool::ToolUseStatus;
|
||||
use audio::{Audio, Sound};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
use editor::scroll::Autoscroll;
|
||||
@@ -53,6 +52,7 @@ use util::ResultExt as _;
|
||||
use util::markdown::MarkdownCodeBlock;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::assistant::OpenRulesLibrary;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
const CODEBLOCK_CONTAINER_GROUP: &str = "codeblock_container";
|
||||
const EDIT_PREVIOUS_MESSAGE_MIN_LINES: usize = 1;
|
||||
|
||||
@@ -38,6 +38,13 @@ impl AgentModelSelector {
|
||||
move |model, cx| {
|
||||
let provider = model.provider_id().0.to_string();
|
||||
let model_id = model.id().0.to_string();
|
||||
|
||||
// Authenticate the provider when a model is selected
|
||||
let registry = LanguageModelRegistry::read_global(cx);
|
||||
if let Some(provider) = registry.provider(&model.provider_id()) {
|
||||
provider.authenticate(cx).detach();
|
||||
}
|
||||
|
||||
match &model_usage_context {
|
||||
ModelUsageContext::Thread(thread) => {
|
||||
thread.update(cx, |thread, cx| {
|
||||
|
||||
@@ -44,7 +44,6 @@ use assistant_context::{AssistantContext, ContextEvent, ContextSummary};
|
||||
use assistant_slash_command::SlashCommandWorkingSet;
|
||||
use assistant_tool::ToolWorkingSet;
|
||||
use client::{DisableAiSettings, UserStore, zed_urls};
|
||||
use cloud_llm_client::{CompletionIntent, UsageLimit};
|
||||
use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer};
|
||||
use feature_flags::{self, FeatureFlagAppExt};
|
||||
use fs::Fs;
|
||||
@@ -78,9 +77,10 @@ use workspace::{
|
||||
};
|
||||
use zed_actions::{
|
||||
DecreaseBufferFontSize, IncreaseBufferFontSize, ResetBufferFontSize,
|
||||
agent::{OpenOnboardingModal, OpenSettings, ResetOnboarding, ToggleModelSelector},
|
||||
agent::{OpenConfiguration, OpenOnboardingModal, ResetOnboarding, ToggleModelSelector},
|
||||
assistant::{OpenRulesLibrary, ToggleFocus},
|
||||
};
|
||||
use zed_llm_client::{CompletionIntent, UsageLimit};
|
||||
|
||||
const AGENT_PANEL_KEY: &str = "agent_panel";
|
||||
|
||||
@@ -105,7 +105,7 @@ pub fn init(cx: &mut App) {
|
||||
panel.update(cx, |panel, cx| panel.open_history(window, cx));
|
||||
}
|
||||
})
|
||||
.register_action(|workspace, _: &OpenSettings, window, cx| {
|
||||
.register_action(|workspace, _: &OpenConfiguration, window, cx| {
|
||||
if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
|
||||
workspace.focus_panel::<AgentPanel>(window, cx);
|
||||
panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
|
||||
@@ -2088,7 +2088,7 @@ impl AgentPanel {
|
||||
|
||||
menu = menu
|
||||
.action("Rules…", Box::new(OpenRulesLibrary::default()))
|
||||
.action("Settings", Box::new(OpenSettings))
|
||||
.action("Settings", Box::new(OpenConfiguration))
|
||||
.action(zoom_in_label, Box::new(ToggleZoom));
|
||||
menu
|
||||
}))
|
||||
@@ -2482,14 +2482,14 @@ impl AgentPanel {
|
||||
.icon_color(Color::Muted)
|
||||
.full_width()
|
||||
.key_binding(KeyBinding::for_action_in(
|
||||
&OpenSettings,
|
||||
&OpenConfiguration,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
))
|
||||
.on_click(|_event, window, cx| {
|
||||
window.dispatch_action(
|
||||
OpenSettings.boxed_clone(),
|
||||
OpenConfiguration.boxed_clone(),
|
||||
cx,
|
||||
)
|
||||
}),
|
||||
@@ -2713,11 +2713,16 @@ impl AgentPanel {
|
||||
.style(ButtonStyle::Tinted(ui::TintColor::Warning))
|
||||
.label_size(LabelSize::Small)
|
||||
.key_binding(
|
||||
KeyBinding::for_action_in(&OpenSettings, &focus_handle, window, cx)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
KeyBinding::for_action_in(
|
||||
&OpenConfiguration,
|
||||
&focus_handle,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.map(|kb| kb.size(rems_from_px(12.))),
|
||||
)
|
||||
.on_click(|_event, window, cx| {
|
||||
window.dispatch_action(OpenSettings.boxed_clone(), cx)
|
||||
window.dispatch_action(OpenConfiguration.boxed_clone(), cx)
|
||||
}),
|
||||
),
|
||||
ConfigurationError::ProviderPendingTermsAcceptance(provider) => {
|
||||
@@ -3221,7 +3226,7 @@ impl Render for AgentPanel {
|
||||
.on_action(cx.listener(|this, _: &OpenHistory, window, cx| {
|
||||
this.open_history(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(|this, _: &OpenSettings, window, cx| {
|
||||
.on_action(cx.listener(|this, _: &OpenConfiguration, window, cx| {
|
||||
this.open_configuration(window, cx);
|
||||
}))
|
||||
.on_action(cx.listener(Self::open_active_thread_as_markdown))
|
||||
|
||||
@@ -265,8 +265,8 @@ fn update_command_palette_filter(cx: &mut App) {
|
||||
filter.hide_namespace("agent");
|
||||
filter.hide_namespace("assistant");
|
||||
filter.hide_namespace("copilot");
|
||||
filter.hide_namespace("supermaven");
|
||||
filter.hide_namespace("zed_predict_onboarding");
|
||||
|
||||
filter.hide_namespace("edit_prediction");
|
||||
|
||||
use editor::actions::{
|
||||
|
||||
@@ -6,7 +6,6 @@ use agent::{
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
|
||||
use futures::{
|
||||
@@ -36,6 +35,7 @@ use std::{
|
||||
};
|
||||
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
|
||||
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub struct BufferCodegen {
|
||||
alternatives: Vec<Entity<CodegenAlternative>>,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#![allow(unused, dead_code)]
|
||||
|
||||
use client::{ModelRequestUsage, RequestUsage};
|
||||
use cloud_llm_client::{Plan, UsageLimit};
|
||||
use gpui::Global;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use ui::prelude::*;
|
||||
use zed_llm_client::{Plan, UsageLimit};
|
||||
|
||||
/// Debug only: Used for testing various account states
|
||||
///
|
||||
|
||||
@@ -48,7 +48,7 @@ use text::{OffsetRangeExt, ToPoint as _};
|
||||
use ui::prelude::*;
|
||||
use util::{RangeExt, ResultExt, maybe};
|
||||
use workspace::{ItemHandle, Toast, Workspace, dock::Panel, notifications::NotificationId};
|
||||
use zed_actions::agent::OpenSettings;
|
||||
use zed_actions::agent::OpenConfiguration;
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
@@ -345,7 +345,7 @@ impl InlineAssistant {
|
||||
if let Some(answer) = answer {
|
||||
if answer == 0 {
|
||||
cx.update(|window, cx| {
|
||||
window.dispatch_action(Box::new(OpenSettings), cx)
|
||||
window.dispatch_action(Box::new(OpenConfiguration), cx)
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -576,7 +576,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
.icon_position(IconPosition::Start)
|
||||
.on_click(|_, window, cx| {
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenSettings.boxed_clone(),
|
||||
zed_actions::agent::OpenConfiguration.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
|
||||
@@ -18,7 +18,6 @@ use agent_settings::{AgentSettings, CompletionMode};
|
||||
use ai_onboarding::ApiKeysWithProviders;
|
||||
use buffer_diff::BufferDiff;
|
||||
use client::UserStore;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, HashSet};
|
||||
use editor::actions::{MoveUp, Paste};
|
||||
use editor::display_map::CreaseId;
|
||||
@@ -54,6 +53,7 @@ use util::ResultExt as _;
|
||||
use workspace::{CollaboratorId, Workspace};
|
||||
use zed_actions::agent::Chat;
|
||||
use zed_actions::agent::ToggleModelSelector;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider, crease_for_mention};
|
||||
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
|
||||
@@ -1300,11 +1300,11 @@ impl MessageEditor {
|
||||
let plan = user_store
|
||||
.current_plan()
|
||||
.map(|plan| match plan {
|
||||
Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
})
|
||||
.unwrap_or(cloud_llm_client::Plan::ZedFree);
|
||||
.unwrap_or(zed_llm_client::Plan::ZedFree);
|
||||
|
||||
let usage = user_store.model_request_usage()?;
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ use agent::{
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use editor::{MultiBuffer, actions::SelectAll};
|
||||
use fs::Fs;
|
||||
@@ -28,6 +27,7 @@ use terminal_view::TerminalView;
|
||||
use ui::prelude::*;
|
||||
use util::ResultExt;
|
||||
use workspace::{Toast, Workspace, notifications::NotificationId};
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub fn init(
|
||||
fs: Arc<dyn Fs>,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use client::{ModelRequestUsage, RequestUsage, zed_urls};
|
||||
use cloud_llm_client::{Plan, UsageLimit};
|
||||
use component::{empty_example, example_group_with_title, single_example};
|
||||
use gpui::{AnyElement, App, IntoElement, RenderOnce, Window};
|
||||
use ui::{Callout, prelude::*};
|
||||
use zed_llm_client::{Plan, UsageLimit};
|
||||
|
||||
#[derive(IntoElement, RegisterComponent)]
|
||||
pub struct UsageCallout {
|
||||
|
||||
@@ -136,7 +136,10 @@ impl RenderOnce for ApiKeysWithoutProviders {
|
||||
.full_width()
|
||||
.style(ButtonStyle::Outlined)
|
||||
.on_click(move |_, window, cx| {
|
||||
window.dispatch_action(zed_actions::agent::OpenSettings.boxed_clone(), cx);
|
||||
window.dispatch_action(
|
||||
zed_actions::agent::OpenConfiguration.boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ assistant_slash_commands.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
context_server.workspace = true
|
||||
fs.workspace = true
|
||||
@@ -49,6 +48,7 @@ util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
||||
|
||||
@@ -11,7 +11,6 @@ use assistant_slash_command::{
|
||||
use assistant_slash_commands::FileCommandMetadata;
|
||||
use client::{self, Client, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::{Fs, RenameOptions};
|
||||
use futures::{FutureExt, StreamExt, future::Shared};
|
||||
@@ -47,6 +46,7 @@ use text::{BufferSnapshot, ToPoint};
|
||||
use ui::IconName;
|
||||
use util::{ResultExt, TryFutureExt, post_inc};
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub use crate::context_store::*;
|
||||
|
||||
|
||||
@@ -21,11 +21,9 @@ assistant_tool.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
derive_more.workspace = true
|
||||
diffy = "0.4.2"
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -65,6 +63,8 @@ web_search.workspace = true
|
||||
which.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
diffy = "0.4.2"
|
||||
|
||||
[dev-dependencies]
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -7,7 +7,6 @@ mod streaming_fuzzy_matcher;
|
||||
use crate::{Template, Templates};
|
||||
use anyhow::Result;
|
||||
use assistant_tool::ActionLog;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use create_file_parser::{CreateFileParser, CreateFileParserEvent};
|
||||
pub use edit_parser::EditFormat;
|
||||
use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
|
||||
@@ -30,6 +29,7 @@ use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::
|
||||
use streaming_diff::{CharOperation, StreamingDiff};
|
||||
use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
|
||||
use util::debug_panic;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CreateFilePromptTemplate {
|
||||
|
||||
@@ -6,7 +6,6 @@ use anyhow::{Context as _, Result, anyhow};
|
||||
use assistant_tool::{
|
||||
ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
|
||||
};
|
||||
use cloud_llm_client::{WebSearchResponse, WebSearchResult};
|
||||
use futures::{Future, FutureExt, TryFutureExt};
|
||||
use gpui::{
|
||||
AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
|
||||
@@ -18,6 +17,7 @@ use serde::{Deserialize, Serialize};
|
||||
use ui::{IconName, Tooltip, prelude::*};
|
||||
use web_search::WebSearchRegistry;
|
||||
use workspace::Workspace;
|
||||
use zed_llm_client::{WebSearchResponse, WebSearchResult};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct WebSearchToolInput {
|
||||
|
||||
@@ -18,6 +18,6 @@ collections.workspace = true
|
||||
derive_more.workspace = true
|
||||
gpui.workspace = true
|
||||
parking_lot.workspace = true
|
||||
rodio = { version = "0.21.1", default-features = false, features = ["wav", "playback", "tracing"] }
|
||||
rodio = { version = "0.20.0", default-features = false, features = ["wav"] }
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
@@ -3,9 +3,12 @@ use std::{io::Cursor, sync::Arc};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, AssetSource, Global};
|
||||
use rodio::{Decoder, Source, source::Buffered};
|
||||
use rodio::{
|
||||
Decoder, Source,
|
||||
source::{Buffered, SamplesConverter},
|
||||
};
|
||||
|
||||
type Sound = Buffered<Decoder<Cursor<Vec<u8>>>>;
|
||||
type Sound = Buffered<SamplesConverter<Decoder<Cursor<Vec<u8>>>, f32>>;
|
||||
|
||||
pub struct SoundRegistry {
|
||||
cache: Arc<parking_lot::Mutex<HashMap<String, Sound>>>,
|
||||
@@ -45,7 +48,7 @@ impl SoundRegistry {
|
||||
.with_context(|| format!("No asset available for path {path}"))??
|
||||
.into_owned();
|
||||
let cursor = Cursor::new(bytes);
|
||||
let source = Decoder::new(cursor)?.buffered();
|
||||
let source = Decoder::new(cursor)?.convert_samples::<f32>().buffered();
|
||||
|
||||
self.cache.lock().insert(name.to_string(), source.clone());
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use assets::SoundRegistry;
|
||||
use derive_more::{Deref, DerefMut};
|
||||
use gpui::{App, AssetSource, BorrowAppContext, Global};
|
||||
use rodio::{OutputStream, OutputStreamBuilder};
|
||||
use rodio::{OutputStream, OutputStreamHandle};
|
||||
use util::ResultExt;
|
||||
|
||||
mod assets;
|
||||
@@ -37,7 +37,8 @@ impl Sound {
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Audio {
|
||||
output_handle: Option<OutputStream>,
|
||||
_output_stream: Option<OutputStream>,
|
||||
output_handle: Option<OutputStreamHandle>,
|
||||
}
|
||||
|
||||
#[derive(Deref, DerefMut)]
|
||||
@@ -50,9 +51,11 @@ impl Audio {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn ensure_output_exists(&mut self) -> Option<&OutputStream> {
|
||||
fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> {
|
||||
if self.output_handle.is_none() {
|
||||
self.output_handle = OutputStreamBuilder::open_default_stream().log_err();
|
||||
let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip();
|
||||
self.output_handle = output_handle;
|
||||
self._output_stream = _output_stream;
|
||||
}
|
||||
|
||||
self.output_handle.as_ref()
|
||||
@@ -66,7 +69,7 @@ impl Audio {
|
||||
cx.update_global::<GlobalAudio, _>(|this, cx| {
|
||||
let output_handle = this.ensure_output_exists()?;
|
||||
let source = SoundRegistry::global(cx).get(sound.file()).log_err()?;
|
||||
output_handle.mixer().add(source);
|
||||
output_handle.play_raw(source).log_err()?;
|
||||
Some(())
|
||||
});
|
||||
}
|
||||
@@ -77,6 +80,7 @@ impl Audio {
|
||||
}
|
||||
|
||||
cx.update_global::<GlobalAudio, _>(|this, _| {
|
||||
this._output_stream.take();
|
||||
this.output_handle.take();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ async-tungstenite = { workspace = true, features = ["tokio", "tokio-rustls-manua
|
||||
base64.workspace = true
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
clock.workspace = true
|
||||
cloud_api_client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
derive_more.workspace = true
|
||||
@@ -35,8 +33,8 @@ http_client.workspace = true
|
||||
http_client_tls.workspace = true
|
||||
httparse = "1.10"
|
||||
log.workspace = true
|
||||
parking_lot.workspace = true
|
||||
paths.workspace = true
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
@@ -48,18 +46,19 @@ serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
sha2.workspace = true
|
||||
smol.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http.workspace = true
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
tokio.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
worktree.workspace = true
|
||||
telemetry.workspace = true
|
||||
tokio.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub mod test;
|
||||
|
||||
mod cloud;
|
||||
mod proxy;
|
||||
pub mod telemetry;
|
||||
pub mod user;
|
||||
@@ -16,7 +15,6 @@ use async_tungstenite::tungstenite::{
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use clock::SystemClock;
|
||||
use cloud_api_client::CloudApiClient;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{
|
||||
AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt,
|
||||
@@ -33,6 +31,7 @@ use rpc::proto::{AnyTypedEnvelope, EnvelopedMessage, PeerId, RequestMessage};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources};
|
||||
use std::pin::Pin;
|
||||
use std::{
|
||||
any::TypeId,
|
||||
convert::TryFrom,
|
||||
@@ -46,14 +45,12 @@ use std::{
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{cmp, pin::Pin};
|
||||
use telemetry::Telemetry;
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use url::Url;
|
||||
use util::{ConnectionResult, ResultExt};
|
||||
|
||||
pub use cloud::*;
|
||||
pub use rpc::*;
|
||||
pub use telemetry_events::Event;
|
||||
pub use user::*;
|
||||
@@ -81,7 +78,7 @@ pub static ZED_ALWAYS_ACTIVE: LazyLock<bool> =
|
||||
LazyLock::new(|| std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| !e.is_empty()));
|
||||
|
||||
pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(500);
|
||||
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(30);
|
||||
pub const MAX_RECONNECTION_DELAY: Duration = Duration::from_secs(10);
|
||||
pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
|
||||
|
||||
actions!(
|
||||
@@ -216,7 +213,6 @@ pub struct Client {
|
||||
id: AtomicU64,
|
||||
peer: Arc<Peer>,
|
||||
http: Arc<HttpClientWithUrl>,
|
||||
cloud_client: Arc<CloudApiClient>,
|
||||
telemetry: Arc<Telemetry>,
|
||||
credentials_provider: ClientCredentialsProvider,
|
||||
state: RwLock<ClientState>,
|
||||
@@ -590,7 +586,6 @@ impl Client {
|
||||
id: AtomicU64::new(0),
|
||||
peer: Peer::new(0),
|
||||
telemetry: Telemetry::new(clock, http.clone(), cx),
|
||||
cloud_client: Arc::new(CloudApiClient::new(http.clone())),
|
||||
http,
|
||||
credentials_provider: ClientCredentialsProvider::new(cx),
|
||||
state: Default::default(),
|
||||
@@ -623,10 +618,6 @@ impl Client {
|
||||
self.http.clone()
|
||||
}
|
||||
|
||||
pub fn cloud_client(&self) -> Arc<CloudApiClient> {
|
||||
self.cloud_client.clone()
|
||||
}
|
||||
|
||||
pub fn set_id(&self, id: u64) -> &Self {
|
||||
self.id.store(id, Ordering::SeqCst);
|
||||
self
|
||||
@@ -736,10 +727,11 @@ impl Client {
|
||||
},
|
||||
&cx,
|
||||
);
|
||||
let jitter =
|
||||
Duration::from_millis(rng.gen_range(0..delay.as_millis() as u64));
|
||||
cx.background_executor().timer(delay + jitter).await;
|
||||
delay = cmp::min(delay * 2, MAX_RECONNECTION_DELAY);
|
||||
cx.background_executor().timer(delay).await;
|
||||
delay = delay
|
||||
.mul_f32(rng.gen_range(0.5..=2.5))
|
||||
.max(INITIAL_RECONNECTION_DELAY)
|
||||
.min(MAX_RECONNECTION_DELAY);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
@@ -939,8 +931,6 @@ impl Client {
|
||||
}
|
||||
let credentials = credentials.unwrap();
|
||||
self.set_id(credentials.user_id);
|
||||
self.cloud_client
|
||||
.set_credentials(credentials.user_id as u32, credentials.access_token.clone());
|
||||
|
||||
if was_disconnected {
|
||||
self.set_status(Status::Connecting, cx);
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
mod user_store;
|
||||
|
||||
pub use user_store::*;
|
||||
@@ -1,41 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use cloud_api_client::{AuthenticatedUser, CloudApiClient};
|
||||
use gpui::{Context, Task};
|
||||
use util::{ResultExt as _, maybe};
|
||||
|
||||
pub struct CloudUserStore {
|
||||
authenticated_user: Option<AuthenticatedUser>,
|
||||
_fetch_authenticated_user_task: Task<()>,
|
||||
}
|
||||
|
||||
impl CloudUserStore {
|
||||
pub fn new(cloud_client: Arc<CloudApiClient>, cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
authenticated_user: None,
|
||||
_fetch_authenticated_user_task: cx.spawn(async move |this, cx| {
|
||||
maybe!(async move {
|
||||
loop {
|
||||
if cloud_client.has_credentials() {
|
||||
break;
|
||||
}
|
||||
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(100))
|
||||
.await;
|
||||
}
|
||||
|
||||
let response = cloud_client.get_authenticated_user().await?;
|
||||
this.update(cx, |this, _cx| {
|
||||
this.authenticated_user = Some(response.user);
|
||||
})
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch authenticated user")
|
||||
.log_err();
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,6 @@
|
||||
use super::{Client, Status, TypedEnvelope, proto};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::{
|
||||
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||
};
|
||||
use collections::{HashMap, HashSet, hash_map::Entry};
|
||||
use derive_more::Deref;
|
||||
use feature_flags::FeatureFlagAppExt;
|
||||
@@ -21,6 +17,10 @@ use std::{
|
||||
};
|
||||
use text::ReplicaId;
|
||||
use util::{TryFutureExt as _, maybe};
|
||||
use zed_llm_client::{
|
||||
EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME, EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME,
|
||||
MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
|
||||
};
|
||||
|
||||
pub type UserId = u64;
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_api_client"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "Apache-2.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_api_client.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_api_types.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
parking_lot.workspace = true
|
||||
serde_json.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1,79 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
pub use cloud_api_types::*;
|
||||
use futures::AsyncReadExt as _;
|
||||
use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
struct Credentials {
|
||||
user_id: u32,
|
||||
access_token: String,
|
||||
}
|
||||
|
||||
pub struct CloudApiClient {
|
||||
credentials: RwLock<Option<Credentials>>,
|
||||
http_client: Arc<HttpClientWithUrl>,
|
||||
}
|
||||
|
||||
impl CloudApiClient {
|
||||
pub fn new(http_client: Arc<HttpClientWithUrl>) -> Self {
|
||||
Self {
|
||||
credentials: RwLock::new(None),
|
||||
http_client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_credentials(&self) -> bool {
|
||||
self.credentials.read().is_some()
|
||||
}
|
||||
|
||||
pub fn set_credentials(&self, user_id: u32, access_token: String) {
|
||||
*self.credentials.write() = Some(Credentials {
|
||||
user_id,
|
||||
access_token,
|
||||
});
|
||||
}
|
||||
|
||||
fn authorization_header(&self) -> Result<String> {
|
||||
let guard = self.credentials.read();
|
||||
let credentials = guard
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("No credentials provided"))?;
|
||||
|
||||
Ok(format!(
|
||||
"{} {}",
|
||||
credentials.user_id, credentials.access_token
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
|
||||
let request = Request::builder()
|
||||
.method(Method::GET)
|
||||
.uri(
|
||||
self.http_client
|
||||
.build_zed_cloud_url("/client/users/me", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", self.authorization_header()?)
|
||||
.body(AsyncBody::default())?;
|
||||
|
||||
let mut response = self.http_client.send(request).await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
|
||||
response.status()
|
||||
)
|
||||
}
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_api_types"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "Apache-2.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_api_types.rs"
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1,14 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct GetAuthenticatedUserResponse {
|
||||
pub user: AuthenticatedUser,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AuthenticatedUser {
|
||||
pub id: i32,
|
||||
pub avatar_url: String,
|
||||
pub github_login: String,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_llm_client"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
license = "Apache-2.0"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_llm_client.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
serde = { workspace = true, features = ["derive", "rc"] }
|
||||
serde_json.workspace = true
|
||||
strum = { workspace = true, features = ["derive"] }
|
||||
uuid = { workspace = true, features = ["serde"] }
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions.workspace = true
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1,370 +0,0 @@
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use strum::{Display, EnumIter, EnumString};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// The name of the header used to indicate which version of Zed the client is running.
|
||||
pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
|
||||
|
||||
/// The name of the header used to indicate when a request failed due to an
|
||||
/// expired LLM token.
|
||||
///
|
||||
/// The client may use this as a signal to refresh the token.
|
||||
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
||||
|
||||
/// The name of the header used to indicate what plan the user is currently on.
|
||||
pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
|
||||
|
||||
/// The name of the header used to indicate the usage limit for model requests.
|
||||
pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
|
||||
|
||||
/// The name of the header used to indicate the usage amount for model requests.
|
||||
pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
|
||||
|
||||
/// The name of the header used to indicate the usage limit for edit predictions.
|
||||
pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
|
||||
|
||||
/// The name of the header used to indicate the usage amount for edit predictions.
|
||||
pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
|
||||
|
||||
/// The name of the header used to indicate the resource for which the subscription limit has been reached.
|
||||
pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
|
||||
|
||||
pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
|
||||
pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
|
||||
|
||||
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
|
||||
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
|
||||
|
||||
/// The name of the header used to indicate the the minimum required Zed version.
|
||||
///
|
||||
/// This can be used to force a Zed upgrade in order to continue communicating
|
||||
/// with the LLM service.
|
||||
pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
|
||||
|
||||
/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
|
||||
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
|
||||
"x-zed-client-supports-status-messages";
|
||||
|
||||
/// The name of the header used by the server to indicate to the client that it supports sending status messages.
|
||||
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
|
||||
"x-zed-server-supports-status-messages";
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UsageLimit {
|
||||
Limited(i32),
|
||||
Unlimited,
|
||||
}
|
||||
|
||||
impl FromStr for UsageLimit {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value {
|
||||
"unlimited" => Ok(Self::Unlimited),
|
||||
limit => limit
|
||||
.parse::<i32>()
|
||||
.map(Self::Limited)
|
||||
.context("failed to parse limit"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Plan {
|
||||
#[default]
|
||||
#[serde(alias = "Free")]
|
||||
ZedFree,
|
||||
#[serde(alias = "ZedPro")]
|
||||
ZedPro,
|
||||
#[serde(alias = "ZedProTrial")]
|
||||
ZedProTrial,
|
||||
}
|
||||
|
||||
impl Plan {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Plan::ZedFree => "zed_free",
|
||||
Plan::ZedPro => "zed_pro",
|
||||
Plan::ZedProTrial => "zed_pro_trial",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn model_requests_limit(&self) -> UsageLimit {
|
||||
match self {
|
||||
Plan::ZedPro => UsageLimit::Limited(500),
|
||||
Plan::ZedProTrial => UsageLimit::Limited(150),
|
||||
Plan::ZedFree => UsageLimit::Limited(50),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_predictions_limit(&self) -> UsageLimit {
|
||||
match self {
|
||||
Plan::ZedPro => UsageLimit::Unlimited,
|
||||
Plan::ZedProTrial => UsageLimit::Unlimited,
|
||||
Plan::ZedFree => UsageLimit::Limited(2_000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Plan {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
||||
match value {
|
||||
"zed_free" => Ok(Plan::ZedFree),
|
||||
"zed_pro" => Ok(Plan::ZedPro),
|
||||
"zed_pro_trial" => Ok(Plan::ZedProTrial),
|
||||
plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum LanguageModelProvider {
|
||||
Anthropic,
|
||||
OpenAi,
|
||||
Google,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsBody {
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub outline: Option<String>,
|
||||
pub input_events: String,
|
||||
pub input_excerpt: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub speculated_output: Option<String>,
|
||||
/// Whether the user provided consent for sampling this interaction.
|
||||
#[serde(default, alias = "data_collection_permission")]
|
||||
pub can_collect_data: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PredictEditsResponse {
|
||||
pub request_id: Uuid,
|
||||
pub output_excerpt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AcceptEditPredictionBody {
|
||||
pub request_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionMode {
|
||||
Normal,
|
||||
Max,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionIntent {
|
||||
UserPrompt,
|
||||
ToolResults,
|
||||
ThreadSummarization,
|
||||
ThreadContextSummarization,
|
||||
CreateFile,
|
||||
EditFile,
|
||||
InlineAssist,
|
||||
TerminalInlineAssist,
|
||||
GenerateGitCommitMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CompletionBody {
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub thread_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub prompt_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub intent: Option<CompletionIntent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none", default)]
|
||||
pub mode: Option<CompletionMode>,
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub provider_request: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionRequestStatus {
|
||||
Queued {
|
||||
position: usize,
|
||||
},
|
||||
Started,
|
||||
Failed {
|
||||
code: String,
|
||||
message: String,
|
||||
request_id: Uuid,
|
||||
/// Retry duration in seconds.
|
||||
retry_after: Option<f64>,
|
||||
},
|
||||
UsageUpdated {
|
||||
amount: usize,
|
||||
limit: UsageLimit,
|
||||
},
|
||||
ToolUseLimitReached,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CompletionEvent<T> {
|
||||
Status(CompletionRequestStatus),
|
||||
Event(T),
|
||||
}
|
||||
|
||||
impl<T> CompletionEvent<T> {
|
||||
pub fn into_status(self) -> Option<CompletionRequestStatus> {
|
||||
match self {
|
||||
Self::Status(status) => Some(status),
|
||||
Self::Event(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_event(self) -> Option<T> {
|
||||
match self {
|
||||
Self::Event(event) => Some(event),
|
||||
Self::Status(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct WebSearchBody {
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct WebSearchResponse {
|
||||
pub results: Vec<WebSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct WebSearchResult {
|
||||
pub title: String,
|
||||
pub url: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CountTokensBody {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub model: String,
|
||||
pub provider_request: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CountTokensResponse {
|
||||
pub tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageModelId(pub Arc<str>);
|
||||
|
||||
impl std::fmt::Display for LanguageModelId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct LanguageModel {
|
||||
pub provider: LanguageModelProvider,
|
||||
pub id: LanguageModelId,
|
||||
pub display_name: String,
|
||||
pub max_token_count: usize,
|
||||
pub max_token_count_in_max_mode: Option<usize>,
|
||||
pub max_output_tokens: usize,
|
||||
pub supports_tools: bool,
|
||||
pub supports_images: bool,
|
||||
pub supports_thinking: bool,
|
||||
pub supports_max_mode: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub models: Vec<LanguageModel>,
|
||||
pub default_model: LanguageModelId,
|
||||
pub default_fast_model: LanguageModelId,
|
||||
pub recommended_models: Vec<LanguageModelId>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GetSubscriptionResponse {
|
||||
pub plan: Plan,
|
||||
pub usage: Option<CurrentUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CurrentUsage {
|
||||
pub model_requests: UsageData,
|
||||
pub edit_predictions: UsageData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UsageData {
|
||||
pub used: u32,
|
||||
pub limit: UsageLimit,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_plan_deserialize_snake_case() {
|
||||
let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedFree);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedPro);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedProTrial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plan_deserialize_aliases() {
|
||||
let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedFree);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedPro);
|
||||
|
||||
let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
|
||||
assert_eq!(plan, Plan::ZedProTrial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_limit_from_str() {
|
||||
let limit = UsageLimit::from_str("unlimited").unwrap();
|
||||
assert!(matches!(limit, UsageLimit::Unlimited));
|
||||
|
||||
let limit = UsageLimit::from_str(&0.to_string()).unwrap();
|
||||
assert!(matches!(limit, UsageLimit::Limited(0)));
|
||||
|
||||
let limit = UsageLimit::from_str(&50.to_string()).unwrap();
|
||||
assert!(matches!(limit, UsageLimit::Limited(50)));
|
||||
|
||||
for value in ["not_a_number", "50xyz"] {
|
||||
let limit = UsageLimit::from_str(value);
|
||||
assert!(limit.is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,14 +23,13 @@ async-stripe.workspace = true
|
||||
async-trait.workspace = true
|
||||
async-tungstenite.workspace = true
|
||||
aws-config = { version = "1.1.5" }
|
||||
aws-sdk-kinesis = "1.51.0"
|
||||
aws-sdk-s3 = { version = "1.15.0" }
|
||||
aws-sdk-kinesis = "1.51.0"
|
||||
axum = { version = "0.6", features = ["json", "headers", "ws"] }
|
||||
axum-extra = { version = "0.4", features = ["erased-json"] }
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
clock.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
dashmap.workspace = true
|
||||
derive_more.workspace = true
|
||||
@@ -76,6 +75,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "re
|
||||
util.workspace = true
|
||||
uuid.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
agent_settings.workspace = true
|
||||
|
||||
@@ -100,7 +100,7 @@ impl std::fmt::Display for SystemIdHeader {
|
||||
|
||||
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
|
||||
Router::new()
|
||||
.route("/user", get(legacy_update_or_create_authenticated_user))
|
||||
.route("/user", get(update_or_create_authenticated_user))
|
||||
.route("/users/look_up", get(look_up_user))
|
||||
.route("/users/:id/access_tokens", post(create_access_token))
|
||||
.route("/users/:id/refresh_llm_tokens", post(refresh_llm_tokens))
|
||||
@@ -161,10 +161,7 @@ struct AuthenticatedUserResponse {
|
||||
feature_flags: Vec<String>,
|
||||
}
|
||||
|
||||
/// This is a legacy endpoint that is no longer used in production.
|
||||
///
|
||||
/// It currently only exists to be used when developing Collab locally.
|
||||
async fn legacy_update_or_create_authenticated_user(
|
||||
async fn update_or_create_authenticated_user(
|
||||
Query(params): Query<AuthenticatedUserParams>,
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
) -> Result<Json<AuthenticatedUserResponse>> {
|
||||
@@ -356,9 +353,9 @@ async fn refresh_llm_tokens(
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UpdatePlanBody {
|
||||
pub plan: cloud_llm_client::Plan,
|
||||
pub plan: zed_llm_client::Plan,
|
||||
pub subscription_period: SubscriptionPeriod,
|
||||
pub usage: cloud_llm_client::CurrentUsage,
|
||||
pub usage: zed_llm_client::CurrentUsage,
|
||||
pub trial_started_at: Option<DateTime<Utc>>,
|
||||
pub is_usage_based_billing_enabled: bool,
|
||||
pub is_account_too_young: bool,
|
||||
@@ -380,9 +377,9 @@ async fn update_plan(
|
||||
extract::Json(body): extract::Json<UpdatePlanBody>,
|
||||
) -> Result<Json<UpdatePlanResponse>> {
|
||||
let plan = match body.plan {
|
||||
cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||
cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||
cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
zed_llm_client::Plan::ZedFree => proto::Plan::Free,
|
||||
zed_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
|
||||
zed_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
let update_user_plan = proto::UpdateUserPlan {
|
||||
@@ -414,15 +411,15 @@ async fn update_plan(
|
||||
Ok(Json(UpdatePlanResponse {}))
|
||||
}
|
||||
|
||||
fn usage_limit_to_proto(limit: cloud_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||
fn usage_limit_to_proto(limit: zed_llm_client::UsageLimit) -> proto::UsageLimit {
|
||||
proto::UsageLimit {
|
||||
variant: Some(match limit {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use anyhow::{Context as _, bail};
|
||||
use chrono::{DateTime, Utc};
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::{HashMap, HashSet};
|
||||
use sea_orm::ActiveValue;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use stripe::{CancellationDetailsReason, EventObject, EventType, ListEvents, SubscriptionStatus};
|
||||
use util::{ResultExt, maybe};
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
use crate::AppState;
|
||||
use crate::db::billing_subscription::{
|
||||
@@ -87,14 +87,6 @@ async fn poll_stripe_events(
|
||||
stripe_client: &Arc<dyn StripeClient>,
|
||||
real_stripe_client: &stripe::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
let sync_events_using_cloud = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag.flag == "cloud-stripe-events-polling" && flag.enabled_for_all);
|
||||
if sync_events_using_cloud {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn event_type_to_string(event_type: EventType) -> String {
|
||||
// Calling `to_string` on `stripe::EventType` members gives us a quoted string,
|
||||
// so we need to unquote it.
|
||||
@@ -577,14 +569,6 @@ async fn sync_model_request_usage_with_stripe(
|
||||
llm_db: &Arc<LlmDatabase>,
|
||||
stripe_billing: &Arc<StripeBilling>,
|
||||
) -> anyhow::Result<()> {
|
||||
let feature_flags = app.db.list_feature_flags().await?;
|
||||
let sync_model_request_usage_using_cloud = feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag.flag == "cloud-stripe-usage-meters-sync" && flag.enabled_for_all);
|
||||
if sync_model_request_usage_using_cloud {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
log::info!("Stripe usage sync: Starting");
|
||||
let started_at = Utc::now();
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use axum::{
|
||||
use chrono::{NaiveDateTime, SecondsFormat};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::AuthenticatedUserParams;
|
||||
use crate::db::ContributorSelector;
|
||||
use crate::{AppState, Result};
|
||||
|
||||
@@ -103,18 +104,9 @@ impl RenovateBot {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AddContributorBody {
|
||||
github_user_id: i32,
|
||||
github_login: String,
|
||||
github_email: Option<String>,
|
||||
github_name: Option<String>,
|
||||
github_user_created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
async fn add_contributor(
|
||||
Extension(app): Extension<Arc<AppState>>,
|
||||
extract::Json(params): extract::Json<AddContributorBody>,
|
||||
extract::Json(params): extract::Json<AuthenticatedUserParams>,
|
||||
) -> Result<()> {
|
||||
let initial_channel_id = app.config.auto_join_channel_id;
|
||||
app.db
|
||||
|
||||
@@ -95,7 +95,7 @@ pub enum SubscriptionKind {
|
||||
ZedFree,
|
||||
}
|
||||
|
||||
impl From<SubscriptionKind> for cloud_llm_client::Plan {
|
||||
impl From<SubscriptionKind> for zed_llm_client::Plan {
|
||||
fn from(value: SubscriptionKind) -> Self {
|
||||
match value {
|
||||
SubscriptionKind::ZedPro => Self::ZedPro,
|
||||
|
||||
@@ -6,11 +6,11 @@ mod tables;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use collections::HashMap;
|
||||
pub use ids::*;
|
||||
pub use seed::*;
|
||||
pub use tables::*;
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use tests::TestLlmDb;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use cloud_llm_client::LanguageModelProvider;
|
||||
use pretty_assertions::assert_eq;
|
||||
use zed_llm_client::LanguageModelProvider;
|
||||
|
||||
use crate::llm::db::LlmDatabase;
|
||||
use crate::test_llm_db;
|
||||
|
||||
@@ -4,12 +4,12 @@ use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEA
|
||||
use crate::{Config, db::billing_preference};
|
||||
use anyhow::{Context as _, Result};
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use cloud_llm_client::Plan;
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
use zed_llm_client::Plan;
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
@@ -2868,12 +2868,12 @@ async fn make_update_user_plan_message(
|
||||
}
|
||||
|
||||
fn model_requests_limit(
|
||||
plan: cloud_llm_client::Plan,
|
||||
plan: zed_llm_client::Plan,
|
||||
feature_flags: &Vec<String>,
|
||||
) -> cloud_llm_client::UsageLimit {
|
||||
) -> zed_llm_client::UsageLimit {
|
||||
match plan.model_requests_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == cloud_llm_client::Plan::ZedProTrial
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
let limit = if plan == zed_llm_client::Plan::ZedProTrial
|
||||
&& feature_flags
|
||||
.iter()
|
||||
.any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
|
||||
@@ -2883,9 +2883,9 @@ fn model_requests_limit(
|
||||
limit
|
||||
};
|
||||
|
||||
cloud_llm_client::UsageLimit::Limited(limit)
|
||||
zed_llm_client::UsageLimit::Limited(limit)
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
|
||||
zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2895,21 +2895,21 @@ fn subscription_usage_to_proto(
|
||||
feature_flags: &Vec<String>,
|
||||
) -> proto::SubscriptionUsage {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: usage.model_requests as u32,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit(plan, feature_flags) {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
@@ -2917,12 +2917,12 @@ fn subscription_usage_to_proto(
|
||||
edit_predictions_usage_amount: usage.edit_predictions as u32,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
@@ -2935,21 +2935,21 @@ fn make_default_subscription_usage(
|
||||
feature_flags: &Vec<String>,
|
||||
) -> proto::SubscriptionUsage {
|
||||
let plan = match plan {
|
||||
proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
|
||||
proto::Plan::Free => zed_llm_client::Plan::ZedFree,
|
||||
proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
|
||||
proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
|
||||
};
|
||||
|
||||
proto::SubscriptionUsage {
|
||||
model_requests_usage_amount: 0,
|
||||
model_requests_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match model_requests_limit(plan, feature_flags) {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
@@ -2957,12 +2957,12 @@ fn make_default_subscription_usage(
|
||||
edit_predictions_usage_amount: 0,
|
||||
edit_predictions_usage_limit: Some(proto::UsageLimit {
|
||||
variant: Some(match plan.edit_predictions_limit() {
|
||||
cloud_llm_client::UsageLimit::Limited(limit) => {
|
||||
zed_llm_client::UsageLimit::Limited(limit) => {
|
||||
proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
|
||||
limit: limit as u32,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::UsageLimit::Unlimited => {
|
||||
zed_llm_client::UsageLimit::Unlimited => {
|
||||
proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -842,7 +842,7 @@ async fn test_client_disconnecting_from_room(
|
||||
|
||||
// Allow user A to reconnect to the server.
|
||||
server.allow_connections();
|
||||
executor.advance_clock(RECONNECT_TIMEOUT);
|
||||
executor.advance_clock(RECEIVE_TIMEOUT);
|
||||
|
||||
// Call user B again from client A.
|
||||
active_call_a
|
||||
@@ -1358,7 +1358,7 @@ async fn test_calls_on_multiple_connections(
|
||||
|
||||
// User A reconnects automatically, then calls user B again.
|
||||
server.allow_connections();
|
||||
executor.advance_clock(RECONNECT_TIMEOUT);
|
||||
executor.advance_clock(RECEIVE_TIMEOUT);
|
||||
active_call_a
|
||||
.update(cx_a, |call, cx| {
|
||||
call.invite(client_b1.user_id().unwrap(), None, cx)
|
||||
|
||||
@@ -8,7 +8,6 @@ use crate::{
|
||||
use anyhow::anyhow;
|
||||
use call::ActiveCall;
|
||||
use channel::{ChannelBuffer, ChannelStore};
|
||||
use client::CloudUserStore;
|
||||
use client::{
|
||||
self, ChannelId, Client, Connection, Credentials, EstablishConnectionError, UserStore,
|
||||
proto::PeerId,
|
||||
@@ -282,14 +281,12 @@ impl TestServer {
|
||||
.register_hosting_provider(Arc::new(git_hosting_providers::Github::public_instance()));
|
||||
|
||||
let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
|
||||
let cloud_user_store = cx.new(|cx| CloudUserStore::new(client.cloud_client(), cx));
|
||||
let workspace_store = cx.new(|cx| WorkspaceStore::new(client.clone(), cx));
|
||||
let language_registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||
let session = cx.new(|cx| AppSession::new(Session::test(), cx));
|
||||
let app_state = Arc::new(workspace::AppState {
|
||||
client: client.clone(),
|
||||
user_store: user_store.clone(),
|
||||
cloud_user_store,
|
||||
workspace_store,
|
||||
languages: language_registry,
|
||||
fs: fs.clone(),
|
||||
|
||||
@@ -7,17 +7,17 @@ license = "GPL-3.0-or-later"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
command_palette.workspace = true
|
||||
gpui.workspace = true
|
||||
clap.workspace = true
|
||||
mdbook = "0.4.40"
|
||||
regex.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
regex.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed.workspace = true
|
||||
zlog.workspace = true
|
||||
gpui.workspace = true
|
||||
command_palette.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::Result;
|
||||
use clap::{Arg, ArgMatches, Command};
|
||||
use mdbook::BookItem;
|
||||
use mdbook::book::{Book, Chapter};
|
||||
use mdbook::preprocess::CmdPreprocessor;
|
||||
use regex::Regex;
|
||||
use settings::KeymapFile;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashSet;
|
||||
use std::io::{self, Read};
|
||||
use std::process;
|
||||
use std::sync::LazyLock;
|
||||
use util::paths::PathExt;
|
||||
|
||||
static KEYMAP_MACOS: LazyLock<KeymapFile> = LazyLock::new(|| {
|
||||
load_keymap("keymaps/default-macos.json").expect("Failed to load MacOS keymap")
|
||||
@@ -21,68 +20,60 @@ static KEYMAP_LINUX: LazyLock<KeymapFile> = LazyLock::new(|| {
|
||||
|
||||
static ALL_ACTIONS: LazyLock<Vec<ActionDef>> = LazyLock::new(dump_all_gpui_actions);
|
||||
|
||||
const FRONT_MATTER_COMMENT: &'static str = "<!-- ZED_META {} -->";
|
||||
pub fn make_app() -> Command {
|
||||
Command::new("zed-docs-preprocessor")
|
||||
.about("Preprocesses Zed Docs content to provide rich action & keybinding support and more")
|
||||
.subcommand(
|
||||
Command::new("supports")
|
||||
.arg(Arg::new("renderer").required(true))
|
||||
.about("Check whether a renderer is supported by this preprocessor"),
|
||||
)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
zlog::init();
|
||||
zlog::init_output_stderr();
|
||||
let matches = make_app().get_matches();
|
||||
// call a zed:: function so everything in `zed` crate is linked and
|
||||
// all actions in the actual app are registered
|
||||
zed::stdout_is_a_pty();
|
||||
let args = std::env::args().skip(1).collect::<Vec<_>>();
|
||||
|
||||
match args.get(0).map(String::as_str) {
|
||||
Some("supports") => {
|
||||
let renderer = args.get(1).expect("Required argument");
|
||||
let supported = renderer != "not-supported";
|
||||
if supported {
|
||||
process::exit(0);
|
||||
} else {
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
Some("postprocess") => handle_postprocessing()?,
|
||||
_ => handle_preprocessing()?,
|
||||
if let Some(sub_args) = matches.subcommand_matches("supports") {
|
||||
handle_supports(sub_args);
|
||||
} else {
|
||||
handle_preprocessing()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
enum PreprocessorError {
|
||||
enum Error {
|
||||
ActionNotFound { action_name: String },
|
||||
DeprecatedActionUsed { used: String, should_be: String },
|
||||
InvalidFrontmatterLine(String),
|
||||
}
|
||||
|
||||
impl PreprocessorError {
|
||||
impl Error {
|
||||
fn new_for_not_found_action(action_name: String) -> Self {
|
||||
for action in &*ALL_ACTIONS {
|
||||
for alias in action.deprecated_aliases {
|
||||
if alias == &action_name {
|
||||
return PreprocessorError::DeprecatedActionUsed {
|
||||
return Error::DeprecatedActionUsed {
|
||||
used: action_name.clone(),
|
||||
should_be: action.name.to_string(),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
PreprocessorError::ActionNotFound {
|
||||
Error::ActionNotFound {
|
||||
action_name: action_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PreprocessorError {
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PreprocessorError::InvalidFrontmatterLine(line) => {
|
||||
write!(f, "Invalid frontmatter line: {}", line)
|
||||
}
|
||||
PreprocessorError::ActionNotFound { action_name } => {
|
||||
write!(f, "Action not found: {}", action_name)
|
||||
}
|
||||
PreprocessorError::DeprecatedActionUsed { used, should_be } => write!(
|
||||
Error::ActionNotFound { action_name } => write!(f, "Action not found: {}", action_name),
|
||||
Error::DeprecatedActionUsed { used, should_be } => write!(
|
||||
f,
|
||||
"Deprecated action used: {} should be {}",
|
||||
used, should_be
|
||||
@@ -98,9 +89,8 @@ fn handle_preprocessing() -> Result<()> {
|
||||
|
||||
let (_ctx, mut book) = CmdPreprocessor::parse_input(input.as_bytes())?;
|
||||
|
||||
let mut errors = HashSet::<PreprocessorError>::new();
|
||||
let mut errors = HashSet::<Error>::new();
|
||||
|
||||
handle_frontmatter(&mut book, &mut errors);
|
||||
template_and_validate_keybindings(&mut book, &mut errors);
|
||||
template_and_validate_actions(&mut book, &mut errors);
|
||||
|
||||
@@ -118,41 +108,19 @@ fn handle_preprocessing() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_frontmatter(book: &mut Book, errors: &mut HashSet<PreprocessorError>) {
|
||||
let frontmatter_regex = Regex::new(r"(?s)^\s*---(.*?)---").unwrap();
|
||||
for_each_chapter_mut(book, |chapter| {
|
||||
let new_content = frontmatter_regex.replace(&chapter.content, |caps: ®ex::Captures| {
|
||||
let frontmatter = caps[1].trim();
|
||||
let frontmatter = frontmatter.trim_matches(&[' ', '-', '\n']);
|
||||
let mut metadata = HashMap::<String, String>::default();
|
||||
for line in frontmatter.lines() {
|
||||
let Some((name, value)) = line.split_once(':') else {
|
||||
errors.insert(PreprocessorError::InvalidFrontmatterLine(format!(
|
||||
"{}: {}",
|
||||
chapter_breadcrumbs(&chapter),
|
||||
line
|
||||
)));
|
||||
continue;
|
||||
};
|
||||
let name = name.trim();
|
||||
let value = value.trim();
|
||||
metadata.insert(name.to_string(), value.to_string());
|
||||
}
|
||||
FRONT_MATTER_COMMENT.replace(
|
||||
"{}",
|
||||
&serde_json::to_string(&metadata).expect("Failed to serialize metadata"),
|
||||
)
|
||||
});
|
||||
match new_content {
|
||||
Cow::Owned(content) => {
|
||||
chapter.content = content;
|
||||
}
|
||||
Cow::Borrowed(_) => {}
|
||||
}
|
||||
});
|
||||
fn handle_supports(sub_args: &ArgMatches) -> ! {
|
||||
let renderer = sub_args
|
||||
.get_one::<String>("renderer")
|
||||
.expect("Required argument");
|
||||
let supported = renderer != "not-supported";
|
||||
if supported {
|
||||
process::exit(0);
|
||||
} else {
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<PreprocessorError>) {
|
||||
fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Error>) {
|
||||
let regex = Regex::new(r"\{#kb (.*?)\}").unwrap();
|
||||
|
||||
for_each_chapter_mut(book, |chapter| {
|
||||
@@ -160,9 +128,7 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Prepr
|
||||
.replace_all(&chapter.content, |caps: ®ex::Captures| {
|
||||
let action = caps[1].trim();
|
||||
if find_action_by_name(action).is_none() {
|
||||
errors.insert(PreprocessorError::new_for_not_found_action(
|
||||
action.to_string(),
|
||||
));
|
||||
errors.insert(Error::new_for_not_found_action(action.to_string()));
|
||||
return String::new();
|
||||
}
|
||||
let macos_binding = find_binding("macos", action).unwrap_or_default();
|
||||
@@ -178,7 +144,7 @@ fn template_and_validate_keybindings(book: &mut Book, errors: &mut HashSet<Prepr
|
||||
});
|
||||
}
|
||||
|
||||
fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<PreprocessorError>) {
|
||||
fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<Error>) {
|
||||
let regex = Regex::new(r"\{#action (.*?)\}").unwrap();
|
||||
|
||||
for_each_chapter_mut(book, |chapter| {
|
||||
@@ -186,9 +152,7 @@ fn template_and_validate_actions(book: &mut Book, errors: &mut HashSet<Preproces
|
||||
.replace_all(&chapter.content, |caps: ®ex::Captures| {
|
||||
let name = caps[1].trim();
|
||||
let Some(action) = find_action_by_name(name) else {
|
||||
errors.insert(PreprocessorError::new_for_not_found_action(
|
||||
name.to_string(),
|
||||
));
|
||||
errors.insert(Error::new_for_not_found_action(name.to_string()));
|
||||
return String::new();
|
||||
};
|
||||
format!("<code class=\"hljs\">{}</code>", &action.human_name)
|
||||
@@ -253,13 +217,6 @@ fn name_for_action(action_as_str: String) -> String {
|
||||
.unwrap_or(action_as_str)
|
||||
}
|
||||
|
||||
fn chapter_breadcrumbs(chapter: &Chapter) -> String {
|
||||
let mut breadcrumbs = Vec::with_capacity(chapter.parent_names.len() + 1);
|
||||
breadcrumbs.extend(chapter.parent_names.iter().map(String::as_str));
|
||||
breadcrumbs.push(chapter.name.as_str());
|
||||
format!("[{:?}] {}", chapter.source_path, breadcrumbs.join(" > "))
|
||||
}
|
||||
|
||||
fn load_keymap(asset_path: &str) -> Result<KeymapFile> {
|
||||
let content = util::asset_str::<settings::SettingsAssets>(asset_path);
|
||||
KeymapFile::parse(content.as_ref())
|
||||
@@ -297,126 +254,3 @@ fn dump_all_gpui_actions() -> Vec<ActionDef> {
|
||||
|
||||
return actions;
|
||||
}
|
||||
|
||||
fn handle_postprocessing() -> Result<()> {
|
||||
let logger = zlog::scoped!("render");
|
||||
let mut ctx = mdbook::renderer::RenderContext::from_json(io::stdin())?;
|
||||
let output = ctx
|
||||
.config
|
||||
.get_mut("output")
|
||||
.expect("has output")
|
||||
.as_table_mut()
|
||||
.expect("output is table");
|
||||
let zed_html = output.remove("zed-html").expect("zed-html output defined");
|
||||
let default_description = zed_html
|
||||
.get("default-description")
|
||||
.expect("Default description not found")
|
||||
.as_str()
|
||||
.expect("Default description not a string")
|
||||
.to_string();
|
||||
let default_title = zed_html
|
||||
.get("default-title")
|
||||
.expect("Default title not found")
|
||||
.as_str()
|
||||
.expect("Default title not a string")
|
||||
.to_string();
|
||||
|
||||
output.insert("html".to_string(), zed_html);
|
||||
mdbook::Renderer::render(&mdbook::renderer::HtmlHandlebars::new(), &ctx)?;
|
||||
let ignore_list = ["toc.html"];
|
||||
|
||||
let root_dir = ctx.destination.clone();
|
||||
let mut files = Vec::with_capacity(128);
|
||||
let mut queue = Vec::with_capacity(64);
|
||||
queue.push(root_dir.clone());
|
||||
while let Some(dir) = queue.pop() {
|
||||
for entry in std::fs::read_dir(&dir).context(dir.to_sanitized_string())? {
|
||||
let Ok(entry) = entry else {
|
||||
continue;
|
||||
};
|
||||
let file_type = entry.file_type().context("Failed to determine file type")?;
|
||||
if file_type.is_dir() {
|
||||
queue.push(entry.path());
|
||||
}
|
||||
if file_type.is_file()
|
||||
&& matches!(
|
||||
entry.path().extension().and_then(std::ffi::OsStr::to_str),
|
||||
Some("html")
|
||||
)
|
||||
{
|
||||
if ignore_list.contains(&&*entry.file_name().to_string_lossy()) {
|
||||
zlog::info!(logger => "Ignoring {}", entry.path().to_string_lossy());
|
||||
} else {
|
||||
files.push(entry.path());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
zlog::info!(logger => "Processing {} `.html` files", files.len());
|
||||
let meta_regex = Regex::new(&FRONT_MATTER_COMMENT.replace("{}", "(.*)")).unwrap();
|
||||
for file in files {
|
||||
let contents = std::fs::read_to_string(&file)?;
|
||||
let mut meta_description = None;
|
||||
let mut meta_title = None;
|
||||
let contents = meta_regex.replace(&contents, |caps: ®ex::Captures| {
|
||||
let metadata: HashMap<String, String> = serde_json::from_str(&caps[1]).with_context(|| format!("JSON Metadata: {:?}", &caps[1])).expect("Failed to deserialize metadata");
|
||||
for (kind, content) in metadata {
|
||||
match kind.as_str() {
|
||||
"description" => {
|
||||
meta_description = Some(content);
|
||||
}
|
||||
"title" => {
|
||||
meta_title = Some(content);
|
||||
}
|
||||
_ => {
|
||||
zlog::warn!(logger => "Unrecognized frontmatter key: {} in {:?}", kind, pretty_path(&file, &root_dir));
|
||||
}
|
||||
}
|
||||
}
|
||||
String::new()
|
||||
});
|
||||
let meta_description = meta_description.as_ref().unwrap_or_else(|| {
|
||||
zlog::warn!(logger => "No meta description found for {:?}", pretty_path(&file, &root_dir));
|
||||
&default_description
|
||||
});
|
||||
let page_title = extract_title_from_page(&contents, pretty_path(&file, &root_dir));
|
||||
let meta_title = meta_title.as_ref().unwrap_or_else(|| {
|
||||
zlog::debug!(logger => "No meta title found for {:?}", pretty_path(&file, &root_dir));
|
||||
&default_title
|
||||
});
|
||||
let meta_title = format!("{} | {}", page_title, meta_title);
|
||||
zlog::trace!(logger => "Updating {:?}", pretty_path(&file, &root_dir));
|
||||
let contents = contents.replace("#description#", meta_description);
|
||||
let contents = TITLE_REGEX
|
||||
.replace(&contents, |_: ®ex::Captures| {
|
||||
format!("<title>{}</title>", meta_title)
|
||||
})
|
||||
.to_string();
|
||||
// let contents = contents.replace("#title#", &meta_title);
|
||||
std::fs::write(file, contents)?;
|
||||
}
|
||||
return Ok(());
|
||||
|
||||
fn pretty_path<'a>(
|
||||
path: &'a std::path::PathBuf,
|
||||
root: &'a std::path::PathBuf,
|
||||
) -> &'a std::path::Path {
|
||||
&path.strip_prefix(&root).unwrap_or(&path)
|
||||
}
|
||||
const TITLE_REGEX: std::cell::LazyCell<Regex> =
|
||||
std::cell::LazyCell::new(|| Regex::new(r"<title>\s*(.*?)\s*</title>").unwrap());
|
||||
fn extract_title_from_page(contents: &str, pretty_path: &std::path::Path) -> String {
|
||||
let title_tag_contents = &TITLE_REGEX
|
||||
.captures(&contents)
|
||||
.with_context(|| format!("Failed to find title in {:?}", pretty_path))
|
||||
.expect("Page has <title> element")[1];
|
||||
let title = title_tag_contents
|
||||
.trim()
|
||||
.strip_suffix("- Zed")
|
||||
.unwrap_or(title_tag_contents)
|
||||
.trim()
|
||||
.to_string();
|
||||
title
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ use aho_corasick::AhoCorasick;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use blink_manager::BlinkManager;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use client::{Collaborator, DisableAiSettings, ParticipantIndex};
|
||||
use client::{Collaborator, ParticipantIndex};
|
||||
use clock::{AGENT_REPLICA_ID, ReplicaId};
|
||||
use collections::{BTreeMap, HashMap, HashSet, VecDeque};
|
||||
use convert_case::{Case, Casing};
|
||||
@@ -65,7 +65,7 @@ use display_map::*;
|
||||
pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder};
|
||||
pub use editor_settings::{
|
||||
CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode,
|
||||
ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, ShowScrollbar,
|
||||
ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowScrollbar,
|
||||
};
|
||||
use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings};
|
||||
pub use editor_settings_controls::*;
|
||||
@@ -7048,7 +7048,7 @@ impl Editor {
|
||||
}
|
||||
|
||||
pub fn update_edit_prediction_settings(&mut self, cx: &mut Context<Self>) {
|
||||
if self.edit_prediction_provider.is_none() || DisableAiSettings::get_global(cx).disable_ai {
|
||||
if self.edit_prediction_provider.is_none() {
|
||||
self.edit_prediction_settings = EditPredictionSettings::Disabled;
|
||||
} else {
|
||||
let selection = self.selections.newest_anchor();
|
||||
|
||||
@@ -19,8 +19,8 @@ path = "src/explorer.rs"
|
||||
|
||||
[dependencies]
|
||||
agent.workspace = true
|
||||
agent_settings.workspace = true
|
||||
agent_ui.workspace = true
|
||||
agent_settings.workspace = true
|
||||
anyhow.workspace = true
|
||||
assistant_tool.workspace = true
|
||||
assistant_tools.workspace = true
|
||||
@@ -29,7 +29,6 @@ buffer_diff.workspace = true
|
||||
chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
dirs.workspace = true
|
||||
@@ -69,3 +68,4 @@ util.workspace = true
|
||||
uuid.workspace = true
|
||||
watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
@@ -15,11 +15,11 @@ use agent_settings::AgentProfileId;
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use buffer_diff::DiffHunkStatus;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
|
||||
use gpui::{App, AppContext, AsyncApp, Entity};
|
||||
use language_model::{LanguageModel, Role, StopReason};
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn language_server_initialization_options(
|
||||
@@ -131,7 +131,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn language_server_workspace_configuration(
|
||||
@@ -154,7 +154,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn language_server_additional_initialization_options(
|
||||
@@ -179,7 +179,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn language_server_additional_workspace_configuration(
|
||||
@@ -204,7 +204,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn labels_for_completions(
|
||||
@@ -230,7 +230,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn labels_for_symbols(
|
||||
@@ -256,7 +256,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn complete_slash_command_argument(
|
||||
@@ -275,7 +275,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_slash_command(
|
||||
@@ -301,7 +301,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn context_server_command(
|
||||
@@ -320,7 +320,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn context_server_configuration(
|
||||
@@ -347,7 +347,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn suggest_docs_packages(&self, provider: Arc<str>) -> Result<Vec<String>> {
|
||||
@@ -362,7 +362,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn index_docs(
|
||||
@@ -388,7 +388,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_dap_binary(
|
||||
@@ -410,7 +410,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
async fn dap_request_kind(
|
||||
&self,
|
||||
@@ -427,7 +427,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn dap_config_to_scenario(&self, config: ZedDebugConfig) -> Result<DebugScenario> {
|
||||
@@ -441,7 +441,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
|
||||
async fn dap_locator_create_scenario(
|
||||
@@ -465,7 +465,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
async fn run_dap_locator(
|
||||
&self,
|
||||
@@ -481,7 +481,7 @@ impl extension::Extension for WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -761,7 +761,7 @@ impl WasmExtension {
|
||||
.with_context(|| format!("failed to load wasm extension {}", manifest.id))
|
||||
}
|
||||
|
||||
pub async fn call<T, Fn>(&self, f: Fn) -> Result<T>
|
||||
pub async fn call<T, Fn>(&self, f: Fn) -> T
|
||||
where
|
||||
T: 'static + Send,
|
||||
Fn: 'static
|
||||
@@ -777,15 +777,14 @@ impl WasmExtension {
|
||||
}
|
||||
.boxed()
|
||||
}))
|
||||
.map_err(|_| {
|
||||
anyhow!(
|
||||
.unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"wasm extension channel should not be closed yet, extension {} (id {})",
|
||||
self.manifest.name,
|
||||
self.manifest.id,
|
||||
self.manifest.name, self.manifest.id,
|
||||
)
|
||||
})?;
|
||||
return_rx.await.with_context(|| {
|
||||
format!(
|
||||
});
|
||||
return_rx.await.unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"wasm extension channel, extension {} (id {})",
|
||||
self.manifest.name, self.manifest.id,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ buffer_diff.workspace = true
|
||||
call.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
component.workspace = true
|
||||
@@ -63,6 +62,7 @@ watch.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
windows.workspace = true
|
||||
|
||||
@@ -71,12 +71,12 @@ use ui::{
|
||||
use util::{ResultExt, TryFutureExt, maybe};
|
||||
use workspace::SERIALIZATION_THROTTLE_TIME;
|
||||
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use workspace::{
|
||||
Workspace,
|
||||
dock::{DockPosition, Panel, PanelEvent},
|
||||
notifications::{DetachAndPromptErr, ErrorMessagePrompt, NotificationId},
|
||||
};
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
actions!(
|
||||
git_panel,
|
||||
|
||||
@@ -216,6 +216,10 @@ xim = { git = "https://github.com/XDeme1/xim-rs", rev = "d50d461764c2213655cd9cf
|
||||
x11-clipboard = { version = "0.9.3", optional = true }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
blade-util.workspace = true
|
||||
bytemuck = "1"
|
||||
blade-graphics.workspace = true
|
||||
blade-macros.workspace = true
|
||||
flume = "0.11"
|
||||
rand.workspace = true
|
||||
windows.workspace = true
|
||||
@@ -236,6 +240,7 @@ util = { workspace = true, features = ["test-support"] }
|
||||
|
||||
[target.'cfg(target_os = "windows")'.build-dependencies]
|
||||
embed-resource = "3.0"
|
||||
naga.workspace = true
|
||||
|
||||
[target.'cfg(target_os = "macos")'.build-dependencies]
|
||||
bindgen = "0.71"
|
||||
|
||||
@@ -9,10 +9,7 @@ fn main() {
|
||||
let target = env::var("CARGO_CFG_TARGET_OS");
|
||||
println!("cargo::rustc-check-cfg=cfg(gles)");
|
||||
|
||||
#[cfg(any(
|
||||
not(any(target_os = "macos", target_os = "windows")),
|
||||
all(target_os = "macos", feature = "macos-blade")
|
||||
))]
|
||||
#[cfg(any(not(target_os = "macos"), feature = "macos-blade"))]
|
||||
check_wgsl_shaders();
|
||||
|
||||
match target.as_deref() {
|
||||
@@ -20,18 +17,21 @@ fn main() {
|
||||
#[cfg(target_os = "macos")]
|
||||
macos::build();
|
||||
}
|
||||
#[cfg(all(target_os = "windows", feature = "windows-manifest"))]
|
||||
Ok("windows") => {
|
||||
#[cfg(target_os = "windows")]
|
||||
windows::build();
|
||||
let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml");
|
||||
let rc_file = std::path::Path::new("resources/windows/gpui.rc");
|
||||
println!("cargo:rerun-if-changed={}", manifest.display());
|
||||
println!("cargo:rerun-if-changed={}", rc_file.display());
|
||||
embed_resource::compile(rc_file, embed_resource::NONE)
|
||||
.manifest_required()
|
||||
.unwrap();
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(any(
|
||||
not(any(target_os = "macos", target_os = "windows")),
|
||||
all(target_os = "macos", feature = "macos-blade")
|
||||
))]
|
||||
#[allow(dead_code)]
|
||||
fn check_wgsl_shaders() {
|
||||
use std::path::PathBuf;
|
||||
use std::process;
|
||||
@@ -243,215 +243,3 @@ mod macos {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
mod windows {
|
||||
use std::{
|
||||
fs,
|
||||
io::Write,
|
||||
path::{Path, PathBuf},
|
||||
process::{self, Command},
|
||||
};
|
||||
|
||||
pub(super) fn build() {
|
||||
// Compile HLSL shaders
|
||||
#[cfg(not(debug_assertions))]
|
||||
compile_shaders();
|
||||
|
||||
// Embed the Windows manifest and resource file
|
||||
#[cfg(feature = "windows-manifest")]
|
||||
embed_resource();
|
||||
}
|
||||
|
||||
#[cfg(feature = "windows-manifest")]
|
||||
fn embed_resource() {
|
||||
let manifest = std::path::Path::new("resources/windows/gpui.manifest.xml");
|
||||
let rc_file = std::path::Path::new("resources/windows/gpui.rc");
|
||||
println!("cargo:rerun-if-changed={}", manifest.display());
|
||||
println!("cargo:rerun-if-changed={}", rc_file.display());
|
||||
embed_resource::compile(rc_file, embed_resource::NONE)
|
||||
.manifest_required()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler.
|
||||
fn compile_shaders() {
|
||||
let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
|
||||
.join("src/platform/windows/shaders.hlsl");
|
||||
let out_dir = std::env::var("OUT_DIR").unwrap();
|
||||
|
||||
println!("cargo:rerun-if-changed={}", shader_path.display());
|
||||
|
||||
// Check if fxc.exe is available
|
||||
let fxc_path = find_fxc_compiler();
|
||||
|
||||
// Define all modules
|
||||
let modules = [
|
||||
"quad",
|
||||
"shadow",
|
||||
"path_rasterization",
|
||||
"path_sprite",
|
||||
"underline",
|
||||
"monochrome_sprite",
|
||||
"polychrome_sprite",
|
||||
];
|
||||
|
||||
let rust_binding_path = format!("{}/shaders_bytes.rs", out_dir);
|
||||
if Path::new(&rust_binding_path).exists() {
|
||||
fs::remove_file(&rust_binding_path)
|
||||
.expect("Failed to remove existing Rust binding file");
|
||||
}
|
||||
for module in modules {
|
||||
compile_shader_for_module(
|
||||
module,
|
||||
&out_dir,
|
||||
&fxc_path,
|
||||
shader_path.to_str().unwrap(),
|
||||
&rust_binding_path,
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let shader_path = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())
|
||||
.join("src/platform/windows/color_text_raster.hlsl");
|
||||
compile_shader_for_module(
|
||||
"emoji_rasterization",
|
||||
&out_dir,
|
||||
&fxc_path,
|
||||
shader_path.to_str().unwrap(),
|
||||
&rust_binding_path,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler.
|
||||
fn find_fxc_compiler() -> String {
|
||||
// Check environment variable
|
||||
if let Ok(path) = std::env::var("GPUI_FXC_PATH") {
|
||||
if Path::new(&path).exists() {
|
||||
return path;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to find in PATH
|
||||
// NOTE: This has to be `where.exe` on Windows, not `where`, it must be ended with `.exe`
|
||||
if let Ok(output) = std::process::Command::new("where.exe")
|
||||
.arg("fxc.exe")
|
||||
.output()
|
||||
{
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout);
|
||||
return path.trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Check the default path
|
||||
if Path::new(r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe")
|
||||
.exists()
|
||||
{
|
||||
return r"C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\fxc.exe"
|
||||
.to_string();
|
||||
}
|
||||
|
||||
panic!("Failed to find fxc.exe");
|
||||
}
|
||||
|
||||
fn compile_shader_for_module(
|
||||
module: &str,
|
||||
out_dir: &str,
|
||||
fxc_path: &str,
|
||||
shader_path: &str,
|
||||
rust_binding_path: &str,
|
||||
) {
|
||||
// Compile vertex shader
|
||||
let output_file = format!("{}/{}_vs.h", out_dir, module);
|
||||
let const_name = format!("{}_VERTEX_BYTES", module.to_uppercase());
|
||||
compile_shader_impl(
|
||||
fxc_path,
|
||||
&format!("{module}_vertex"),
|
||||
&output_file,
|
||||
&const_name,
|
||||
shader_path,
|
||||
"vs_4_1",
|
||||
);
|
||||
generate_rust_binding(&const_name, &output_file, &rust_binding_path);
|
||||
|
||||
// Compile fragment shader
|
||||
let output_file = format!("{}/{}_ps.h", out_dir, module);
|
||||
let const_name = format!("{}_FRAGMENT_BYTES", module.to_uppercase());
|
||||
compile_shader_impl(
|
||||
fxc_path,
|
||||
&format!("{module}_fragment"),
|
||||
&output_file,
|
||||
&const_name,
|
||||
shader_path,
|
||||
"ps_4_1",
|
||||
);
|
||||
generate_rust_binding(&const_name, &output_file, &rust_binding_path);
|
||||
}
|
||||
|
||||
fn compile_shader_impl(
|
||||
fxc_path: &str,
|
||||
entry_point: &str,
|
||||
output_path: &str,
|
||||
var_name: &str,
|
||||
shader_path: &str,
|
||||
target: &str,
|
||||
) {
|
||||
let output = Command::new(fxc_path)
|
||||
.args([
|
||||
"/T",
|
||||
target,
|
||||
"/E",
|
||||
entry_point,
|
||||
"/Fh",
|
||||
output_path,
|
||||
"/Vn",
|
||||
var_name,
|
||||
"/O3",
|
||||
shader_path,
|
||||
])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(result) => {
|
||||
if result.status.success() {
|
||||
return;
|
||||
}
|
||||
eprintln!(
|
||||
"Shader compilation failed for {}:\n{}",
|
||||
entry_point,
|
||||
String::from_utf8_lossy(&result.stderr)
|
||||
);
|
||||
process::exit(1);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to run fxc for {}: {}", entry_point, e);
|
||||
process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_rust_binding(const_name: &str, head_file: &str, output_path: &str) {
|
||||
let header_content = fs::read_to_string(head_file).expect("Failed to read header file");
|
||||
let const_definition = {
|
||||
let global_var_start = header_content.find("const BYTE").unwrap();
|
||||
let global_var = &header_content[global_var_start..];
|
||||
let equal = global_var.find('=').unwrap();
|
||||
global_var[equal + 1..].trim()
|
||||
};
|
||||
let rust_binding = format!(
|
||||
"const {}: &[u8] = &{}\n",
|
||||
const_name,
|
||||
const_definition.replace('{', "[").replace('}', "]")
|
||||
);
|
||||
let mut options = fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(output_path)
|
||||
.expect("Failed to open Rust binding file");
|
||||
options
|
||||
.write_all(rust_binding.as_bytes())
|
||||
.expect("Failed to write Rust binding file");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ impl RenderOnce for CharacterGrid {
|
||||
"χ", "ψ", "∂", "а", "в", "Ж", "ж", "З", "з", "К", "к", "л", "м", "Н", "н", "Р", "р",
|
||||
"У", "у", "ф", "ч", "ь", "ы", "Э", "э", "Я", "я", "ij", "öẋ", ".,", "⣝⣑", "~", "*",
|
||||
"_", "^", "`", "'", "(", "{", "«", "#", "&", "@", "$", "¢", "%", "|", "?", "¶", "µ",
|
||||
"❮", "<=", "!=", "==", "--", "++", "=>", "->", "🏀", "🎊", "😍", "❤️", "👍", "👎",
|
||||
"❮", "<=", "!=", "==", "--", "++", "=>", "->",
|
||||
];
|
||||
|
||||
let columns = 11;
|
||||
|
||||
@@ -35,7 +35,6 @@ pub(crate) fn swap_rgba_pa_to_bgra(color: &mut [u8]) {
|
||||
|
||||
/// An RGBA color
|
||||
#[derive(PartialEq, Clone, Copy, Default)]
|
||||
#[repr(C)]
|
||||
pub struct Rgba {
|
||||
/// The red component of the color, in the range 0.0 to 1.0
|
||||
pub r: f32,
|
||||
|
||||
@@ -13,7 +13,8 @@ mod mac;
|
||||
any(target_os = "linux", target_os = "freebsd"),
|
||||
any(feature = "x11", feature = "wayland")
|
||||
),
|
||||
all(target_os = "macos", feature = "macos-blade")
|
||||
target_os = "windows",
|
||||
feature = "macos-blade"
|
||||
))]
|
||||
mod blade;
|
||||
|
||||
@@ -447,8 +448,6 @@ impl Tiling {
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
|
||||
pub(crate) struct RequestFrameOptions {
|
||||
pub(crate) require_presentation: bool,
|
||||
/// Force refresh of all rendering states when true
|
||||
pub(crate) force_render: bool,
|
||||
}
|
||||
|
||||
pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
|
||||
|
||||
@@ -1004,13 +1004,12 @@ impl X11Client {
|
||||
let mut keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code);
|
||||
let keysym = state.xkb.key_get_one_sym(code);
|
||||
|
||||
if keysym.is_modifier_key() {
|
||||
return Some(());
|
||||
}
|
||||
|
||||
// should be called after key_get_one_sym
|
||||
state.xkb.update_key(code, xkbc::KeyDirection::Down);
|
||||
|
||||
if keysym.is_modifier_key() {
|
||||
return Some(());
|
||||
}
|
||||
if let Some(mut compose_state) = state.compose_state.take() {
|
||||
compose_state.feed(keysym);
|
||||
match compose_state.status() {
|
||||
@@ -1068,13 +1067,12 @@ impl X11Client {
|
||||
let keystroke = crate::Keystroke::from_xkb(&state.xkb, modifiers, code);
|
||||
let keysym = state.xkb.key_get_one_sym(code);
|
||||
|
||||
if keysym.is_modifier_key() {
|
||||
return Some(());
|
||||
}
|
||||
|
||||
// should be called after key_get_one_sym
|
||||
state.xkb.update_key(code, xkbc::KeyDirection::Up);
|
||||
|
||||
if keysym.is_modifier_key() {
|
||||
return Some(());
|
||||
}
|
||||
keystroke
|
||||
};
|
||||
drop(state);
|
||||
@@ -1795,7 +1793,6 @@ impl X11ClientState {
|
||||
drop(state);
|
||||
window.refresh(RequestFrameOptions {
|
||||
require_presentation: expose_event_received,
|
||||
force_render: false,
|
||||
});
|
||||
}
|
||||
xcb_connection
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
mod clipboard;
|
||||
mod destination_list;
|
||||
mod direct_write;
|
||||
mod directx_atlas;
|
||||
mod directx_renderer;
|
||||
mod dispatcher;
|
||||
mod display;
|
||||
mod events;
|
||||
@@ -16,8 +14,6 @@ mod wrapper;
|
||||
pub(crate) use clipboard::*;
|
||||
pub(crate) use destination_list::*;
|
||||
pub(crate) use direct_write::*;
|
||||
pub(crate) use directx_atlas::*;
|
||||
pub(crate) use directx_renderer::*;
|
||||
pub(crate) use dispatcher::*;
|
||||
pub(crate) use display::*;
|
||||
pub(crate) use events::*;
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
struct RasterVertexOutput {
|
||||
float4 position : SV_Position;
|
||||
float2 texcoord : TEXCOORD0;
|
||||
};
|
||||
|
||||
RasterVertexOutput emoji_rasterization_vertex(uint vertexID : SV_VERTEXID)
|
||||
{
|
||||
RasterVertexOutput output;
|
||||
output.texcoord = float2((vertexID << 1) & 2, vertexID & 2);
|
||||
output.position = float4(output.texcoord * 2.0f - 1.0f, 0.0f, 1.0f);
|
||||
output.position.y = -output.position.y;
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
struct PixelInput {
|
||||
float4 position: SV_Position;
|
||||
float2 texcoord : TEXCOORD0;
|
||||
};
|
||||
|
||||
struct Bounds {
|
||||
int2 origin;
|
||||
int2 size;
|
||||
};
|
||||
|
||||
Texture2D<float4> t_layer : register(t0);
|
||||
SamplerState s_layer : register(s0);
|
||||
|
||||
cbuffer GlyphLayerTextureParams : register(b0) {
|
||||
Bounds bounds;
|
||||
float4 run_color;
|
||||
};
|
||||
|
||||
float4 emoji_rasterization_fragment(PixelInput input): SV_Target {
|
||||
float3 sampled = t_layer.Sample(s_layer, input.texcoord.xy).rgb;
|
||||
float alpha = (sampled.r + sampled.g + sampled.b) / 3;
|
||||
|
||||
return float4(run_color.rgb, alpha);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,309 +0,0 @@
|
||||
use collections::FxHashMap;
|
||||
use etagere::BucketedAtlasAllocator;
|
||||
use parking_lot::Mutex;
|
||||
use windows::Win32::Graphics::{
|
||||
Direct3D11::{
|
||||
D3D11_BIND_SHADER_RESOURCE, D3D11_BOX, D3D11_CPU_ACCESS_WRITE, D3D11_TEXTURE2D_DESC,
|
||||
D3D11_USAGE_DEFAULT, ID3D11Device, ID3D11DeviceContext, ID3D11ShaderResourceView,
|
||||
ID3D11Texture2D,
|
||||
},
|
||||
Dxgi::Common::*,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, DevicePixels, PlatformAtlas,
|
||||
Point, Size, platform::AtlasTextureList,
|
||||
};
|
||||
|
||||
pub(crate) struct DirectXAtlas(Mutex<DirectXAtlasState>);
|
||||
|
||||
struct DirectXAtlasState {
|
||||
device: ID3D11Device,
|
||||
device_context: ID3D11DeviceContext,
|
||||
monochrome_textures: AtlasTextureList<DirectXAtlasTexture>,
|
||||
polychrome_textures: AtlasTextureList<DirectXAtlasTexture>,
|
||||
tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
|
||||
}
|
||||
|
||||
struct DirectXAtlasTexture {
|
||||
id: AtlasTextureId,
|
||||
bytes_per_pixel: u32,
|
||||
allocator: BucketedAtlasAllocator,
|
||||
texture: ID3D11Texture2D,
|
||||
view: [Option<ID3D11ShaderResourceView>; 1],
|
||||
live_atlas_keys: u32,
|
||||
}
|
||||
|
||||
impl DirectXAtlas {
|
||||
pub(crate) fn new(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Self {
|
||||
DirectXAtlas(Mutex::new(DirectXAtlasState {
|
||||
device: device.clone(),
|
||||
device_context: device_context.clone(),
|
||||
monochrome_textures: Default::default(),
|
||||
polychrome_textures: Default::default(),
|
||||
tiles_by_key: Default::default(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn get_texture_view(
|
||||
&self,
|
||||
id: AtlasTextureId,
|
||||
) -> [Option<ID3D11ShaderResourceView>; 1] {
|
||||
let lock = self.0.lock();
|
||||
let tex = lock.texture(id);
|
||||
tex.view.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn handle_device_lost(
|
||||
&self,
|
||||
device: &ID3D11Device,
|
||||
device_context: &ID3D11DeviceContext,
|
||||
) {
|
||||
let mut lock = self.0.lock();
|
||||
lock.device = device.clone();
|
||||
lock.device_context = device_context.clone();
|
||||
lock.monochrome_textures = AtlasTextureList::default();
|
||||
lock.polychrome_textures = AtlasTextureList::default();
|
||||
lock.tiles_by_key.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl PlatformAtlas for DirectXAtlas {
|
||||
fn get_or_insert_with<'a>(
|
||||
&self,
|
||||
key: &AtlasKey,
|
||||
build: &mut dyn FnMut() -> anyhow::Result<
|
||||
Option<(Size<DevicePixels>, std::borrow::Cow<'a, [u8]>)>,
|
||||
>,
|
||||
) -> anyhow::Result<Option<AtlasTile>> {
|
||||
let mut lock = self.0.lock();
|
||||
if let Some(tile) = lock.tiles_by_key.get(key) {
|
||||
Ok(Some(tile.clone()))
|
||||
} else {
|
||||
let Some((size, bytes)) = build()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let tile = lock
|
||||
.allocate(size, key.texture_kind())
|
||||
.ok_or_else(|| anyhow::anyhow!("failed to allocate"))?;
|
||||
let texture = lock.texture(tile.texture_id);
|
||||
texture.upload(&lock.device_context, tile.bounds, &bytes);
|
||||
lock.tiles_by_key.insert(key.clone(), tile.clone());
|
||||
Ok(Some(tile))
|
||||
}
|
||||
}
|
||||
|
||||
fn remove(&self, key: &AtlasKey) {
|
||||
let mut lock = self.0.lock();
|
||||
|
||||
let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let textures = match id.kind {
|
||||
AtlasTextureKind::Monochrome => &mut lock.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut lock.polychrome_textures,
|
||||
};
|
||||
|
||||
let Some(texture_slot) = textures.textures.get_mut(id.index as usize) else {
|
||||
return;
|
||||
};
|
||||
|
||||
if let Some(mut texture) = texture_slot.take() {
|
||||
texture.decrement_ref_count();
|
||||
if texture.is_unreferenced() {
|
||||
textures.free_list.push(texture.id.index as usize);
|
||||
lock.tiles_by_key.remove(key);
|
||||
} else {
|
||||
*texture_slot = Some(texture);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectXAtlasState {
|
||||
fn allocate(
|
||||
&mut self,
|
||||
size: Size<DevicePixels>,
|
||||
texture_kind: AtlasTextureKind,
|
||||
) -> Option<AtlasTile> {
|
||||
{
|
||||
let textures = match texture_kind {
|
||||
AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
|
||||
};
|
||||
|
||||
if let Some(tile) = textures
|
||||
.iter_mut()
|
||||
.rev()
|
||||
.find_map(|texture| texture.allocate(size))
|
||||
{
|
||||
return Some(tile);
|
||||
}
|
||||
}
|
||||
|
||||
let texture = self.push_texture(size, texture_kind)?;
|
||||
texture.allocate(size)
|
||||
}
|
||||
|
||||
fn push_texture(
|
||||
&mut self,
|
||||
min_size: Size<DevicePixels>,
|
||||
kind: AtlasTextureKind,
|
||||
) -> Option<&mut DirectXAtlasTexture> {
|
||||
const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size {
|
||||
width: DevicePixels(1024),
|
||||
height: DevicePixels(1024),
|
||||
};
|
||||
// Max texture size for DirectX. See:
|
||||
// https://learn.microsoft.com/en-us/windows/win32/direct3d11/overviews-direct3d-11-resources-limits
|
||||
const MAX_ATLAS_SIZE: Size<DevicePixels> = Size {
|
||||
width: DevicePixels(16384),
|
||||
height: DevicePixels(16384),
|
||||
};
|
||||
let size = min_size.min(&MAX_ATLAS_SIZE).max(&DEFAULT_ATLAS_SIZE);
|
||||
let pixel_format;
|
||||
let bind_flag;
|
||||
let bytes_per_pixel;
|
||||
match kind {
|
||||
AtlasTextureKind::Monochrome => {
|
||||
pixel_format = DXGI_FORMAT_R8_UNORM;
|
||||
bind_flag = D3D11_BIND_SHADER_RESOURCE;
|
||||
bytes_per_pixel = 1;
|
||||
}
|
||||
AtlasTextureKind::Polychrome => {
|
||||
pixel_format = DXGI_FORMAT_B8G8R8A8_UNORM;
|
||||
bind_flag = D3D11_BIND_SHADER_RESOURCE;
|
||||
bytes_per_pixel = 4;
|
||||
}
|
||||
}
|
||||
let texture_desc = D3D11_TEXTURE2D_DESC {
|
||||
Width: size.width.0 as u32,
|
||||
Height: size.height.0 as u32,
|
||||
MipLevels: 1,
|
||||
ArraySize: 1,
|
||||
Format: pixel_format,
|
||||
SampleDesc: DXGI_SAMPLE_DESC {
|
||||
Count: 1,
|
||||
Quality: 0,
|
||||
},
|
||||
Usage: D3D11_USAGE_DEFAULT,
|
||||
BindFlags: bind_flag.0 as u32,
|
||||
CPUAccessFlags: D3D11_CPU_ACCESS_WRITE.0 as u32,
|
||||
MiscFlags: 0,
|
||||
};
|
||||
let mut texture: Option<ID3D11Texture2D> = None;
|
||||
unsafe {
|
||||
// This only returns None if the device is lost, which we will recreate later.
|
||||
// So it's ok to return None here.
|
||||
self.device
|
||||
.CreateTexture2D(&texture_desc, None, Some(&mut texture))
|
||||
.ok()?;
|
||||
}
|
||||
let texture = texture.unwrap();
|
||||
|
||||
let texture_list = match kind {
|
||||
AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
|
||||
AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
|
||||
};
|
||||
let index = texture_list.free_list.pop();
|
||||
let view = unsafe {
|
||||
let mut view = None;
|
||||
self.device
|
||||
.CreateShaderResourceView(&texture, None, Some(&mut view))
|
||||
.ok()?;
|
||||
[view]
|
||||
};
|
||||
let atlas_texture = DirectXAtlasTexture {
|
||||
id: AtlasTextureId {
|
||||
index: index.unwrap_or(texture_list.textures.len()) as u32,
|
||||
kind,
|
||||
},
|
||||
bytes_per_pixel,
|
||||
allocator: etagere::BucketedAtlasAllocator::new(size.into()),
|
||||
texture,
|
||||
view,
|
||||
live_atlas_keys: 0,
|
||||
};
|
||||
if let Some(ix) = index {
|
||||
texture_list.textures[ix] = Some(atlas_texture);
|
||||
texture_list.textures.get_mut(ix).unwrap().as_mut()
|
||||
} else {
|
||||
texture_list.textures.push(Some(atlas_texture));
|
||||
texture_list.textures.last_mut().unwrap().as_mut()
|
||||
}
|
||||
}
|
||||
|
||||
fn texture(&self, id: AtlasTextureId) -> &DirectXAtlasTexture {
|
||||
let textures = match id.kind {
|
||||
crate::AtlasTextureKind::Monochrome => &self.monochrome_textures,
|
||||
crate::AtlasTextureKind::Polychrome => &self.polychrome_textures,
|
||||
};
|
||||
textures[id.index as usize].as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl DirectXAtlasTexture {
|
||||
fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
|
||||
let allocation = self.allocator.allocate(size.into())?;
|
||||
let tile = AtlasTile {
|
||||
texture_id: self.id,
|
||||
tile_id: allocation.id.into(),
|
||||
bounds: Bounds {
|
||||
origin: allocation.rectangle.min.into(),
|
||||
size,
|
||||
},
|
||||
padding: 0,
|
||||
};
|
||||
self.live_atlas_keys += 1;
|
||||
Some(tile)
|
||||
}
|
||||
|
||||
fn upload(
|
||||
&self,
|
||||
device_context: &ID3D11DeviceContext,
|
||||
bounds: Bounds<DevicePixels>,
|
||||
bytes: &[u8],
|
||||
) {
|
||||
unsafe {
|
||||
device_context.UpdateSubresource(
|
||||
&self.texture,
|
||||
0,
|
||||
Some(&D3D11_BOX {
|
||||
left: bounds.left().0 as u32,
|
||||
top: bounds.top().0 as u32,
|
||||
front: 0,
|
||||
right: bounds.right().0 as u32,
|
||||
bottom: bounds.bottom().0 as u32,
|
||||
back: 1,
|
||||
}),
|
||||
bytes.as_ptr() as _,
|
||||
bounds.size.width.to_bytes(self.bytes_per_pixel as u8),
|
||||
0,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn decrement_ref_count(&mut self) {
|
||||
self.live_atlas_keys -= 1;
|
||||
}
|
||||
|
||||
fn is_unreferenced(&mut self) -> bool {
|
||||
self.live_atlas_keys == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Size<DevicePixels>> for etagere::Size {
|
||||
fn from(size: Size<DevicePixels>) -> Self {
|
||||
etagere::Size::new(size.width.into(), size.height.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<etagere::Point> for Point<DevicePixels> {
|
||||
fn from(value: etagere::Point) -> Self {
|
||||
Point {
|
||||
x: DevicePixels::from(value.x),
|
||||
y: DevicePixels::from(value.y),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,6 @@ pub(crate) const WM_GPUI_CURSOR_STYLE_CHANGED: u32 = WM_USER + 1;
|
||||
pub(crate) const WM_GPUI_CLOSE_ONE_WINDOW: u32 = WM_USER + 2;
|
||||
pub(crate) const WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD: u32 = WM_USER + 3;
|
||||
pub(crate) const WM_GPUI_DOCK_MENU_ACTION: u32 = WM_USER + 4;
|
||||
pub(crate) const WM_GPUI_FORCE_UPDATE_WINDOW: u32 = WM_USER + 5;
|
||||
|
||||
const SIZE_MOVE_LOOP_TIMER_ID: usize = 1;
|
||||
const AUTO_HIDE_TASKBAR_THICKNESS_PX: i32 = 1;
|
||||
@@ -38,7 +37,6 @@ pub(crate) fn handle_msg(
|
||||
let handled = match msg {
|
||||
WM_ACTIVATE => handle_activate_msg(wparam, state_ptr),
|
||||
WM_CREATE => handle_create_msg(handle, state_ptr),
|
||||
WM_DEVICECHANGE => handle_device_change_msg(handle, wparam, state_ptr),
|
||||
WM_MOVE => handle_move_msg(handle, lparam, state_ptr),
|
||||
WM_SIZE => handle_size_msg(wparam, lparam, state_ptr),
|
||||
WM_GETMINMAXINFO => handle_get_min_max_info_msg(lparam, state_ptr),
|
||||
@@ -50,7 +48,7 @@ pub(crate) fn handle_msg(
|
||||
WM_DISPLAYCHANGE => handle_display_change_msg(handle, state_ptr),
|
||||
WM_NCHITTEST => handle_hit_test_msg(handle, msg, wparam, lparam, state_ptr),
|
||||
WM_PAINT => handle_paint_msg(handle, state_ptr),
|
||||
WM_CLOSE => handle_close_msg(state_ptr),
|
||||
WM_CLOSE => handle_close_msg(handle, state_ptr),
|
||||
WM_DESTROY => handle_destroy_msg(handle, state_ptr),
|
||||
WM_MOUSEMOVE => handle_mouse_move_msg(handle, lparam, wparam, state_ptr),
|
||||
WM_MOUSELEAVE | WM_NCMOUSELEAVE => handle_mouse_leave_msg(state_ptr),
|
||||
@@ -98,7 +96,6 @@ pub(crate) fn handle_msg(
|
||||
WM_SETTINGCHANGE => handle_system_settings_changed(handle, wparam, lparam, state_ptr),
|
||||
WM_INPUTLANGCHANGE => handle_input_language_changed(lparam, state_ptr),
|
||||
WM_GPUI_CURSOR_STYLE_CHANGED => handle_cursor_changed(lparam, state_ptr),
|
||||
WM_GPUI_FORCE_UPDATE_WINDOW => draw_window(handle, true, state_ptr),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(n) = handled {
|
||||
@@ -184,9 +181,11 @@ fn handle_size_msg(
|
||||
let new_size = size(DevicePixels(width), DevicePixels(height));
|
||||
let scale_factor = lock.scale_factor;
|
||||
if lock.restore_from_minimized.is_some() {
|
||||
lock.renderer
|
||||
.update_drawable_size_even_if_unchanged(new_size);
|
||||
lock.callbacks.request_frame = lock.restore_from_minimized.take();
|
||||
} else {
|
||||
lock.renderer.resize(new_size).log_err();
|
||||
lock.renderer.update_drawable_size(new_size);
|
||||
}
|
||||
let new_size = new_size.to_pixels(scale_factor);
|
||||
lock.logical_size = new_size;
|
||||
@@ -239,14 +238,40 @@ fn handle_timer_msg(
|
||||
}
|
||||
|
||||
fn handle_paint_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
|
||||
draw_window(handle, false, state_ptr)
|
||||
let mut lock = state_ptr.state.borrow_mut();
|
||||
if let Some(mut request_frame) = lock.callbacks.request_frame.take() {
|
||||
drop(lock);
|
||||
request_frame(Default::default());
|
||||
state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame);
|
||||
}
|
||||
unsafe { ValidateRect(Some(handle), None).ok().log_err() };
|
||||
Some(0)
|
||||
}
|
||||
|
||||
fn handle_close_msg(state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
|
||||
let mut callback = state_ptr.state.borrow_mut().callbacks.should_close.take()?;
|
||||
let should_close = callback();
|
||||
state_ptr.state.borrow_mut().callbacks.should_close = Some(callback);
|
||||
if should_close { None } else { Some(0) }
|
||||
fn handle_close_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
|
||||
let mut lock = state_ptr.state.borrow_mut();
|
||||
let output = if let Some(mut callback) = lock.callbacks.should_close.take() {
|
||||
drop(lock);
|
||||
let should_close = callback();
|
||||
state_ptr.state.borrow_mut().callbacks.should_close = Some(callback);
|
||||
if should_close { None } else { Some(0) }
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Workaround as window close animation is not played with `WS_EX_LAYERED` enabled.
|
||||
if output.is_none() {
|
||||
unsafe {
|
||||
let current_style = get_window_long(handle, GWL_EXSTYLE);
|
||||
set_window_long(
|
||||
handle,
|
||||
GWL_EXSTYLE,
|
||||
current_style & !WS_EX_LAYERED.0 as isize,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn handle_destroy_msg(handle: HWND, state_ptr: Rc<WindowsWindowStatePtr>) -> Option<isize> {
|
||||
@@ -1198,53 +1223,6 @@ fn handle_input_language_changed(
|
||||
Some(0)
|
||||
}
|
||||
|
||||
fn handle_device_change_msg(
|
||||
handle: HWND,
|
||||
wparam: WPARAM,
|
||||
state_ptr: Rc<WindowsWindowStatePtr>,
|
||||
) -> Option<isize> {
|
||||
if wparam.0 == DBT_DEVNODES_CHANGED as usize {
|
||||
// The reason for sending this message is to actually trigger a redraw of the window.
|
||||
unsafe {
|
||||
PostMessageW(
|
||||
Some(handle),
|
||||
WM_GPUI_FORCE_UPDATE_WINDOW,
|
||||
WPARAM(0),
|
||||
LPARAM(0),
|
||||
)
|
||||
.log_err();
|
||||
}
|
||||
// If the GPU device is lost, this redraw will take care of recreating the device context.
|
||||
// The WM_GPUI_FORCE_UPDATE_WINDOW message will take care of redrawing the window, after
|
||||
// the device context has been recreated.
|
||||
draw_window(handle, true, state_ptr)
|
||||
} else {
|
||||
// Other device change messages are not handled.
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn draw_window(
|
||||
handle: HWND,
|
||||
force_render: bool,
|
||||
state_ptr: Rc<WindowsWindowStatePtr>,
|
||||
) -> Option<isize> {
|
||||
let mut request_frame = state_ptr
|
||||
.state
|
||||
.borrow_mut()
|
||||
.callbacks
|
||||
.request_frame
|
||||
.take()?;
|
||||
request_frame(RequestFrameOptions {
|
||||
require_presentation: false,
|
||||
force_render,
|
||||
});
|
||||
state_ptr.state.borrow_mut().callbacks.request_frame = Some(request_frame);
|
||||
unsafe { ValidateRect(Some(handle), None).ok().log_err() };
|
||||
Some(0)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn parse_char_message(wparam: WPARAM, state_ptr: &Rc<WindowsWindowStatePtr>) -> Option<String> {
|
||||
let code_point = wparam.loword();
|
||||
|
||||
@@ -28,12 +28,13 @@ use windows::{
|
||||
core::*,
|
||||
};
|
||||
|
||||
use crate::*;
|
||||
use crate::{platform::blade::BladeContext, *};
|
||||
|
||||
pub(crate) struct WindowsPlatform {
|
||||
state: RefCell<WindowsPlatformState>,
|
||||
raw_window_handles: RwLock<SmallVec<[HWND; 4]>>,
|
||||
// The below members will never change throughout the entire lifecycle of the app.
|
||||
gpu_context: BladeContext,
|
||||
icon: HICON,
|
||||
main_receiver: flume::Receiver<Runnable>,
|
||||
background_executor: BackgroundExecutor,
|
||||
@@ -44,7 +45,6 @@ pub(crate) struct WindowsPlatform {
|
||||
drop_target_helper: IDropTargetHelper,
|
||||
validation_number: usize,
|
||||
main_thread_id_win32: u32,
|
||||
disable_direct_composition: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct WindowsPlatformState {
|
||||
@@ -94,18 +94,14 @@ impl WindowsPlatform {
|
||||
main_thread_id_win32,
|
||||
validation_number,
|
||||
));
|
||||
let disable_direct_composition = std::env::var(DISABLE_DIRECT_COMPOSITION)
|
||||
.is_ok_and(|value| value == "true" || value == "1");
|
||||
let background_executor = BackgroundExecutor::new(dispatcher.clone());
|
||||
let foreground_executor = ForegroundExecutor::new(dispatcher);
|
||||
let directx_devices = DirectXDevices::new(disable_direct_composition)
|
||||
.context("Unable to init directx devices.")?;
|
||||
let bitmap_factory = ManuallyDrop::new(unsafe {
|
||||
CoCreateInstance(&CLSID_WICImagingFactory, None, CLSCTX_INPROC_SERVER)
|
||||
.context("Error creating bitmap factory.")?
|
||||
});
|
||||
let text_system = Arc::new(
|
||||
DirectWriteTextSystem::new(&directx_devices, &bitmap_factory)
|
||||
DirectWriteTextSystem::new(&bitmap_factory)
|
||||
.context("Error creating DirectWriteTextSystem")?,
|
||||
);
|
||||
let drop_target_helper: IDropTargetHelper = unsafe {
|
||||
@@ -115,17 +111,18 @@ impl WindowsPlatform {
|
||||
let icon = load_icon().unwrap_or_default();
|
||||
let state = RefCell::new(WindowsPlatformState::new());
|
||||
let raw_window_handles = RwLock::new(SmallVec::new());
|
||||
let gpu_context = BladeContext::new().context("Unable to init GPU context")?;
|
||||
let windows_version = WindowsVersion::new().context("Error retrieve windows version")?;
|
||||
|
||||
Ok(Self {
|
||||
state,
|
||||
raw_window_handles,
|
||||
gpu_context,
|
||||
icon,
|
||||
main_receiver,
|
||||
background_executor,
|
||||
foreground_executor,
|
||||
text_system,
|
||||
disable_direct_composition,
|
||||
windows_version,
|
||||
bitmap_factory,
|
||||
drop_target_helper,
|
||||
@@ -190,7 +187,6 @@ impl WindowsPlatform {
|
||||
validation_number: self.validation_number,
|
||||
main_receiver: self.main_receiver.clone(),
|
||||
main_thread_id_win32: self.main_thread_id_win32,
|
||||
disable_direct_composition: self.disable_direct_composition,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,11 +343,27 @@ impl Platform for WindowsPlatform {
|
||||
|
||||
fn run(&self, on_finish_launching: Box<dyn 'static + FnOnce()>) {
|
||||
on_finish_launching();
|
||||
loop {
|
||||
if self.handle_events() {
|
||||
break;
|
||||
let vsync_event = unsafe { Owned::new(CreateEventW(None, false, false, None).unwrap()) };
|
||||
begin_vsync(*vsync_event);
|
||||
'a: loop {
|
||||
let wait_result = unsafe {
|
||||
MsgWaitForMultipleObjects(Some(&[*vsync_event]), false, INFINITE, QS_ALLINPUT)
|
||||
};
|
||||
|
||||
match wait_result {
|
||||
// compositor clock ticked so we should draw a frame
|
||||
WAIT_EVENT(0) => self.redraw_all(),
|
||||
// Windows thread messages are posted
|
||||
WAIT_EVENT(1) => {
|
||||
if self.handle_events() {
|
||||
break 'a;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
log::error!("Something went wrong while waiting {:?}", wait_result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.redraw_all();
|
||||
}
|
||||
|
||||
if let Some(ref mut callback) = self.state.borrow_mut().callbacks.quit {
|
||||
@@ -443,7 +455,12 @@ impl Platform for WindowsPlatform {
|
||||
handle: AnyWindowHandle,
|
||||
options: WindowParams,
|
||||
) -> Result<Box<dyn PlatformWindow>> {
|
||||
let window = WindowsWindow::new(handle, options, self.generate_creation_info())?;
|
||||
let window = WindowsWindow::new(
|
||||
handle,
|
||||
options,
|
||||
self.generate_creation_info(),
|
||||
&self.gpu_context,
|
||||
)?;
|
||||
let handle = window.get_raw_handle();
|
||||
self.raw_window_handles.write().push(handle);
|
||||
|
||||
@@ -722,7 +739,6 @@ pub(crate) struct WindowCreationInfo {
|
||||
pub(crate) validation_number: usize,
|
||||
pub(crate) main_receiver: flume::Receiver<Runnable>,
|
||||
pub(crate) main_thread_id_win32: u32,
|
||||
pub(crate) disable_direct_composition: bool,
|
||||
}
|
||||
|
||||
fn open_target(target: &str) {
|
||||
@@ -830,6 +846,16 @@ fn file_save_dialog(directory: PathBuf, window: Option<HWND>) -> Result<Option<P
|
||||
Ok(Some(PathBuf::from(file_path_string)))
|
||||
}
|
||||
|
||||
fn begin_vsync(vsync_event: HANDLE) {
|
||||
let event: SafeHandle = vsync_event.into();
|
||||
std::thread::spawn(move || unsafe {
|
||||
loop {
|
||||
windows::Win32::Graphics::Dwm::DwmFlush().log_err();
|
||||
SetEvent(*event).log_err();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn load_icon() -> Result<HICON> {
|
||||
let module = unsafe { GetModuleHandleW(None).context("unable to get module handle")? };
|
||||
let handle = unsafe {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,7 @@ use windows::{
|
||||
core::*,
|
||||
};
|
||||
|
||||
use crate::platform::blade::{BladeContext, BladeRenderer};
|
||||
use crate::*;
|
||||
|
||||
pub(crate) struct WindowsWindow(pub Rc<WindowsWindowStatePtr>);
|
||||
@@ -48,7 +49,7 @@ pub struct WindowsWindowState {
|
||||
pub system_key_handled: bool,
|
||||
pub hovered: bool,
|
||||
|
||||
pub renderer: DirectXRenderer,
|
||||
pub renderer: BladeRenderer,
|
||||
|
||||
pub click_state: ClickState,
|
||||
pub system_settings: WindowsSystemSettings,
|
||||
@@ -79,12 +80,13 @@ pub(crate) struct WindowsWindowStatePtr {
|
||||
impl WindowsWindowState {
|
||||
fn new(
|
||||
hwnd: HWND,
|
||||
transparent: bool,
|
||||
cs: &CREATESTRUCTW,
|
||||
current_cursor: Option<HCURSOR>,
|
||||
display: WindowsDisplay,
|
||||
gpu_context: &BladeContext,
|
||||
min_size: Option<Size<Pixels>>,
|
||||
appearance: WindowAppearance,
|
||||
disable_direct_composition: bool,
|
||||
) -> Result<Self> {
|
||||
let scale_factor = {
|
||||
let monitor_dpi = unsafe { GetDpiForWindow(hwnd) } as f32;
|
||||
@@ -101,8 +103,7 @@ impl WindowsWindowState {
|
||||
};
|
||||
let border_offset = WindowBorderOffset::default();
|
||||
let restore_from_minimized = None;
|
||||
let renderer = DirectXRenderer::new(hwnd, disable_direct_composition)
|
||||
.context("Creating DirectX renderer")?;
|
||||
let renderer = windows_renderer::init(gpu_context, hwnd, transparent)?;
|
||||
let callbacks = Callbacks::default();
|
||||
let input_handler = None;
|
||||
let pending_surrogate = None;
|
||||
@@ -205,12 +206,13 @@ impl WindowsWindowStatePtr {
|
||||
fn new(context: &WindowCreateContext, hwnd: HWND, cs: &CREATESTRUCTW) -> Result<Rc<Self>> {
|
||||
let state = RefCell::new(WindowsWindowState::new(
|
||||
hwnd,
|
||||
context.transparent,
|
||||
cs,
|
||||
context.current_cursor,
|
||||
context.display,
|
||||
context.gpu_context,
|
||||
context.min_size,
|
||||
context.appearance,
|
||||
context.disable_direct_composition,
|
||||
)?);
|
||||
|
||||
Ok(Rc::new_cyclic(|this| Self {
|
||||
@@ -327,11 +329,12 @@ pub(crate) struct Callbacks {
|
||||
pub(crate) appearance_changed: Option<Box<dyn FnMut()>>,
|
||||
}
|
||||
|
||||
struct WindowCreateContext {
|
||||
struct WindowCreateContext<'a> {
|
||||
inner: Option<Result<Rc<WindowsWindowStatePtr>>>,
|
||||
handle: AnyWindowHandle,
|
||||
hide_title_bar: bool,
|
||||
display: WindowsDisplay,
|
||||
transparent: bool,
|
||||
is_movable: bool,
|
||||
min_size: Option<Size<Pixels>>,
|
||||
executor: ForegroundExecutor,
|
||||
@@ -340,9 +343,9 @@ struct WindowCreateContext {
|
||||
drop_target_helper: IDropTargetHelper,
|
||||
validation_number: usize,
|
||||
main_receiver: flume::Receiver<Runnable>,
|
||||
gpu_context: &'a BladeContext,
|
||||
main_thread_id_win32: u32,
|
||||
appearance: WindowAppearance,
|
||||
disable_direct_composition: bool,
|
||||
}
|
||||
|
||||
impl WindowsWindow {
|
||||
@@ -350,6 +353,7 @@ impl WindowsWindow {
|
||||
handle: AnyWindowHandle,
|
||||
params: WindowParams,
|
||||
creation_info: WindowCreationInfo,
|
||||
gpu_context: &BladeContext,
|
||||
) -> Result<Self> {
|
||||
let WindowCreationInfo {
|
||||
icon,
|
||||
@@ -360,7 +364,6 @@ impl WindowsWindow {
|
||||
validation_number,
|
||||
main_receiver,
|
||||
main_thread_id_win32,
|
||||
disable_direct_composition,
|
||||
} = creation_info;
|
||||
let classname = register_wnd_class(icon);
|
||||
let hide_title_bar = params
|
||||
@@ -376,18 +379,14 @@ impl WindowsWindow {
|
||||
.map(|title| title.as_ref())
|
||||
.unwrap_or(""),
|
||||
);
|
||||
|
||||
let (mut dwexstyle, dwstyle) = if params.kind == WindowKind::PopUp {
|
||||
(WS_EX_TOOLWINDOW, WINDOW_STYLE(0x0))
|
||||
let (dwexstyle, mut dwstyle) = if params.kind == WindowKind::PopUp {
|
||||
(WS_EX_TOOLWINDOW | WS_EX_LAYERED, WINDOW_STYLE(0x0))
|
||||
} else {
|
||||
(
|
||||
WS_EX_APPWINDOW,
|
||||
WS_EX_APPWINDOW | WS_EX_LAYERED,
|
||||
WS_THICKFRAME | WS_SYSMENU | WS_MAXIMIZEBOX | WS_MINIMIZEBOX,
|
||||
)
|
||||
};
|
||||
if !disable_direct_composition {
|
||||
dwexstyle |= WS_EX_NOREDIRECTIONBITMAP;
|
||||
}
|
||||
|
||||
let hinstance = get_module_handle();
|
||||
let display = if let Some(display_id) = params.display_id {
|
||||
@@ -402,6 +401,7 @@ impl WindowsWindow {
|
||||
handle,
|
||||
hide_title_bar,
|
||||
display,
|
||||
transparent: true,
|
||||
is_movable: params.is_movable,
|
||||
min_size: params.window_min_size,
|
||||
executor,
|
||||
@@ -410,9 +410,9 @@ impl WindowsWindow {
|
||||
drop_target_helper,
|
||||
validation_number,
|
||||
main_receiver,
|
||||
gpu_context,
|
||||
main_thread_id_win32,
|
||||
appearance,
|
||||
disable_direct_composition,
|
||||
};
|
||||
let lpparam = Some(&context as *const _ as *const _);
|
||||
let creation_result = unsafe {
|
||||
@@ -453,6 +453,14 @@ impl WindowsWindow {
|
||||
state: WindowOpenState::Windowed,
|
||||
});
|
||||
}
|
||||
// The render pipeline will perform compositing on the GPU when the
|
||||
// swapchain is configured correctly (see downstream of
|
||||
// update_transparency).
|
||||
// The following configuration is a one-time setup to ensure that the
|
||||
// window is going to be composited with per-pixel alpha, but the render
|
||||
// pipeline is responsible for effectively calling UpdateLayeredWindow
|
||||
// at the appropriate time.
|
||||
unsafe { SetLayeredWindowAttributes(hwnd, COLORREF(0), 255, LWA_ALPHA)? };
|
||||
|
||||
Ok(Self(state_ptr))
|
||||
}
|
||||
@@ -477,6 +485,7 @@ impl rwh::HasDisplayHandle for WindowsWindow {
|
||||
|
||||
impl Drop for WindowsWindow {
|
||||
fn drop(&mut self) {
|
||||
self.0.state.borrow_mut().renderer.destroy();
|
||||
// clone this `Rc` to prevent early release of the pointer
|
||||
let this = self.0.clone();
|
||||
self.0
|
||||
@@ -696,21 +705,24 @@ impl PlatformWindow for WindowsWindow {
|
||||
}
|
||||
|
||||
fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) {
|
||||
let hwnd = self.0.hwnd;
|
||||
let mut window_state = self.0.state.borrow_mut();
|
||||
window_state
|
||||
.renderer
|
||||
.update_transparency(background_appearance != WindowBackgroundAppearance::Opaque);
|
||||
|
||||
match background_appearance {
|
||||
WindowBackgroundAppearance::Opaque => {
|
||||
// ACCENT_DISABLED
|
||||
set_window_composition_attribute(hwnd, None, 0);
|
||||
set_window_composition_attribute(window_state.hwnd, None, 0);
|
||||
}
|
||||
WindowBackgroundAppearance::Transparent => {
|
||||
// Use ACCENT_ENABLE_TRANSPARENTGRADIENT for transparent background
|
||||
set_window_composition_attribute(hwnd, None, 2);
|
||||
set_window_composition_attribute(window_state.hwnd, None, 2);
|
||||
}
|
||||
WindowBackgroundAppearance::Blurred => {
|
||||
// Enable acrylic blur
|
||||
// ACCENT_ENABLE_ACRYLICBLURBEHIND
|
||||
set_window_composition_attribute(hwnd, Some((0, 0, 0, 0)), 4);
|
||||
set_window_composition_attribute(window_state.hwnd, Some((0, 0, 0, 0)), 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -782,11 +794,11 @@ impl PlatformWindow for WindowsWindow {
|
||||
}
|
||||
|
||||
fn draw(&self, scene: &Scene) {
|
||||
self.0.state.borrow_mut().renderer.draw(scene).log_err();
|
||||
self.0.state.borrow_mut().renderer.draw(scene)
|
||||
}
|
||||
|
||||
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> {
|
||||
self.0.state.borrow().renderer.sprite_atlas()
|
||||
self.0.state.borrow().renderer.sprite_atlas().clone()
|
||||
}
|
||||
|
||||
fn get_raw_handle(&self) -> HWND {
|
||||
@@ -794,11 +806,11 @@ impl PlatformWindow for WindowsWindow {
|
||||
}
|
||||
|
||||
fn gpu_specs(&self) -> Option<GpuSpecs> {
|
||||
self.0.state.borrow().renderer.gpu_specs().log_err()
|
||||
Some(self.0.state.borrow().renderer.gpu_specs())
|
||||
}
|
||||
|
||||
fn update_ime_position(&self, _bounds: Bounds<ScaledPixels>) {
|
||||
// There is no such thing on Windows.
|
||||
// todo(windows)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1294,6 +1306,52 @@ fn set_window_composition_attribute(hwnd: HWND, color: Option<Color>, state: u32
|
||||
}
|
||||
}
|
||||
|
||||
mod windows_renderer {
|
||||
use crate::platform::blade::{BladeContext, BladeRenderer, BladeSurfaceConfig};
|
||||
use raw_window_handle as rwh;
|
||||
use std::num::NonZeroIsize;
|
||||
use windows::Win32::{Foundation::HWND, UI::WindowsAndMessaging::GWLP_HINSTANCE};
|
||||
|
||||
use crate::{get_window_long, show_error};
|
||||
|
||||
pub(super) fn init(
|
||||
context: &BladeContext,
|
||||
hwnd: HWND,
|
||||
transparent: bool,
|
||||
) -> anyhow::Result<BladeRenderer> {
|
||||
let raw = RawWindow { hwnd };
|
||||
let config = BladeSurfaceConfig {
|
||||
size: Default::default(),
|
||||
transparent,
|
||||
};
|
||||
BladeRenderer::new(context, &raw, config)
|
||||
.inspect_err(|err| show_error("Failed to initialize BladeRenderer", err.to_string()))
|
||||
}
|
||||
|
||||
struct RawWindow {
|
||||
hwnd: HWND,
|
||||
}
|
||||
|
||||
impl rwh::HasWindowHandle for RawWindow {
|
||||
fn window_handle(&self) -> Result<rwh::WindowHandle<'_>, rwh::HandleError> {
|
||||
Ok(unsafe {
|
||||
let hwnd = NonZeroIsize::new_unchecked(self.hwnd.0 as isize);
|
||||
let mut handle = rwh::Win32WindowHandle::new(hwnd);
|
||||
let hinstance = get_window_long(self.hwnd, GWLP_HINSTANCE);
|
||||
handle.hinstance = NonZeroIsize::new(hinstance);
|
||||
rwh::WindowHandle::borrow_raw(handle.into())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl rwh::HasDisplayHandle for RawWindow {
|
||||
fn display_handle(&self) -> Result<rwh::DisplayHandle<'_>, rwh::HandleError> {
|
||||
let handle = rwh::WindowsDisplayHandle::new();
|
||||
Ok(unsafe { rwh::DisplayHandle::borrow_raw(handle.into()) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ClickState;
|
||||
|
||||
@@ -1020,7 +1020,7 @@ impl Window {
|
||||
|| (active.get()
|
||||
&& last_input_timestamp.get().elapsed() < Duration::from_secs(1));
|
||||
|
||||
if invalidator.is_dirty() || request_frame_options.force_render {
|
||||
if invalidator.is_dirty() {
|
||||
measure("frame duration", || {
|
||||
handle
|
||||
.update(&mut cx, |_, window, cx| {
|
||||
|
||||
@@ -236,22 +236,6 @@ impl HttpClientWithUrl {
|
||||
)?)
|
||||
}
|
||||
|
||||
/// Builds a Zed Cloud URL using the given path.
|
||||
pub fn build_zed_cloud_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
|
||||
let base_url = self.base_url();
|
||||
let base_api_url = match base_url.as_ref() {
|
||||
"https://zed.dev" => "https://cloud.zed.dev",
|
||||
"https://staging.zed.dev" => "https://cloud.zed.dev",
|
||||
"http://localhost:3000" => "http://localhost:8787",
|
||||
other => other,
|
||||
};
|
||||
|
||||
Ok(Url::parse_with_params(
|
||||
&format!("{}{}", base_api_url, path),
|
||||
query,
|
||||
)?)
|
||||
}
|
||||
|
||||
/// Builds a Zed LLM URL using the given path.
|
||||
pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
|
||||
let base_url = self.base_url();
|
||||
|
||||
@@ -15,7 +15,6 @@ doctest = false
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
copilot.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
@@ -33,6 +32,7 @@ ui.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
zeta.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use client::{DisableAiSettings, UserStore, zed_urls};
|
||||
use cloud_llm_client::UsageLimit;
|
||||
use copilot::{Copilot, Status};
|
||||
use editor::{
|
||||
Editor, SelectionEffects,
|
||||
@@ -35,6 +34,7 @@ use workspace::{
|
||||
notifications::NotificationId,
|
||||
};
|
||||
use zed_actions::OpenBrowser;
|
||||
use zed_llm_client::UsageLimit;
|
||||
use zeta::RateCompletions;
|
||||
|
||||
actions!(
|
||||
|
||||
@@ -20,7 +20,6 @@ anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
base64.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
futures.workspace = true
|
||||
gpui.workspace = true
|
||||
@@ -38,6 +37,7 @@ telemetry_events.workspace = true
|
||||
thiserror.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -11,7 +11,6 @@ pub mod fake_provider;
|
||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
@@ -27,6 +26,7 @@ use std::time::Duration;
|
||||
use std::{fmt, io};
|
||||
use thiserror::Error;
|
||||
use util::serde::is_default;
|
||||
use zed_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||
|
||||
pub use crate::model::*;
|
||||
pub use crate::rate_limiter::*;
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use std::io::{Cursor, Write};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::role::Role;
|
||||
use crate::{LanguageModelToolUse, LanguageModelToolUseId};
|
||||
use anyhow::Result;
|
||||
use base64::write::EncoderWriter;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionMode};
|
||||
use gpui::{
|
||||
App, AppContext as _, DevicePixels, Image, ImageFormat, ObjectFit, SharedString, Size, Task,
|
||||
point, px, size,
|
||||
@@ -11,9 +12,7 @@ use gpui::{
|
||||
use image::codecs::png::PngEncoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::role::Role;
|
||||
use crate::{LanguageModelToolUse, LanguageModelToolUseId};
|
||||
use zed_llm_client::{CompletionIntent, CompletionMode};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct LanguageModelImage {
|
||||
|
||||
@@ -15,18 +15,20 @@ path = "src/language_models.rs"
|
||||
ai_onboarding.workspace = true
|
||||
anthropic = { workspace = true, features = ["schemars"] }
|
||||
anyhow.workspace = true
|
||||
|
||||
aws-config = { workspace = true, features = ["behavior-version-latest"] }
|
||||
aws-credential-types = { workspace = true, features = ["hardcoded-credentials"] }
|
||||
aws-credential-types = { workspace = true, features = [
|
||||
"hardcoded-credentials",
|
||||
] }
|
||||
aws_http_client.workspace = true
|
||||
bedrock.workspace = true
|
||||
chrono.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
collections.workspace = true
|
||||
component.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
convert_case.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
editor.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -34,7 +36,6 @@ google_ai = { workspace = true, features = ["schemars"] }
|
||||
gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
lmstudio = { workspace = true, features = ["schemars"] }
|
||||
log.workspace = true
|
||||
@@ -43,6 +44,8 @@ mistral = { workspace = true, features = ["schemars"] }
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
partial-json-fixer.workspace = true
|
||||
proto.workspace = true
|
||||
release_channel.workspace = true
|
||||
@@ -59,9 +62,10 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
util.workspace = true
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
workspace-hack.workspace = true
|
||||
x_ai = { workspace = true, features = ["schemars"] }
|
||||
zed_llm_client.workspace = true
|
||||
language.workspace = true
|
||||
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs", tag = "v0.6.0", features = [] }
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -17,6 +17,7 @@ use crate::provider::cloud::CloudLanguageModelProvider;
|
||||
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
|
||||
use crate::provider::google::GoogleLanguageModelProvider;
|
||||
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
|
||||
use crate::provider::local::LocalLanguageModelProvider;
|
||||
use crate::provider::mistral::MistralLanguageModelProvider;
|
||||
use crate::provider::ollama::OllamaLanguageModelProvider;
|
||||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
@@ -150,4 +151,8 @@ fn register_language_model_providers(
|
||||
);
|
||||
registry.register_provider(XAiLanguageModelProvider::new(client.http_client(), cx), cx);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
registry.register_provider(
|
||||
LocalLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ pub mod copilot_chat;
|
||||
pub mod deepseek;
|
||||
pub mod google;
|
||||
pub mod lmstudio;
|
||||
pub mod local;
|
||||
pub mod mistral;
|
||||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
|
||||
@@ -3,13 +3,6 @@ use anthropic::AnthropicModelMode;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
||||
use cloud_llm_client::{
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
};
|
||||
@@ -40,6 +33,13 @@ use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
use util::{ResultExt as _, maybe};
|
||||
use zed_llm_client::{
|
||||
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
|
||||
CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
|
||||
use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
|
||||
use crate::provider::google::{GoogleEventMapper, into_google};
|
||||
@@ -120,10 +120,10 @@ pub struct State {
|
||||
user_store: Entity<UserStore>,
|
||||
status: client::Status,
|
||||
accept_terms_of_service_task: Option<Task<Result<()>>>,
|
||||
models: Vec<Arc<cloud_llm_client::LanguageModel>>,
|
||||
default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
|
||||
default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
|
||||
recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
|
||||
models: Vec<Arc<zed_llm_client::LanguageModel>>,
|
||||
default_model: Option<Arc<zed_llm_client::LanguageModel>>,
|
||||
default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
|
||||
recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
|
||||
_fetch_models_task: Task<()>,
|
||||
_settings_subscription: Subscription,
|
||||
_llm_token_subscription: Subscription,
|
||||
@@ -238,8 +238,8 @@ impl State {
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(cloud_llm_client::LanguageModel {
|
||||
id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
@@ -328,7 +328,7 @@ impl CloudLanguageModelProvider {
|
||||
|
||||
fn create_language_model(
|
||||
&self,
|
||||
model: Arc<cloud_llm_client::LanguageModel>,
|
||||
model: Arc<zed_llm_client::LanguageModel>,
|
||||
llm_api_token: LlmApiToken,
|
||||
) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(CloudLanguageModel {
|
||||
@@ -518,7 +518,7 @@ fn render_accept_terms(
|
||||
|
||||
pub struct CloudLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: Arc<cloud_llm_client::LanguageModel>,
|
||||
model: Arc<zed_llm_client::LanguageModel>,
|
||||
llm_api_token: LlmApiToken,
|
||||
client: Arc<Client>,
|
||||
request_limiter: RateLimiter,
|
||||
@@ -611,12 +611,12 @@ impl CloudLanguageModel {
|
||||
.headers()
|
||||
.get(CURRENT_PLAN_HEADER_NAME)
|
||||
.and_then(|plan| plan.to_str().ok())
|
||||
.and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
|
||||
.and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
|
||||
{
|
||||
let plan = match plan {
|
||||
cloud_llm_client::Plan::ZedFree => Plan::Free,
|
||||
cloud_llm_client::Plan::ZedPro => Plan::ZedPro,
|
||||
cloud_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
|
||||
zed_llm_client::Plan::ZedFree => Plan::Free,
|
||||
zed_llm_client::Plan::ZedPro => Plan::ZedPro,
|
||||
zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
|
||||
};
|
||||
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
|
||||
}
|
||||
@@ -729,7 +729,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
|
||||
fn upstream_provider_id(&self) -> LanguageModelProviderId {
|
||||
use cloud_llm_client::LanguageModelProvider::*;
|
||||
use zed_llm_client::LanguageModelProvider::*;
|
||||
match self.model.provider {
|
||||
Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
|
||||
OpenAi => language_model::OPEN_AI_PROVIDER_ID,
|
||||
@@ -738,7 +738,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
}
|
||||
|
||||
fn upstream_provider_name(&self) -> LanguageModelProviderName {
|
||||
use cloud_llm_client::LanguageModelProvider::*;
|
||||
use zed_llm_client::LanguageModelProvider::*;
|
||||
match self.model.provider {
|
||||
Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
|
||||
OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
|
||||
@@ -772,11 +772,11 @@ impl LanguageModel for CloudLanguageModel {
|
||||
|
||||
fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
||||
match self.model.provider {
|
||||
cloud_llm_client::LanguageModelProvider::Anthropic
|
||||
| cloud_llm_client::LanguageModelProvider::OpenAi => {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic
|
||||
| zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
LanguageModelToolSchemaFormat::JsonSchema
|
||||
}
|
||||
cloud_llm_client::LanguageModelProvider::Google => {
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
LanguageModelToolSchemaFormat::JsonSchemaSubset
|
||||
}
|
||||
}
|
||||
@@ -795,15 +795,15 @@ impl LanguageModel for CloudLanguageModel {
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
match &self.model.provider {
|
||||
cloud_llm_client::LanguageModelProvider::Anthropic => {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||
Some(LanguageModelCacheConfiguration {
|
||||
min_total_token: 2_048,
|
||||
should_speculate: true,
|
||||
max_cache_anchors: 4,
|
||||
})
|
||||
}
|
||||
cloud_llm_client::LanguageModelProvider::OpenAi
|
||||
| cloud_llm_client::LanguageModelProvider::Google => None,
|
||||
zed_llm_client::LanguageModelProvider::OpenAi
|
||||
| zed_llm_client::LanguageModelProvider::Google => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -813,17 +813,15 @@ impl LanguageModel for CloudLanguageModel {
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
match self.model.provider {
|
||||
cloud_llm_client::LanguageModelProvider::Anthropic => {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
cloud_llm_client::LanguageModelProvider::OpenAi => {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
|
||||
zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
let model = match open_ai::Model::from_id(&self.model.id.0) {
|
||||
Ok(model) => model,
|
||||
Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
|
||||
};
|
||||
count_open_ai_tokens(request, model, cx)
|
||||
}
|
||||
cloud_llm_client::LanguageModelProvider::Google => {
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let model_id = self.model.id.to_string();
|
||||
@@ -834,7 +832,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request_body = CountTokensBody {
|
||||
provider: cloud_llm_client::LanguageModelProvider::Google,
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: model_id,
|
||||
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||
generate_content_request,
|
||||
@@ -895,7 +893,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
|
||||
let thinking_allowed = request.thinking_allowed;
|
||||
match self.model.provider {
|
||||
cloud_llm_client::LanguageModelProvider::Anthropic => {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => {
|
||||
let request = into_anthropic(
|
||||
request,
|
||||
self.model.id.to_string(),
|
||||
@@ -926,7 +924,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
prompt_id,
|
||||
intent,
|
||||
mode,
|
||||
provider: cloud_llm_client::LanguageModelProvider::Anthropic,
|
||||
provider: zed_llm_client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: serde_json::to_value(&request)
|
||||
.map_err(|e| anyhow!(e))?,
|
||||
@@ -950,7 +948,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
cloud_llm_client::LanguageModelProvider::OpenAi => {
|
||||
zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
let client = self.client.clone();
|
||||
let model = match open_ai::Model::from_id(&self.model.id.0) {
|
||||
Ok(model) => model,
|
||||
@@ -978,7 +976,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
prompt_id,
|
||||
intent,
|
||||
mode,
|
||||
provider: cloud_llm_client::LanguageModelProvider::OpenAi,
|
||||
provider: zed_llm_client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: serde_json::to_value(&request)
|
||||
.map_err(|e| anyhow!(e))?,
|
||||
@@ -998,7 +996,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
});
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
cloud_llm_client::LanguageModelProvider::Google => {
|
||||
zed_llm_client::LanguageModelProvider::Google => {
|
||||
let client = self.client.clone();
|
||||
let request =
|
||||
into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
|
||||
@@ -1018,7 +1016,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
prompt_id,
|
||||
intent,
|
||||
mode,
|
||||
provider: cloud_llm_client::LanguageModelProvider::Google,
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: request.model.model_id.clone(),
|
||||
provider_request: serde_json::to_value(&request)
|
||||
.map_err(|e| anyhow!(e))?,
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashMap;
|
||||
use copilot::copilot_chat::{
|
||||
ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
|
||||
@@ -31,6 +30,7 @@ use settings::SettingsStore;
|
||||
use std::time::Duration;
|
||||
use ui::prelude::*;
|
||||
use util::debug_panic;
|
||||
use zed_llm_client::CompletionIntent;
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
use super::google::count_google_tokens;
|
||||
|
||||
474
crates/language_models/src/provider/local.rs
Normal file
474
crates/language_models/src/provider/local.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::{FutureExt, SinkExt, StreamExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, MessageContent, RateLimiter, Role, StopReason,
|
||||
};
|
||||
use mistralrs::{
|
||||
IsqType, Model as MistralModel, Response as MistralResponse, TextMessageRole, TextMessages,
|
||||
TextModelBuilder,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use ui::{ButtonLike, IconName, Indicator, prelude::*};
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
|
||||
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
|
||||
const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
|
||||
|
||||
#[derive(Default, Debug, Clone, PartialEq)]
|
||||
pub struct LocalSettings {
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
}
|
||||
|
||||
pub struct LocalLanguageModelProvider {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
model: Option<Arc<MistralModel>>,
|
||||
status: ModelStatus,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
enum ModelStatus {
|
||||
NotLoaded,
|
||||
Loading,
|
||||
Loaded,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn new(_cx: &mut Context<Self>) -> Self {
|
||||
Self {
|
||||
model: None,
|
||||
status: ModelStatus::NotLoaded,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_authenticated(&self) -> bool {
|
||||
// Local models don't require authentication
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
// Skip if already loaded or currently loading
|
||||
if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
self.status = ModelStatus::Loading;
|
||||
cx.notify();
|
||||
|
||||
let background_executor = cx.background_executor().clone();
|
||||
cx.spawn(async move |this, cx| {
|
||||
eprintln!("Local model: Starting to load model");
|
||||
|
||||
// Move the model loading to a background thread
|
||||
let model_result = background_executor
|
||||
.spawn(async move { load_mistral_model().await })
|
||||
.await;
|
||||
|
||||
match model_result {
|
||||
Ok(model) => {
|
||||
eprintln!("Local model: Model loaded successfully");
|
||||
this.update(cx, |state, cx| {
|
||||
state.model = Some(model);
|
||||
state.status = ModelStatus::Loaded;
|
||||
cx.notify();
|
||||
eprintln!("Local model: Status updated to Loaded");
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
let error_msg = e.to_string();
|
||||
eprintln!("Local model: Failed to load model - {}", error_msg);
|
||||
this.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Error(error_msg.clone());
|
||||
cx.notify();
|
||||
eprintln!("Local model: Status updated to Failed");
|
||||
})?;
|
||||
Err(AuthenticateError::Other(anyhow!(
|
||||
"Failed to load model: {}",
|
||||
error_msg
|
||||
)))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_mistral_model() -> Result<Arc<MistralModel>> {
|
||||
println!("\n\n\n\nLoading mistral model...\n\n\n");
|
||||
eprintln!("Starting to load model: {}", DEFAULT_MODEL);
|
||||
|
||||
// Configure the model builder to use background threads for downloads
|
||||
eprintln!("Creating TextModelBuilder...");
|
||||
let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
|
||||
|
||||
eprintln!("Building model (this should be quick for a 0.5B model)...");
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
match builder.build().await {
|
||||
Ok(model) => {
|
||||
let elapsed = start_time.elapsed();
|
||||
eprintln!("Model loaded successfully in {:?}", elapsed);
|
||||
Ok(Arc::new(model))
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to load model: {:?}", e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalLanguageModelProvider {
|
||||
pub fn new(_http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(State::new);
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for LocalLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for LocalLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::Ai
|
||||
}
|
||||
|
||||
fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
vec![Arc::new(LocalLanguageModel {
|
||||
state: self.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})]
|
||||
}
|
||||
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.provided_models(cx).into_iter().next()
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
self.default_model(cx)
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, _cx: &App) -> bool {
|
||||
// Local models don't require authentication
|
||||
true
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, _window: &mut gpui::Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|_cx| ConfigurationView {
|
||||
state: self.state.clone(),
|
||||
})
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| {
|
||||
state.model = None;
|
||||
state.status = ModelStatus::NotLoaded;
|
||||
cx.notify();
|
||||
});
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LocalLanguageModel {
|
||||
state: Entity<State>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl LocalLanguageModel {
|
||||
fn to_mistral_messages(&self, request: &LanguageModelRequest) -> TextMessages {
|
||||
let mut messages = TextMessages::new();
|
||||
|
||||
for message in &request.messages {
|
||||
let mut text_content = String::new();
|
||||
|
||||
for content in &message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
text_content.push_str(text);
|
||||
}
|
||||
MessageContent::Image { .. } => {
|
||||
// For now, skip image content
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolResult { .. } => {
|
||||
// Skip tool results for now
|
||||
continue;
|
||||
}
|
||||
MessageContent::Thinking { .. } => {
|
||||
// Skip thinking content
|
||||
continue;
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {
|
||||
// Skip redacted thinking
|
||||
continue;
|
||||
}
|
||||
MessageContent::ToolUse(_) => {
|
||||
// Skip tool use
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if text_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let role = match message.role {
|
||||
Role::User => TextMessageRole::User,
|
||||
Role::Assistant => TextMessageRole::Assistant,
|
||||
Role::System => TextMessageRole::System,
|
||||
};
|
||||
|
||||
messages = messages.add_message(role, text_content);
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for LocalLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
LanguageModelId(DEFAULT_MODEL.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName(DEFAULT_MODEL.into())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
PROVIDER_ID
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
PROVIDER_NAME
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("local/{}", DEFAULT_MODEL)
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
128000 // Qwen2.5 supports 128k context
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
// Rough estimation: 1 token ≈ 4 characters
|
||||
let mut total_chars = 0;
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => total_chars += text.len(),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
let tokens = (total_chars / 4) as u64;
|
||||
futures::future::ready(Ok(tokens)).boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let messages = self.to_mistral_messages(&request);
|
||||
let state = self.state.clone();
|
||||
let limiter = self.request_limiter.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let result: Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
> = limiter
|
||||
.run(async move {
|
||||
let model = cx
|
||||
.read_entity(&state, |state, _| {
|
||||
eprintln!(
|
||||
"Local model: Checking if model is loaded: {:?}",
|
||||
state.status
|
||||
);
|
||||
state.model.clone()
|
||||
})
|
||||
.map_err(|_| {
|
||||
LanguageModelCompletionError::Other(anyhow!("App state dropped"))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
eprintln!("Local model: Model is not loaded!");
|
||||
LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
|
||||
})?;
|
||||
|
||||
let (mut tx, rx) = mpsc::channel(32);
|
||||
|
||||
// Spawn a task to handle the stream
|
||||
let _ = smol::spawn(async move {
|
||||
let mut stream = match model.stream_chat_request(messages).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
let _ = tx
|
||||
.send(Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Failed to start stream: {}",
|
||||
e
|
||||
))))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let event = match response {
|
||||
MistralResponse::Chunk(chunk) => {
|
||||
if let Some(choice) = chunk.choices.first() {
|
||||
if let Some(content) = &choice.delta.content {
|
||||
Some(Ok(LanguageModelCompletionEvent::Text(
|
||||
content.clone(),
|
||||
)))
|
||||
} else if let Some(finish_reason) = &choice.finish_reason {
|
||||
let stop_reason = match finish_reason.as_str() {
|
||||
"stop" => StopReason::EndTurn,
|
||||
"length" => StopReason::MaxTokens,
|
||||
_ => StopReason::EndTurn,
|
||||
};
|
||||
Some(Ok(LanguageModelCompletionEvent::Stop(
|
||||
stop_reason,
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
MistralResponse::Done(_response) => {
|
||||
// For now, we don't emit usage events since the format doesn't match
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(event) = event {
|
||||
if tx.send(event).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
Ok(rx.boxed())
|
||||
})
|
||||
.await;
|
||||
|
||||
result
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: Entity<State>,
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _window: &mut gpui::Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let status = self.state.read(cx).status.clone();
|
||||
|
||||
div().size_full().child(
|
||||
div()
|
||||
.p_4()
|
||||
.child(
|
||||
div()
|
||||
.flex()
|
||||
.gap_2()
|
||||
.items_center()
|
||||
.child(match &status {
|
||||
ModelStatus::NotLoaded => Label::new("Model not loaded"),
|
||||
ModelStatus::Loading => Label::new("Loading model..."),
|
||||
ModelStatus::Loaded => Label::new("Model loaded"),
|
||||
ModelStatus::Error(e) => Label::new(format!("Error: {}", e)),
|
||||
})
|
||||
.child(match &status {
|
||||
ModelStatus::NotLoaded => Indicator::dot().color(Color::Disabled),
|
||||
ModelStatus::Loading => Indicator::dot().color(Color::Modified),
|
||||
ModelStatus::Loaded => Indicator::dot().color(Color::Success),
|
||||
ModelStatus::Error(_) => Indicator::dot().color(Color::Error),
|
||||
}),
|
||||
)
|
||||
.when(!matches!(status, ModelStatus::Loading), |this| {
|
||||
this.child(
|
||||
ButtonLike::new("load_model")
|
||||
.child(Label::new(if matches!(status, ModelStatus::Loaded) {
|
||||
"Reload Model"
|
||||
} else {
|
||||
"Load Model"
|
||||
}))
|
||||
.on_click(cx.listener(|this, _, _window, cx| {
|
||||
this.state.update(cx, |state, cx| {
|
||||
state.authenticate(cx).detach();
|
||||
});
|
||||
})),
|
||||
)
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
259
crates/language_models/src/provider/local/tests.rs
Normal file
259
crates/language_models/src/provider/local/tests.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use http_client::FakeHttpClient;
|
||||
use language_model::{LanguageModelRequest, MessageContent, Role};
|
||||
|
||||
#[gpui::test]
|
||||
fn test_local_provider_creation(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
cx.read(|cx| {
|
||||
assert_eq!(provider.id(), PROVIDER_ID);
|
||||
assert_eq!(provider.name(), PROVIDER_NAME);
|
||||
assert!(!provider.is_authenticated(cx));
|
||||
assert_eq!(provider.provided_models(cx).len(), 1);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_state_initialization(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let state = cx.new(State::new);
|
||||
|
||||
assert!(!state.read(cx).is_authenticated());
|
||||
assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
|
||||
assert!(state.read(cx).model.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_model_properties(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
// Create a model directly for testing (bypassing authentication)
|
||||
let model = LocalLanguageModel {
|
||||
state: provider.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
};
|
||||
|
||||
assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.provider_id(), PROVIDER_ID);
|
||||
assert_eq!(model.provider_name(), PROVIDER_NAME);
|
||||
assert_eq!(model.max_token_count(), 128000);
|
||||
assert!(!model.supports_tools());
|
||||
assert!(!model.supports_images());
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_token_counting(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
let model = LocalLanguageModel {
|
||||
state: provider.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
};
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
messages: vec![language_model::LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello, world!".to_string())],
|
||||
cache: false,
|
||||
}],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: false,
|
||||
};
|
||||
|
||||
let count = cx
|
||||
.update(|cx| model.count_tokens(request, cx))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// "Hello, world!" is 13 characters, so ~3 tokens
|
||||
assert!(count > 0);
|
||||
assert!(count < 10);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_message_conversion(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
let model = LocalLanguageModel {
|
||||
state: provider.state.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
};
|
||||
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
messages: vec![
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: Role::System,
|
||||
content: vec![MessageContent::Text(
|
||||
"You are a helpful assistant.".to_string(),
|
||||
)],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("Hello!".to_string())],
|
||||
cache: false,
|
||||
},
|
||||
language_model::LanguageModelRequestMessage {
|
||||
role: Role::Assistant,
|
||||
content: vec![MessageContent::Text("Hi there!".to_string())],
|
||||
cache: false,
|
||||
},
|
||||
],
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
stop: Vec::new(),
|
||||
temperature: None,
|
||||
thinking_allowed: false,
|
||||
};
|
||||
|
||||
let _messages = model.to_mistral_messages(&request);
|
||||
// We can't directly inspect TextMessages, but we can verify it doesn't panic
|
||||
assert!(true); // Placeholder assertion
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_reset_credentials(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
// Simulate loading a model by just setting the status
|
||||
cx.update(|cx| {
|
||||
provider.state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Loaded;
|
||||
// We don't actually set a model since we can't mock it safely
|
||||
cx.notify();
|
||||
});
|
||||
});
|
||||
|
||||
cx.read(|cx| {
|
||||
// Since is_authenticated checks for model presence, we need to check status directly
|
||||
assert_eq!(provider.state.read(cx).status, ModelStatus::Loaded);
|
||||
});
|
||||
|
||||
// Reset credentials
|
||||
let task = cx.update(|cx| provider.reset_credentials(cx));
|
||||
task.await.unwrap();
|
||||
|
||||
cx.read(|cx| {
|
||||
assert!(!provider.is_authenticated(cx));
|
||||
assert_eq!(provider.state.read(cx).status, ModelStatus::NotLoaded);
|
||||
assert!(provider.state.read(cx).model.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: Fix this test - need to handle window creation in tests
|
||||
// #[gpui::test]
|
||||
// async fn test_configuration_view_rendering(cx: &mut TestAppContext) {
|
||||
// let http_client = FakeHttpClient::with_200_response();
|
||||
// let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
// let view = cx.update(|cx| provider.configuration_view(cx.window(), cx));
|
||||
|
||||
// // Basic test to ensure the view can be created without panicking
|
||||
// assert!(view.entity_type() == std::any::TypeId::of::<ConfigurationView>());
|
||||
// }
|
||||
|
||||
#[gpui::test]
|
||||
fn test_status_transitions(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let state = cx.new(State::new);
|
||||
|
||||
// Initial state
|
||||
assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
|
||||
|
||||
// Transition to loading
|
||||
state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Loading;
|
||||
cx.notify();
|
||||
});
|
||||
assert_eq!(state.read(cx).status, ModelStatus::Loading);
|
||||
|
||||
// Transition to loaded
|
||||
state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Loaded;
|
||||
cx.notify();
|
||||
});
|
||||
assert_eq!(state.read(cx).status, ModelStatus::Loaded);
|
||||
|
||||
// Transition to error
|
||||
state.update(cx, |state, cx| {
|
||||
state.status = ModelStatus::Error("Test error".to_string());
|
||||
cx.notify();
|
||||
});
|
||||
match &state.read(cx).status {
|
||||
ModelStatus::Error(msg) => assert_eq!(msg, "Test error"),
|
||||
_ => panic!("Expected error status"),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_provider_shows_models_without_authentication(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
cx.read(|cx| {
|
||||
// Provider should show models even when not authenticated
|
||||
let models = provider.provided_models(cx);
|
||||
assert_eq!(models.len(), 1);
|
||||
|
||||
let model = &models[0];
|
||||
assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
|
||||
assert_eq!(model.provider_id(), PROVIDER_ID);
|
||||
assert_eq!(model.provider_name(), PROVIDER_NAME);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_provider_has_icon(cx: &mut TestAppContext) {
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
|
||||
|
||||
assert_eq!(provider.icon(), IconName::Ai);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
fn test_provider_appears_in_registry(cx: &mut TestAppContext) {
|
||||
use language_model::LanguageModelRegistry;
|
||||
|
||||
cx.update(|cx| {
|
||||
let registry = cx.new(|_| LanguageModelRegistry::default());
|
||||
let http_client = FakeHttpClient::with_200_response();
|
||||
|
||||
// Register the local provider
|
||||
registry.update(cx, |registry, cx| {
|
||||
let provider = LocalLanguageModelProvider::new(Arc::new(http_client), cx);
|
||||
registry.register_provider(provider, cx);
|
||||
});
|
||||
|
||||
// Verify the provider is registered
|
||||
let provider = registry.read(cx).provider(&PROVIDER_ID).unwrap();
|
||||
assert_eq!(provider.name(), PROVIDER_NAME);
|
||||
assert_eq!(provider.icon(), IconName::Ai);
|
||||
|
||||
// Verify it provides models even without authentication
|
||||
let models = provider.provided_models(cx);
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].id(), LanguageModelId(DEFAULT_MODEL.into()));
|
||||
});
|
||||
}
|
||||
@@ -421,14 +421,14 @@ impl LanguageServer {
|
||||
.map(|stderr| {
|
||||
let io_handlers = io_handlers.clone();
|
||||
let stderr_captures = stderr_capture.clone();
|
||||
cx.background_spawn(async move {
|
||||
cx.spawn(async move |_| {
|
||||
Self::handle_stderr(stderr, io_handlers, stderr_captures)
|
||||
.log_err()
|
||||
.await
|
||||
})
|
||||
})
|
||||
.unwrap_or_else(|| Task::ready(None));
|
||||
let input_task = cx.background_spawn(async move {
|
||||
let input_task = cx.spawn(async move |_| {
|
||||
let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
|
||||
stdout.or(stderr)
|
||||
});
|
||||
@@ -846,7 +846,7 @@ impl LanguageServer {
|
||||
configuration: Arc<DidChangeConfigurationParams>,
|
||||
cx: &App,
|
||||
) -> Task<Result<Arc<Self>>> {
|
||||
cx.background_spawn(async move {
|
||||
cx.spawn(async move |_| {
|
||||
let response = self
|
||||
.request::<request::Initialize>(params)
|
||||
.await
|
||||
|
||||
@@ -18,19 +18,12 @@ default = []
|
||||
anyhow.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
db.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
fs.workspace = true
|
||||
gpui.workspace = true
|
||||
language.workspace = true
|
||||
project.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
settings.workspace = true
|
||||
theme.workspace = true
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
workspace.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zlog.workspace = true
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user