Compare commits
93 Commits
v0.217.0-p
...
migrate-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cc517e0dd | ||
|
|
d1390a5b78 | ||
|
|
ee4faede38 | ||
|
|
8d96a699b3 | ||
|
|
8cfb7471db | ||
|
|
def9c87837 | ||
|
|
7ed5d42696 | ||
|
|
0313ab6d41 | ||
|
|
c5329fdff2 | ||
|
|
a676a6895b | ||
|
|
3b5d7d7d89 | ||
|
|
91f01131b1 | ||
|
|
5fa5226286 | ||
|
|
ae94007227 | ||
|
|
25d74480aa | ||
|
|
37077a8ebb | ||
|
|
7c4a85f5f1 | ||
|
|
d21628c349 | ||
|
|
9e628505f3 | ||
|
|
3a84ec38ac | ||
|
|
8f425a1bd5 | ||
|
|
743c414e7b | ||
|
|
0fe335efc5 | ||
|
|
a61bf33fb0 | ||
|
|
36b95aac4b | ||
|
|
b2df70ab58 | ||
|
|
d83201256d | ||
|
|
8ee85eab3c | ||
|
|
5b309ef986 | ||
|
|
326ebb5230 | ||
|
|
f5babf96e1 | ||
|
|
f48aa252f8 | ||
|
|
4106c8a188 | ||
|
|
36293d7dd9 | ||
|
|
3ae3e1fce8 | ||
|
|
e5f1fc7478 | ||
|
|
a4f6076da7 | ||
|
|
43726b2620 | ||
|
|
94980ffb49 | ||
|
|
22cc731450 | ||
|
|
d9396373e3 | ||
|
|
48002be135 | ||
|
|
58db83f8f5 | ||
|
|
0243d5b542 | ||
|
|
06230327fa | ||
|
|
ca5c8992f9 | ||
|
|
1038e1c2ef | ||
|
|
e1fe0b3287 | ||
|
|
a0e10a91bf | ||
|
|
272b1aa4bc | ||
|
|
9ef0537b44 | ||
|
|
77f1de742b | ||
|
|
e054cabd41 | ||
|
|
3b95cb5682 | ||
|
|
c89653bd07 | ||
|
|
b90ac2dc07 | ||
|
|
c9998541f0 | ||
|
|
e2b49b3cd3 | ||
|
|
d1e77397c6 | ||
|
|
cc5f5e35e4 | ||
|
|
7183b8a1cd | ||
|
|
b1934fb712 | ||
|
|
a198b6c0d1 | ||
|
|
8b5b2712c8 | ||
|
|
4464392e8e | ||
|
|
a0d3bc31e9 | ||
|
|
ccd6672d1a | ||
|
|
21de6d35dd | ||
|
|
2031ca17e5 | ||
|
|
8b1ce75a57 | ||
|
|
5559726fd7 | ||
|
|
e1a9269921 | ||
|
|
3b6b3ff504 | ||
|
|
aabed94970 | ||
|
|
2d3a3521ba | ||
|
|
a48bd10da0 | ||
|
|
fec9525be4 | ||
|
|
bf2b8e999e | ||
|
|
63c35d2b00 | ||
|
|
1396c68010 | ||
|
|
fcb3d3dec6 | ||
|
|
f54e7f8c9d | ||
|
|
2a89529d7f | ||
|
|
58207325e2 | ||
|
|
e08ab99e8d | ||
|
|
a95f3f33a4 | ||
|
|
b0767c1b1f | ||
|
|
b200e10bc4 | ||
|
|
948905d916 | ||
|
|
04de456373 | ||
|
|
e5ce32e936 | ||
|
|
d7caae30de | ||
|
|
c7e77674a1 |
2
.github/workflows/run_tests.yml
vendored
2
.github/workflows/run_tests.yml
vendored
@@ -497,6 +497,8 @@ jobs:
|
||||
env:
|
||||
GIT_AUTHOR_NAME: Protobuf Action
|
||||
GIT_AUTHOR_EMAIL: ci@zed.dev
|
||||
GIT_COMMITTER_NAME: Protobuf Action
|
||||
GIT_COMMITTER_EMAIL: ci@zed.dev
|
||||
steps:
|
||||
- name: steps::checkout_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
|
||||
46
Cargo.lock
generated
46
Cargo.lock
generated
@@ -3111,16 +3111,6 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloud_zeta2_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cloud_llm_client",
|
||||
"indoc",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
@@ -3595,6 +3585,7 @@ dependencies = [
|
||||
"settings",
|
||||
"smol",
|
||||
"tempfile",
|
||||
"terminal",
|
||||
"url",
|
||||
"util",
|
||||
]
|
||||
@@ -5118,7 +5109,6 @@ dependencies = [
|
||||
"clock",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"collections",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
@@ -5149,8 +5139,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"strsim",
|
||||
"strum 0.27.2",
|
||||
"telemetry",
|
||||
"telemetry_events",
|
||||
@@ -5161,6 +5149,7 @@ dependencies = [
|
||||
"workspace",
|
||||
"worktree",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5174,11 +5163,10 @@ dependencies = [
|
||||
"clap",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"collections",
|
||||
"debug_adapter_extension",
|
||||
"dirs 4.0.0",
|
||||
"edit_prediction",
|
||||
"edit_prediction_context",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -5208,9 +5196,10 @@ dependencies = [
|
||||
"sqlez",
|
||||
"sqlez_macros",
|
||||
"terminal_view",
|
||||
"toml 0.8.23",
|
||||
"util",
|
||||
"wasmtime",
|
||||
"watch",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5238,6 +5227,7 @@ dependencies = [
|
||||
"text",
|
||||
"tree-sitter",
|
||||
"util",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5259,7 +5249,6 @@ dependencies = [
|
||||
"buffer_diff",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"codestral",
|
||||
"command_palette_hooks",
|
||||
"copilot",
|
||||
@@ -5290,6 +5279,7 @@ dependencies = [
|
||||
"util",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5853,9 +5843,12 @@ dependencies = [
|
||||
"async-trait",
|
||||
"client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"criterion",
|
||||
"ctor",
|
||||
"dap",
|
||||
"dirs 4.0.0",
|
||||
"editor",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -5864,8 +5857,11 @@ dependencies = [
|
||||
"http_client",
|
||||
"language",
|
||||
"language_extension",
|
||||
"language_model",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"menu",
|
||||
"moka",
|
||||
"node_runtime",
|
||||
"parking_lot",
|
||||
@@ -5880,12 +5876,14 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smol",
|
||||
"task",
|
||||
"telemetry",
|
||||
"tempfile",
|
||||
"theme",
|
||||
"theme_extension",
|
||||
"toml 0.8.23",
|
||||
"ui",
|
||||
"url",
|
||||
"util",
|
||||
"wasmparser 0.221.3",
|
||||
@@ -8852,6 +8850,8 @@ dependencies = [
|
||||
"credentials_provider",
|
||||
"deepseek",
|
||||
"editor",
|
||||
"extension",
|
||||
"extension_host",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
"google_ai",
|
||||
@@ -13156,6 +13156,7 @@ dependencies = [
|
||||
"askpass",
|
||||
"auto_update",
|
||||
"dap",
|
||||
"db",
|
||||
"editor",
|
||||
"extension_host",
|
||||
"file_finder",
|
||||
@@ -13167,6 +13168,7 @@ dependencies = [
|
||||
"log",
|
||||
"markdown",
|
||||
"menu",
|
||||
"node_runtime",
|
||||
"ordered-float 2.10.1",
|
||||
"paths",
|
||||
"picker",
|
||||
@@ -13185,6 +13187,7 @@ dependencies = [
|
||||
"util",
|
||||
"windows-registry 0.6.1",
|
||||
"workspace",
|
||||
"worktree",
|
||||
"zed_actions",
|
||||
]
|
||||
|
||||
@@ -20469,7 +20472,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zed"
|
||||
version = "0.217.0"
|
||||
version = "0.218.0"
|
||||
dependencies = [
|
||||
"acp_tools",
|
||||
"activity_indicator",
|
||||
@@ -20929,6 +20932,13 @@ dependencies = [
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeta_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
|
||||
@@ -32,7 +32,6 @@ members = [
|
||||
"crates/cloud_api_client",
|
||||
"crates/cloud_api_types",
|
||||
"crates/cloud_llm_client",
|
||||
"crates/cloud_zeta2_prompt",
|
||||
"crates/collab",
|
||||
"crates/collab_ui",
|
||||
"crates/collections",
|
||||
@@ -202,6 +201,7 @@ members = [
|
||||
"crates/zed_actions",
|
||||
"crates/zed_env_vars",
|
||||
"crates/edit_prediction_cli",
|
||||
"crates/zeta_prompt",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
"crates/ztracing",
|
||||
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
|
||||
cloud_api_client = { path = "crates/cloud_api_client" }
|
||||
cloud_api_types = { path = "crates/cloud_api_types" }
|
||||
cloud_llm_client = { path = "crates/cloud_llm_client" }
|
||||
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
|
||||
collab_ui = { path = "crates/collab_ui" }
|
||||
collections = { path = "crates/collections", version = "0.1.0" }
|
||||
command_palette = { path = "crates/command_palette" }
|
||||
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zed_env_vars = { path = "crates/zed_env_vars" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
zeta_prompt = { path = "crates/zeta_prompt" }
|
||||
zlog = { path = "crates/zlog" }
|
||||
zlog_settings = { path = "crates/zlog_settings" }
|
||||
ztracing = { path = "crates/ztracing" }
|
||||
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
|
||||
tiny_http = "0.8"
|
||||
tokio = { version = "1" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
|
||||
toml = "0.8"
|
||||
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
|
||||
tower-http = "0.4.4"
|
||||
|
||||
5
assets/icons/box.svg
Normal file
5
assets/icons/box.svg
Normal file
@@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M13.3996 5.59852C13.3994 5.3881 13.3439 5.18144 13.2386 4.99926C13.1333 4.81709 12.9819 4.66581 12.7997 4.56059L8.59996 2.16076C8.41755 2.05544 8.21063 2 8 2C7.78937 2 7.58246 2.05544 7.40004 2.16076L3.20033 4.56059C3.0181 4.66581 2.86674 4.81709 2.76144 4.99926C2.65613 5.18144 2.60059 5.3881 2.60037 5.59852V10.3982C2.60059 10.6086 2.65613 10.8153 2.76144 10.9975C2.86674 11.1796 3.0181 11.3309 3.20033 11.4361L7.40004 13.836C7.58246 13.9413 7.78937 13.9967 8 13.9967C8.21063 13.9967 8.41755 13.9413 8.59996 13.836L12.7997 11.4361C12.9819 11.3309 13.1333 11.1796 13.2386 10.9975C13.3439 10.8153 13.3994 10.6086 13.3996 10.3982V5.59852Z" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2.78033 4.99857L7.99998 7.99836L13.2196 4.99857" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M8 13.9979V7.99829" stroke="white" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
@@ -180,7 +180,6 @@
|
||||
"ctrl-w g shift-d": "editor::GoToTypeDefinitionSplit",
|
||||
"ctrl-w space": "editor::OpenExcerptsSplit",
|
||||
"ctrl-w g space": "editor::OpenExcerptsSplit",
|
||||
"ctrl-6": "pane::AlternateFile",
|
||||
"ctrl-^": "pane::AlternateFile",
|
||||
".": "vim::Repeat"
|
||||
}
|
||||
|
||||
@@ -870,6 +870,10 @@
|
||||
//
|
||||
// Default: false
|
||||
"collapse_untracked_diff": false,
|
||||
/// Whether to show entries with tree or flat view in the panel
|
||||
///
|
||||
/// Default: false
|
||||
"tree_view": false,
|
||||
"scrollbar": {
|
||||
// When to show the scrollbar in the git panel.
|
||||
//
|
||||
@@ -1721,7 +1725,12 @@
|
||||
// If you don't want any of these extensions, add this field to your settings
|
||||
// and change the value to `false`.
|
||||
"auto_install_extensions": {
|
||||
"html": true
|
||||
"html": true,
|
||||
"copilot-chat": true,
|
||||
"anthropic": true,
|
||||
"google-ai": true,
|
||||
"openai": true,
|
||||
"openrouter": true,
|
||||
},
|
||||
// The capabilities granted to extensions.
|
||||
//
|
||||
|
||||
@@ -1372,7 +1372,7 @@ impl AcpThread {
|
||||
let path_style = self.project.read(cx).path_style(cx);
|
||||
let id = update.tool_call_id.clone();
|
||||
|
||||
let agent = self.connection().telemetry_id();
|
||||
let agent_telemetry_id = self.connection().telemetry_id();
|
||||
let session = self.session_id();
|
||||
if let ToolCallStatus::Completed | ToolCallStatus::Failed = status {
|
||||
let status = if matches!(status, ToolCallStatus::Completed) {
|
||||
@@ -1380,7 +1380,12 @@ impl AcpThread {
|
||||
} else {
|
||||
"failed"
|
||||
};
|
||||
telemetry::event!("Agent Tool Call Completed", agent, session, status);
|
||||
telemetry::event!(
|
||||
"Agent Tool Call Completed",
|
||||
agent_telemetry_id,
|
||||
session,
|
||||
status
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(ix) = self.index_for_tool_call(&id) {
|
||||
@@ -3556,8 +3561,8 @@ mod tests {
|
||||
}
|
||||
|
||||
impl AgentConnection for FakeAgentConnection {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"fake"
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
"fake".into()
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
|
||||
@@ -20,7 +20,7 @@ impl UserMessageId {
|
||||
}
|
||||
|
||||
pub trait AgentConnection {
|
||||
fn telemetry_id(&self) -> &'static str;
|
||||
fn telemetry_id(&self) -> SharedString;
|
||||
|
||||
fn new_thread(
|
||||
self: Rc<Self>,
|
||||
@@ -204,12 +204,21 @@ pub trait AgentModelSelector: 'static {
|
||||
}
|
||||
}
|
||||
|
||||
/// Icon for a model in the model selector.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AgentModelIcon {
|
||||
/// A built-in icon from Zed's icon set.
|
||||
Named(IconName),
|
||||
/// Path to a custom SVG icon file.
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentModelInfo {
|
||||
pub id: acp::ModelId,
|
||||
pub name: SharedString,
|
||||
pub description: Option<SharedString>,
|
||||
pub icon: Option<IconName>,
|
||||
pub icon: Option<AgentModelIcon>,
|
||||
}
|
||||
|
||||
impl From<acp::ModelInfo> for AgentModelInfo {
|
||||
@@ -322,8 +331,8 @@ mod test_support {
|
||||
}
|
||||
|
||||
impl AgentConnection for StubAgentConnection {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"stub"
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
"stub".into()
|
||||
}
|
||||
|
||||
fn auth_methods(&self) -> &[acp::AuthMethod] {
|
||||
|
||||
@@ -777,7 +777,7 @@ impl ActionLog {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ActionLogTelemetry {
|
||||
pub agent_telemetry_id: &'static str,
|
||||
pub agent_telemetry_id: SharedString,
|
||||
pub session_id: Arc<str>,
|
||||
}
|
||||
|
||||
|
||||
@@ -739,7 +739,7 @@ impl ActivityIndicator {
|
||||
extension_store.outstanding_operations().iter().next()
|
||||
{
|
||||
let (message, icon, rotate) = match operation {
|
||||
ExtensionOperation::Install => (
|
||||
ExtensionOperation::Install | ExtensionOperation::AutoInstall => (
|
||||
format!("Installing {extension_id} extension…"),
|
||||
IconName::LoadCircle,
|
||||
true,
|
||||
|
||||
@@ -18,7 +18,7 @@ pub use templates::*;
|
||||
pub use thread::*;
|
||||
pub use tools::*;
|
||||
|
||||
use acp_thread::{AcpThread, AgentModelSelector};
|
||||
use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
@@ -105,7 +105,7 @@ impl LanguageModels {
|
||||
fn refresh_list(&mut self, cx: &App) {
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
@@ -161,11 +161,16 @@ impl LanguageModels {
|
||||
model: &Arc<dyn LanguageModel>,
|
||||
provider: &Arc<dyn LanguageModelProvider>,
|
||||
) -> acp_thread::AgentModelInfo {
|
||||
let icon = if let Some(path) = provider.icon_path() {
|
||||
Some(AgentModelIcon::Path(path))
|
||||
} else {
|
||||
Some(AgentModelIcon::Named(provider.icon()))
|
||||
};
|
||||
acp_thread::AgentModelInfo {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
description: None,
|
||||
icon: Some(provider.icon()),
|
||||
icon,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -947,8 +952,8 @@ impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
|
||||
}
|
||||
|
||||
impl acp_thread::AgentConnection for NativeAgentConnection {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"zed"
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
"zed".into()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
@@ -1356,7 +1361,7 @@ mod internal_tests {
|
||||
id: acp::ModelId::new("fake/fake"),
|
||||
name: "Fake".into(),
|
||||
description: None,
|
||||
icon: Some(ui::IconName::ZedAssistant),
|
||||
icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
|
||||
@@ -21,10 +21,6 @@ impl NativeAgentServer {
|
||||
}
|
||||
|
||||
impl AgentServer for NativeAgentServer {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"zed"
|
||||
}
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Zed Agent".into()
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ use futures::io::BufReader;
|
||||
use project::Project;
|
||||
use project::agent_server_store::AgentServerCommand;
|
||||
use serde::Deserialize;
|
||||
use settings::Settings as _;
|
||||
use task::ShellBuilder;
|
||||
#[cfg(windows)]
|
||||
use task::ShellKind;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use std::path::PathBuf;
|
||||
@@ -21,7 +25,7 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntit
|
||||
|
||||
use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent};
|
||||
use terminal::TerminalBuilder;
|
||||
use terminal::terminal_settings::{AlternateScroll, CursorShape};
|
||||
use terminal::terminal_settings::{AlternateScroll, CursorShape, TerminalSettings};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Unsupported version")]
|
||||
@@ -29,7 +33,7 @@ pub struct UnsupportedVersion;
|
||||
|
||||
pub struct AcpConnection {
|
||||
server_name: SharedString,
|
||||
telemetry_id: &'static str,
|
||||
telemetry_id: SharedString,
|
||||
connection: Rc<acp::ClientSideConnection>,
|
||||
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
|
||||
auth_methods: Vec<acp::AuthMethod>,
|
||||
@@ -54,7 +58,6 @@ pub struct AcpSession {
|
||||
|
||||
pub async fn connect(
|
||||
server_name: SharedString,
|
||||
telemetry_id: &'static str,
|
||||
command: AgentServerCommand,
|
||||
root_dir: &Path,
|
||||
default_mode: Option<acp::SessionModeId>,
|
||||
@@ -64,7 +67,6 @@ pub async fn connect(
|
||||
) -> Result<Rc<dyn AgentConnection>> {
|
||||
let conn = AcpConnection::stdio(
|
||||
server_name,
|
||||
telemetry_id,
|
||||
command.clone(),
|
||||
root_dir,
|
||||
default_mode,
|
||||
@@ -81,7 +83,6 @@ const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::ProtocolVersion::V1
|
||||
impl AcpConnection {
|
||||
pub async fn stdio(
|
||||
server_name: SharedString,
|
||||
telemetry_id: &'static str,
|
||||
command: AgentServerCommand,
|
||||
root_dir: &Path,
|
||||
default_mode: Option<acp::SessionModeId>,
|
||||
@@ -89,9 +90,26 @@ impl AcpConnection {
|
||||
is_remote: bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut child = util::command::new_smol_command(&command.path);
|
||||
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
|
||||
let builder = ShellBuilder::new(&shell, cfg!(windows));
|
||||
#[cfg(windows)]
|
||||
let kind = builder.kind();
|
||||
let (cmd, args) = builder.build(Some(command.path.display().to_string()), &command.args);
|
||||
|
||||
let mut child = util::command::new_smol_command(cmd);
|
||||
#[cfg(windows)]
|
||||
if kind == ShellKind::Cmd {
|
||||
use smol::process::windows::CommandExt;
|
||||
for arg in args {
|
||||
child.raw_arg(arg);
|
||||
}
|
||||
} else {
|
||||
child.args(args);
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
child.args(args);
|
||||
|
||||
child
|
||||
.args(command.args.iter().map(|arg| arg.as_str()))
|
||||
.envs(command.env.iter().flatten())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
@@ -199,6 +217,13 @@ impl AcpConnection {
|
||||
return Err(UnsupportedVersion.into());
|
||||
}
|
||||
|
||||
let telemetry_id = response
|
||||
.agent_info
|
||||
// Use the one the agent provides if we have one
|
||||
.map(|info| info.name.into())
|
||||
// Otherwise, just use the name
|
||||
.unwrap_or_else(|| server_name.clone());
|
||||
|
||||
Ok(Self {
|
||||
auth_methods: response.auth_methods,
|
||||
root_dir: root_dir.to_owned(),
|
||||
@@ -233,8 +258,8 @@ impl Drop for AcpConnection {
|
||||
}
|
||||
|
||||
impl AgentConnection for AcpConnection {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
self.telemetry_id
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
self.telemetry_id.clone()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
|
||||
@@ -56,7 +56,6 @@ impl AgentServerDelegate {
|
||||
pub trait AgentServer: Send {
|
||||
fn logo(&self) -> ui::IconName;
|
||||
fn name(&self) -> SharedString;
|
||||
fn telemetry_id(&self) -> &'static str;
|
||||
fn default_mode(&self, _cx: &mut App) -> Option<agent_client_protocol::SessionModeId> {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -22,10 +22,6 @@ pub struct AgentServerLoginCommand {
|
||||
}
|
||||
|
||||
impl AgentServer for ClaudeCode {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"claude-code"
|
||||
}
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Claude Code".into()
|
||||
}
|
||||
@@ -83,7 +79,6 @@ impl AgentServer for ClaudeCode {
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
|
||||
let name = self.name();
|
||||
let telemetry_id = self.telemetry_id();
|
||||
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let store = delegate.store.downgrade();
|
||||
@@ -108,7 +103,6 @@ impl AgentServer for ClaudeCode {
|
||||
.await?;
|
||||
let connection = crate::acp::connect(
|
||||
name,
|
||||
telemetry_id,
|
||||
command,
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
|
||||
@@ -23,10 +23,6 @@ pub(crate) mod tests {
|
||||
}
|
||||
|
||||
impl AgentServer for Codex {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"codex"
|
||||
}
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Codex".into()
|
||||
}
|
||||
@@ -84,7 +80,6 @@ impl AgentServer for Codex {
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
|
||||
let name = self.name();
|
||||
let telemetry_id = self.telemetry_id();
|
||||
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let store = delegate.store.downgrade();
|
||||
@@ -110,7 +105,6 @@ impl AgentServer for Codex {
|
||||
|
||||
let connection = crate::acp::connect(
|
||||
name,
|
||||
telemetry_id,
|
||||
command,
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{AgentServerDelegate, load_proxy_env};
|
||||
use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Context as _, Result};
|
||||
@@ -20,11 +20,7 @@ impl CustomAgentServer {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::AgentServer for CustomAgentServer {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"custom"
|
||||
}
|
||||
|
||||
impl AgentServer for CustomAgentServer {
|
||||
fn name(&self) -> SharedString {
|
||||
self.name.clone()
|
||||
}
|
||||
@@ -112,14 +108,12 @@ impl crate::AgentServer for CustomAgentServer {
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
|
||||
let name = self.name();
|
||||
let telemetry_id = self.telemetry_id();
|
||||
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let default_mode = self.default_mode(cx);
|
||||
let default_model = self.default_model(cx);
|
||||
let store = delegate.store.downgrade();
|
||||
let extra_env = load_proxy_env(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (command, root_dir, login) = store
|
||||
.update(cx, |store, cx| {
|
||||
@@ -139,7 +133,6 @@ impl crate::AgentServer for CustomAgentServer {
|
||||
.await?;
|
||||
let connection = crate::acp::connect(
|
||||
name,
|
||||
telemetry_id,
|
||||
command,
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
|
||||
@@ -5,17 +5,13 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, SharedString, Task};
|
||||
use language_models::provider::google::GoogleLanguageModelProvider;
|
||||
use language_models::api_key_for_gemini_cli;
|
||||
use project::agent_server_store::GEMINI_NAME;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Gemini;
|
||||
|
||||
impl AgentServer for Gemini {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"gemini-cli"
|
||||
}
|
||||
|
||||
fn name(&self) -> SharedString {
|
||||
"Gemini CLI".into()
|
||||
}
|
||||
@@ -31,7 +27,6 @@ impl AgentServer for Gemini {
|
||||
cx: &mut App,
|
||||
) -> Task<Result<(Rc<dyn AgentConnection>, Option<task::SpawnInTerminal>)>> {
|
||||
let name = self.name();
|
||||
let telemetry_id = self.telemetry_id();
|
||||
let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().into_owned());
|
||||
let is_remote = delegate.project.read(cx).is_via_remote_server();
|
||||
let store = delegate.store.downgrade();
|
||||
@@ -42,11 +37,7 @@ impl AgentServer for Gemini {
|
||||
cx.spawn(async move |cx| {
|
||||
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
|
||||
|
||||
if let Some(api_key) = cx
|
||||
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
|
||||
.await
|
||||
.ok()
|
||||
{
|
||||
if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() {
|
||||
extra_env.insert("GEMINI_API_KEY".into(), api_key);
|
||||
}
|
||||
let (command, root_dir, login) = store
|
||||
@@ -66,7 +57,6 @@ impl AgentServer for Gemini {
|
||||
|
||||
let connection = crate::acp::connect(
|
||||
name,
|
||||
telemetry_id,
|
||||
command,
|
||||
root_dir.as_ref(),
|
||||
default_mode,
|
||||
|
||||
@@ -565,8 +565,26 @@ impl MessageEditor {
|
||||
if let Some((workspace, selections)) =
|
||||
self.workspace.upgrade().zip(editor_clipboard_selections)
|
||||
{
|
||||
cx.stop_propagation();
|
||||
let Some(first_selection) = selections.first() else {
|
||||
return;
|
||||
};
|
||||
if let Some(file_path) = &first_selection.file_path {
|
||||
// In case someone pastes selections from another window
|
||||
// with a different project, we don't want to insert the
|
||||
// crease (containing the absolute path) since the agent
|
||||
// cannot access files outside the project.
|
||||
let is_in_project = workspace
|
||||
.read(cx)
|
||||
.project()
|
||||
.read(cx)
|
||||
.project_path_for_absolute_path(file_path, cx)
|
||||
.is_some();
|
||||
if !is_in_project {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
cx.stop_propagation();
|
||||
let insertion_target = self
|
||||
.editor
|
||||
.read(cx)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
@@ -292,12 +292,18 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.when_some(model_info.icon, |this, icon| {
|
||||
this.child(
|
||||
Icon::new(icon)
|
||||
.map(|this| match &model_info.icon {
|
||||
Some(AgentModelIcon::Path(path)) => this.child(
|
||||
Icon::from_external_svg(path.clone())
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
Some(AgentModelIcon::Named(icon)) => this.child(
|
||||
Icon::new(*icon)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
),
|
||||
None => this,
|
||||
})
|
||||
.child(Label::new(model_info.name.clone()).truncate()),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use acp_thread::{AgentModelInfo, AgentModelSelector};
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
@@ -64,7 +64,7 @@ impl Render for AcpModelSelectorPopover {
|
||||
.map(|model| model.name.clone())
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon);
|
||||
let model_icon = model.as_ref().and_then(|model| model.icon.clone());
|
||||
|
||||
let focus_handle = self.focus_handle.clone();
|
||||
|
||||
@@ -78,8 +78,15 @@ impl Render for AcpModelSelectorPopover {
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.when_some(model_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
.when_some(model_icon, |this, icon| match icon {
|
||||
AgentModelIcon::Path(path) => this.child(
|
||||
Icon::from_external_svg(path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
AgentModelIcon::Named(icon_name) => {
|
||||
this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall))
|
||||
}
|
||||
})
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
|
||||
@@ -170,7 +170,7 @@ impl ThreadFeedbackState {
|
||||
}
|
||||
}
|
||||
let session_id = thread.read(cx).session_id().clone();
|
||||
let agent = thread.read(cx).connection().telemetry_id();
|
||||
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
|
||||
let task = telemetry.thread_data(&session_id, cx);
|
||||
let rating = match feedback {
|
||||
ThreadFeedback::Positive => "positive",
|
||||
@@ -180,7 +180,7 @@ impl ThreadFeedbackState {
|
||||
let thread = task.await?;
|
||||
telemetry::event!(
|
||||
"Agent Thread Rated",
|
||||
agent = agent,
|
||||
agent = agent_telemetry_id,
|
||||
session_id = session_id,
|
||||
rating = rating,
|
||||
thread = thread
|
||||
@@ -207,13 +207,13 @@ impl ThreadFeedbackState {
|
||||
self.comments_editor.take();
|
||||
|
||||
let session_id = thread.read(cx).session_id().clone();
|
||||
let agent = thread.read(cx).connection().telemetry_id();
|
||||
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
|
||||
let task = telemetry.thread_data(&session_id, cx);
|
||||
cx.background_spawn(async move {
|
||||
let thread = task.await?;
|
||||
telemetry::event!(
|
||||
"Agent Thread Feedback Comments",
|
||||
agent = agent,
|
||||
agent = agent_telemetry_id,
|
||||
session_id = session_id,
|
||||
comments = comments,
|
||||
thread = thread
|
||||
@@ -333,6 +333,7 @@ impl AcpThreadView {
|
||||
project: Entity<Project>,
|
||||
history_store: Entity<HistoryStore>,
|
||||
prompt_store: Option<Entity<PromptStore>>,
|
||||
track_load_event: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
@@ -391,8 +392,9 @@ impl AcpThreadView {
|
||||
),
|
||||
];
|
||||
|
||||
let show_codex_windows_warning = crate::ExternalAgent::parse_built_in(agent.as_ref())
|
||||
== Some(crate::ExternalAgent::Codex);
|
||||
let show_codex_windows_warning = cfg!(windows)
|
||||
&& project.read(cx).is_local()
|
||||
&& agent.clone().downcast::<agent_servers::Codex>().is_some();
|
||||
|
||||
Self {
|
||||
agent: agent.clone(),
|
||||
@@ -404,6 +406,7 @@ impl AcpThreadView {
|
||||
resume_thread.clone(),
|
||||
workspace.clone(),
|
||||
project.clone(),
|
||||
track_load_event,
|
||||
window,
|
||||
cx,
|
||||
),
|
||||
@@ -448,6 +451,7 @@ impl AcpThreadView {
|
||||
self.resume_thread_metadata.clone(),
|
||||
self.workspace.clone(),
|
||||
self.project.clone(),
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
@@ -461,6 +465,7 @@ impl AcpThreadView {
|
||||
resume_thread: Option<DbThreadMetadata>,
|
||||
workspace: WeakEntity<Workspace>,
|
||||
project: Entity<Project>,
|
||||
track_load_event: bool,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> ThreadState {
|
||||
@@ -519,6 +524,10 @@ impl AcpThreadView {
|
||||
}
|
||||
};
|
||||
|
||||
if track_load_event {
|
||||
telemetry::event!("Agent Thread Started", agent = connection.telemetry_id());
|
||||
}
|
||||
|
||||
let result = if let Some(native_agent) = connection
|
||||
.clone()
|
||||
.downcast::<agent::NativeAgentConnection>()
|
||||
@@ -1133,8 +1142,8 @@ impl AcpThreadView {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
let agent_telemetry_id = self.agent.telemetry_id();
|
||||
let session_id = thread.read(cx).session_id().clone();
|
||||
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
|
||||
let thread = thread.downgrade();
|
||||
if self.should_be_following {
|
||||
self.workspace
|
||||
@@ -1512,6 +1521,7 @@ impl AcpThreadView {
|
||||
else {
|
||||
return;
|
||||
};
|
||||
let agent_telemetry_id = connection.telemetry_id();
|
||||
|
||||
// Check for the experimental "terminal-auth" _meta field
|
||||
let auth_method = connection.auth_methods().iter().find(|m| m.id == method);
|
||||
@@ -1579,19 +1589,18 @@ impl AcpThreadView {
|
||||
);
|
||||
cx.notify();
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
let agent = self.agent.clone();
|
||||
async move |this, cx| {
|
||||
let result = authenticate.await;
|
||||
|
||||
match &result {
|
||||
Ok(_) => telemetry::event!(
|
||||
"Authenticate Agent Succeeded",
|
||||
agent = agent.telemetry_id()
|
||||
agent = agent_telemetry_id
|
||||
),
|
||||
Err(_) => {
|
||||
telemetry::event!(
|
||||
"Authenticate Agent Failed",
|
||||
agent = agent.telemetry_id(),
|
||||
agent = agent_telemetry_id,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1675,6 +1684,7 @@ impl AcpThreadView {
|
||||
None,
|
||||
this.workspace.clone(),
|
||||
this.project.clone(),
|
||||
true,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -1730,43 +1740,38 @@ impl AcpThreadView {
|
||||
connection.authenticate(method, cx)
|
||||
};
|
||||
cx.notify();
|
||||
self.auth_task =
|
||||
Some(cx.spawn_in(window, {
|
||||
let agent = self.agent.clone();
|
||||
async move |this, cx| {
|
||||
let result = authenticate.await;
|
||||
self.auth_task = Some(cx.spawn_in(window, {
|
||||
async move |this, cx| {
|
||||
let result = authenticate.await;
|
||||
|
||||
match &result {
|
||||
Ok(_) => telemetry::event!(
|
||||
"Authenticate Agent Succeeded",
|
||||
agent = agent.telemetry_id()
|
||||
),
|
||||
Err(_) => {
|
||||
telemetry::event!(
|
||||
"Authenticate Agent Failed",
|
||||
agent = agent.telemetry_id(),
|
||||
)
|
||||
}
|
||||
match &result {
|
||||
Ok(_) => telemetry::event!(
|
||||
"Authenticate Agent Succeeded",
|
||||
agent = agent_telemetry_id
|
||||
),
|
||||
Err(_) => {
|
||||
telemetry::event!("Authenticate Agent Failed", agent = agent_telemetry_id,)
|
||||
}
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
if let Err(err) = result {
|
||||
if let ThreadState::Unauthenticated {
|
||||
pending_auth_method,
|
||||
..
|
||||
} = &mut this.thread_state
|
||||
{
|
||||
pending_auth_method.take();
|
||||
}
|
||||
this.handle_thread_error(err, cx);
|
||||
} else {
|
||||
this.reset(window, cx);
|
||||
}
|
||||
this.auth_task.take()
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}));
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
if let Err(err) = result {
|
||||
if let ThreadState::Unauthenticated {
|
||||
pending_auth_method,
|
||||
..
|
||||
} = &mut this.thread_state
|
||||
{
|
||||
pending_auth_method.take();
|
||||
}
|
||||
this.handle_thread_error(err, cx);
|
||||
} else {
|
||||
this.reset(window, cx);
|
||||
}
|
||||
this.auth_task.take()
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
fn spawn_external_agent_login(
|
||||
@@ -1896,10 +1901,11 @@ impl AcpThreadView {
|
||||
let Some(thread) = self.thread() else {
|
||||
return;
|
||||
};
|
||||
let agent_telemetry_id = thread.read(cx).connection().telemetry_id();
|
||||
|
||||
telemetry::event!(
|
||||
"Agent Tool Call Authorized",
|
||||
agent = self.agent.telemetry_id(),
|
||||
agent = agent_telemetry_id,
|
||||
session = thread.read(cx).session_id(),
|
||||
option = option_kind
|
||||
);
|
||||
@@ -3509,6 +3515,8 @@ impl AcpThreadView {
|
||||
(method.id.0.clone(), method.name.clone())
|
||||
};
|
||||
|
||||
let agent_telemetry_id = connection.telemetry_id();
|
||||
|
||||
Button::new(method_id.clone(), name)
|
||||
.label_size(LabelSize::Small)
|
||||
.map(|this| {
|
||||
@@ -3528,7 +3536,7 @@ impl AcpThreadView {
|
||||
cx.listener(move |this, _, window, cx| {
|
||||
telemetry::event!(
|
||||
"Authenticate Agent Started",
|
||||
agent = this.agent.telemetry_id(),
|
||||
agent = agent_telemetry_id,
|
||||
method = method_id
|
||||
);
|
||||
|
||||
@@ -5376,47 +5384,39 @@ impl AcpThreadView {
|
||||
)
|
||||
}
|
||||
|
||||
fn render_codex_windows_warning(&self, cx: &mut Context<Self>) -> Option<Callout> {
|
||||
if self.show_codex_windows_warning {
|
||||
Some(
|
||||
Callout::new()
|
||||
.icon(IconName::Warning)
|
||||
.severity(Severity::Warning)
|
||||
.title("Codex on Windows")
|
||||
.description(
|
||||
"For best performance, run Codex in Windows Subsystem for Linux (WSL2)",
|
||||
)
|
||||
.actions_slot(
|
||||
Button::new("open-wsl-modal", "Open in WSL")
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(cx.listener({
|
||||
move |_, _, _window, cx| {
|
||||
#[cfg(windows)]
|
||||
_window.dispatch_action(
|
||||
zed_actions::wsl_actions::OpenWsl::default().boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
)
|
||||
.dismiss_action(
|
||||
IconButton::new("dismiss", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Dismiss Warning"))
|
||||
.on_click(cx.listener({
|
||||
move |this, _, _, cx| {
|
||||
this.show_codex_windows_warning = false;
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
),
|
||||
fn render_codex_windows_warning(&self, cx: &mut Context<Self>) -> Callout {
|
||||
Callout::new()
|
||||
.icon(IconName::Warning)
|
||||
.severity(Severity::Warning)
|
||||
.title("Codex on Windows")
|
||||
.description("For best performance, run Codex in Windows Subsystem for Linux (WSL2)")
|
||||
.actions_slot(
|
||||
Button::new("open-wsl-modal", "Open in WSL")
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(cx.listener({
|
||||
move |_, _, _window, cx| {
|
||||
#[cfg(windows)]
|
||||
_window.dispatch_action(
|
||||
zed_actions::wsl_actions::OpenWsl::default().boxed_clone(),
|
||||
cx,
|
||||
);
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
)
|
||||
.dismiss_action(
|
||||
IconButton::new("dismiss", IconName::Close)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_color(Color::Muted)
|
||||
.tooltip(Tooltip::text("Dismiss Warning"))
|
||||
.on_click(cx.listener({
|
||||
move |this, _, _, cx| {
|
||||
this.show_codex_windows_warning = false;
|
||||
cx.notify();
|
||||
}
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn render_thread_error(&mut self, window: &mut Window, cx: &mut Context<Self>) -> Option<Div> {
|
||||
@@ -5936,12 +5936,8 @@ impl Render for AcpThreadView {
|
||||
_ => this,
|
||||
})
|
||||
.children(self.render_thread_retry_status_callout(window, cx))
|
||||
.children({
|
||||
if cfg!(windows) && self.project.read(cx).is_local() {
|
||||
self.render_codex_windows_warning(cx)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
.when(self.show_codex_windows_warning, |this| {
|
||||
this.child(self.render_codex_windows_warning(cx))
|
||||
})
|
||||
.children(self.render_thread_error(window, cx))
|
||||
.when_some(
|
||||
@@ -6398,6 +6394,7 @@ pub(crate) mod tests {
|
||||
project,
|
||||
history_store,
|
||||
None,
|
||||
false,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -6475,10 +6472,6 @@ pub(crate) mod tests {
|
||||
where
|
||||
C: 'static + AgentConnection + Send + Clone,
|
||||
{
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"test"
|
||||
}
|
||||
|
||||
fn logo(&self) -> ui::IconName {
|
||||
ui::IconName::Ai
|
||||
}
|
||||
@@ -6505,8 +6498,8 @@ pub(crate) mod tests {
|
||||
struct SaboteurAgentConnection;
|
||||
|
||||
impl AgentConnection for SaboteurAgentConnection {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"saboteur"
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
"saboteur".into()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
@@ -6569,8 +6562,8 @@ pub(crate) mod tests {
|
||||
struct RefusalAgentConnection;
|
||||
|
||||
impl AgentConnection for RefusalAgentConnection {
|
||||
fn telemetry_id(&self) -> &'static str {
|
||||
"refusal"
|
||||
fn telemetry_id(&self) -> SharedString {
|
||||
"refusal".into()
|
||||
}
|
||||
|
||||
fn new_thread(
|
||||
@@ -6671,6 +6664,7 @@ pub(crate) mod tests {
|
||||
project.clone(),
|
||||
history_store.clone(),
|
||||
None,
|
||||
false,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use ui::{
|
||||
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
|
||||
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
|
||||
PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
|
||||
PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
|
||||
};
|
||||
use util::ResultExt as _;
|
||||
use workspace::{Workspace, create_and_open_local_file};
|
||||
@@ -117,7 +117,7 @@ impl AgentConfiguration {
|
||||
}
|
||||
|
||||
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
|
||||
for provider in providers {
|
||||
self.add_provider_configuration_view(&provider, window, cx);
|
||||
}
|
||||
@@ -260,11 +260,15 @@ impl AgentConfiguration {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
.child(if let Some(icon_path) = provider.icon_path() {
|
||||
Icon::from_external_svg(icon_path)
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted)
|
||||
} else {
|
||||
Icon::new(provider.icon())
|
||||
.size(IconSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.color(Color::Muted)
|
||||
})
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
@@ -416,7 +420,7 @@ impl AgentConfiguration {
|
||||
&mut self,
|
||||
cx: &mut Context<Self>,
|
||||
) -> impl IntoElement {
|
||||
let providers = LanguageModelRegistry::read_global(cx).providers();
|
||||
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
|
||||
|
||||
let popover_menu = PopoverMenu::new("add-provider-popover")
|
||||
.trigger(
|
||||
@@ -879,6 +883,7 @@ impl AgentConfiguration {
|
||||
.child(context_server_configuration_menu)
|
||||
.child(
|
||||
Switch::new("context-server-switch", is_running.into())
|
||||
.color(SwitchColor::Accent)
|
||||
.on_click({
|
||||
let context_server_manager = self.context_server_store.clone();
|
||||
let fs = self.fs.clone();
|
||||
|
||||
@@ -77,7 +77,8 @@ impl Render for AgentModelSelector {
|
||||
.map(|model| model.model.name().0)
|
||||
.unwrap_or_else(|| SharedString::from("Select a Model"));
|
||||
|
||||
let provider_icon = model.as_ref().map(|model| model.provider.icon());
|
||||
let provider_icon_path = model.as_ref().and_then(|model| model.provider.icon_path());
|
||||
let provider_icon_name = model.as_ref().map(|model| model.provider.icon());
|
||||
let color = if self.menu_handle.is_deployed() {
|
||||
Color::Accent
|
||||
} else {
|
||||
@@ -89,8 +90,17 @@ impl Render for AgentModelSelector {
|
||||
PickerPopoverMenu::new(
|
||||
self.selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
.when_some(provider_icon, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
.when_some(provider_icon_path.clone(), |this, icon_path| {
|
||||
this.child(
|
||||
Icon::from_external_svg(icon_path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
)
|
||||
})
|
||||
.when(provider_icon_path.is_none(), |this| {
|
||||
this.when_some(provider_icon_name, |this, icon| {
|
||||
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
|
||||
})
|
||||
})
|
||||
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
|
||||
.child(
|
||||
@@ -102,7 +112,7 @@ impl Render for AgentModelSelector {
|
||||
.child(
|
||||
Icon::new(IconName::ChevronDown)
|
||||
.color(color)
|
||||
.size(IconSize::Small),
|
||||
.size(IconSize::XSmall),
|
||||
),
|
||||
move |_window, cx| {
|
||||
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
|
||||
|
||||
@@ -305,6 +305,7 @@ impl ActiveView {
|
||||
project,
|
||||
history_store,
|
||||
prompt_store,
|
||||
false,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -885,10 +886,6 @@ impl AgentPanel {
|
||||
|
||||
let server = ext_agent.server(fs, history);
|
||||
|
||||
if !loading {
|
||||
telemetry::event!("Agent Thread Started", agent = server.telemetry_id());
|
||||
}
|
||||
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
let selected_agent = ext_agent.into();
|
||||
if this.selected_agent != selected_agent {
|
||||
@@ -905,6 +902,7 @@ impl AgentPanel {
|
||||
project,
|
||||
this.history_store.clone(),
|
||||
this.prompt_store.clone(),
|
||||
!loading,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
@@ -2294,7 +2292,7 @@ impl AgentPanel {
|
||||
let history_is_empty = self.history_store.read(cx).is_empty(cx);
|
||||
|
||||
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.iter()
|
||||
.any(|provider| {
|
||||
provider.is_authenticated(cx)
|
||||
|
||||
@@ -160,16 +160,6 @@ pub enum ExternalAgent {
|
||||
}
|
||||
|
||||
impl ExternalAgent {
|
||||
pub fn parse_built_in(server: &dyn agent_servers::AgentServer) -> Option<Self> {
|
||||
match server.telemetry_id() {
|
||||
"gemini-cli" => Some(Self::Gemini),
|
||||
"claude-code" => Some(Self::ClaudeCode),
|
||||
"codex" => Some(Self::Codex),
|
||||
"zed" => Some(Self::NativeAgent),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server(
|
||||
&self,
|
||||
fs: Arc<dyn fs::Fs>,
|
||||
@@ -348,7 +338,8 @@ fn init_language_model_settings(cx: &mut App) {
|
||||
|_, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
update_active_language_model_from_settings(cx);
|
||||
}
|
||||
_ => {}
|
||||
@@ -367,26 +358,49 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
||||
}
|
||||
}
|
||||
|
||||
let default = settings.default_model.as_ref().map(to_selected_model);
|
||||
// Filter out models from providers that are not authenticated
|
||||
fn is_provider_authenticated(
|
||||
selection: &LanguageModelSelection,
|
||||
registry: &LanguageModelRegistry,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let provider_id = LanguageModelProviderId::from(selection.provider.0.clone());
|
||||
registry
|
||||
.provider(&provider_id)
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
}
|
||||
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let registry_ref = registry.read(cx);
|
||||
|
||||
let default = settings
|
||||
.default_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let inline_assistant = settings
|
||||
.inline_assistant_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let commit_message = settings
|
||||
.commit_message_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let thread_summary = settings
|
||||
.thread_summary_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let inline_alternatives = settings
|
||||
.inline_alternatives
|
||||
.iter()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.select_default_model(default.as_ref(), cx);
|
||||
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
|
||||
registry.select_commit_message_model(commit_message.as_ref(), cx);
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
use std::{cmp::Reverse, sync::Arc};
|
||||
|
||||
use collections::IndexMap;
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||
use gpui::{
|
||||
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
|
||||
};
|
||||
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
|
||||
LanguageModelRegistry,
|
||||
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
@@ -47,7 +46,9 @@ pub fn language_model_selector(
|
||||
}
|
||||
|
||||
fn all_models(cx: &App) -> GroupedModels {
|
||||
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
|
||||
let providers = LanguageModelRegistry::global(cx)
|
||||
.read(cx)
|
||||
.visible_providers();
|
||||
|
||||
let recommended = providers
|
||||
.iter()
|
||||
@@ -57,12 +58,12 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
model,
|
||||
icon: provider.icon(),
|
||||
icon: ProviderIcon::from_provider(provider.as_ref()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let all = providers
|
||||
let all: Vec<ModelInfo> = providers
|
||||
.iter()
|
||||
.flat_map(|provider| {
|
||||
provider
|
||||
@@ -70,7 +71,7 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
model,
|
||||
icon: provider.icon(),
|
||||
icon: ProviderIcon::from_provider(provider.as_ref()),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
@@ -78,10 +79,26 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
GroupedModels::new(all, recommended)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProviderIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
impl ProviderIcon {
|
||||
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
|
||||
if let Some(path) = provider.icon_path() {
|
||||
Self::Path(path)
|
||||
} else {
|
||||
Self::Name(provider.icon())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ModelInfo {
|
||||
model: Arc<dyn LanguageModel>,
|
||||
icon: IconName,
|
||||
icon: ProviderIcon,
|
||||
}
|
||||
|
||||
pub struct LanguageModelPickerDelegate {
|
||||
@@ -91,7 +108,7 @@ pub struct LanguageModelPickerDelegate {
|
||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||
selected_index: usize,
|
||||
_authenticate_all_providers_task: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
_refresh_models_task: Task<()>,
|
||||
popover_styles: bool,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
@@ -116,24 +133,43 @@ impl LanguageModelPickerDelegate {
|
||||
filtered_entries: entries,
|
||||
get_active_model: Arc::new(get_active_model),
|
||||
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
||||
_subscriptions: vec![cx.subscribe_in(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|picker, _, event, window, cx| {
|
||||
match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let query = picker.query(cx);
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
// Update matches will automatically drop the previous task
|
||||
// if we get a provider event again
|
||||
picker.update_matches(query, window, cx)
|
||||
}
|
||||
_ => {}
|
||||
_refresh_models_task: {
|
||||
// Create a channel to signal when models need refreshing
|
||||
let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
|
||||
|
||||
// Subscribe to registry events and send refresh signals through the channel
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
},
|
||||
)],
|
||||
language_model::Event::AddedProvider(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
language_model::Event::RemovedProvider(_) => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
language_model::Event::ProvidersChanged => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// Spawn a task that listens for refresh signals and updates the picker
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(()) = refresh_rx.next().await {
|
||||
let result = this.update_in(cx, |picker, window, cx| {
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
picker.refresh(window, cx);
|
||||
});
|
||||
if result.is_err() {
|
||||
// Picker was dropped, exit the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
popover_styles,
|
||||
focus_handle,
|
||||
}
|
||||
@@ -392,7 +428,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
|
||||
let configured_providers = language_model_registry
|
||||
.read(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.into_iter()
|
||||
.filter(|provider| provider.is_authenticated(cx))
|
||||
.collect::<Vec<_>>();
|
||||
@@ -504,11 +540,16 @@ impl PickerDelegate for LanguageModelPickerDelegate {
|
||||
h_flex()
|
||||
.w_full()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
Icon::new(model_info.icon)
|
||||
.child(match &model_info.icon {
|
||||
ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small),
|
||||
)
|
||||
ProviderIcon::Path(icon_path) => {
|
||||
Icon::from_external_svg(icon_path.clone())
|
||||
.color(model_icon_color)
|
||||
.size(IconSize::Small)
|
||||
}
|
||||
})
|
||||
.child(Label::new(model_info.model.name().0).truncate()),
|
||||
)
|
||||
.end_slot(div().pr_3().when(is_selected, |this| {
|
||||
@@ -657,7 +698,7 @@ mod tests {
|
||||
.into_iter()
|
||||
.map(|(provider, name)| ModelInfo {
|
||||
model: Arc::new(TestLanguageModel::new(name, provider)),
|
||||
icon: IconName::Ai,
|
||||
icon: ProviderIcon::Name(IconName::Ai),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -1682,98 +1682,6 @@ impl TextThreadEditor {
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let editor_clipboard_selections = cx
|
||||
.read_from_clipboard()
|
||||
.and_then(|item| item.entries().first().cloned())
|
||||
.and_then(|entry| match entry {
|
||||
ClipboardEntry::String(text) => {
|
||||
text.metadata_json::<Vec<editor::ClipboardSelection>>()
|
||||
}
|
||||
_ => None,
|
||||
});
|
||||
|
||||
let has_file_context = editor_clipboard_selections
|
||||
.as_ref()
|
||||
.is_some_and(|selections| {
|
||||
selections
|
||||
.iter()
|
||||
.any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
|
||||
});
|
||||
|
||||
if has_file_context {
|
||||
if let Some(clipboard_item) = cx.read_from_clipboard() {
|
||||
if let Some(ClipboardEntry::String(clipboard_text)) =
|
||||
clipboard_item.entries().first()
|
||||
{
|
||||
if let Some(selections) = editor_clipboard_selections {
|
||||
cx.stop_propagation();
|
||||
|
||||
let text = clipboard_text.text();
|
||||
self.editor.update(cx, |editor, cx| {
|
||||
let mut current_offset = 0;
|
||||
let weak_editor = cx.entity().downgrade();
|
||||
|
||||
for selection in selections {
|
||||
if let (Some(file_path), Some(line_range)) =
|
||||
(selection.file_path, selection.line_range)
|
||||
{
|
||||
let selected_text =
|
||||
&text[current_offset..current_offset + selection.len];
|
||||
let fence = assistant_slash_commands::codeblock_fence_for_path(
|
||||
file_path.to_str(),
|
||||
Some(line_range.clone()),
|
||||
);
|
||||
let formatted_text = format!("{fence}{selected_text}\n```");
|
||||
|
||||
let insert_point = editor
|
||||
.selections
|
||||
.newest::<Point>(&editor.display_snapshot(cx))
|
||||
.head();
|
||||
let start_row = MultiBufferRow(insert_point.row);
|
||||
|
||||
editor.insert(&formatted_text, window, cx);
|
||||
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
let anchor_before = snapshot.anchor_after(insert_point);
|
||||
let anchor_after = editor
|
||||
.selections
|
||||
.newest_anchor()
|
||||
.head()
|
||||
.bias_left(&snapshot);
|
||||
|
||||
editor.insert("\n", window, cx);
|
||||
|
||||
let crease_text = acp_thread::selection_name(
|
||||
Some(file_path.as_ref()),
|
||||
&line_range,
|
||||
);
|
||||
|
||||
let fold_placeholder = quote_selection_fold_placeholder(
|
||||
crease_text,
|
||||
weak_editor.clone(),
|
||||
);
|
||||
let crease = Crease::inline(
|
||||
anchor_before..anchor_after,
|
||||
fold_placeholder,
|
||||
render_quote_selection_output_toggle,
|
||||
|_, _, _, _| Empty.into_any(),
|
||||
);
|
||||
editor.insert_creases(vec![crease], cx);
|
||||
editor.fold_at(start_row, window, cx);
|
||||
|
||||
current_offset += selection.len;
|
||||
if !selection.is_entire_line && current_offset < text.len() {
|
||||
current_offset += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cx.stop_propagation();
|
||||
|
||||
let mut images = if let Some(item) = cx.read_from_clipboard() {
|
||||
@@ -2189,7 +2097,8 @@ impl TextThreadEditor {
|
||||
.default_model()
|
||||
.map(|default| default.provider);
|
||||
|
||||
let provider_icon = match active_provider {
|
||||
let provider_icon_path = active_provider.as_ref().and_then(|p| p.icon_path());
|
||||
let provider_icon_name = match &active_provider {
|
||||
Some(provider) => provider.icon(),
|
||||
None => IconName::Ai,
|
||||
};
|
||||
@@ -2201,6 +2110,16 @@ impl TextThreadEditor {
|
||||
(Color::Muted, IconName::ChevronDown)
|
||||
};
|
||||
|
||||
let provider_icon_element = if let Some(icon_path) = provider_icon_path {
|
||||
Icon::from_external_svg(icon_path)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall)
|
||||
} else {
|
||||
Icon::new(provider_icon_name)
|
||||
.color(color)
|
||||
.size(IconSize::XSmall)
|
||||
};
|
||||
|
||||
PickerPopoverMenu::new(
|
||||
self.language_model_selector.clone(),
|
||||
ButtonLike::new("active-model")
|
||||
@@ -2208,7 +2127,7 @@ impl TextThreadEditor {
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_0p5()
|
||||
.child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
|
||||
.child(provider_icon_element)
|
||||
.child(
|
||||
Label::new(model_name)
|
||||
.color(color)
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
|
||||
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use language_model::{LanguageModelProvider, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
|
||||
use ui::{Divider, List, ListBulletItem, prelude::*};
|
||||
|
||||
#[derive(Clone)]
|
||||
enum ProviderIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
impl ProviderIcon {
|
||||
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
|
||||
if let Some(path) = provider.icon_path() {
|
||||
Self::Path(path)
|
||||
} else {
|
||||
Self::Name(provider.icon())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ApiKeysWithProviders {
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
configured_providers: Vec<(ProviderIcon, SharedString)>,
|
||||
}
|
||||
|
||||
impl ApiKeysWithProviders {
|
||||
@@ -13,7 +29,8 @@ impl ApiKeysWithProviders {
|
||||
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
this.configured_providers = Self::compute_configured_providers(cx)
|
||||
}
|
||||
_ => {}
|
||||
@@ -26,14 +43,19 @@ impl ApiKeysWithProviders {
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
|
||||
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0))
|
||||
.map(|provider| {
|
||||
(
|
||||
ProviderIcon::from_provider(provider.as_ref()),
|
||||
provider.name().0,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@@ -47,7 +69,14 @@ impl Render for ApiKeysWithProviders {
|
||||
.map(|(icon, name)| {
|
||||
h_flex()
|
||||
.gap_1p5()
|
||||
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
|
||||
.child(match icon {
|
||||
ProviderIcon::Name(icon_name) => Icon::new(icon_name)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Muted),
|
||||
})
|
||||
.child(Label::new(name))
|
||||
});
|
||||
div()
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
|
||||
pub struct AgentPanelOnboarding {
|
||||
user_store: Entity<UserStore>,
|
||||
client: Arc<Client>,
|
||||
configured_providers: Vec<(IconName, SharedString)>,
|
||||
has_configured_providers: bool,
|
||||
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
|
||||
}
|
||||
|
||||
@@ -27,8 +27,9 @@ impl AgentPanelOnboarding {
|
||||
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
this.configured_providers = Self::compute_available_providers(cx)
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
this.has_configured_providers = Self::has_configured_providers(cx)
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
@@ -38,20 +39,16 @@ impl AgentPanelOnboarding {
|
||||
Self {
|
||||
user_store,
|
||||
client,
|
||||
configured_providers: Self::compute_available_providers(cx),
|
||||
has_configured_providers: Self::has_configured_providers(cx),
|
||||
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
|
||||
fn has_configured_providers(cx: &App) -> bool {
|
||||
LanguageModelRegistry::read_global(cx)
|
||||
.providers()
|
||||
.visible_providers()
|
||||
.iter()
|
||||
.filter(|provider| {
|
||||
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
|
||||
})
|
||||
.map(|provider| (provider.icon(), provider.name().0))
|
||||
.collect()
|
||||
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding {
|
||||
}),
|
||||
)
|
||||
.map(|this| {
|
||||
if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
|
||||
if enrolled_in_trial || is_pro_user || self.has_configured_providers {
|
||||
this
|
||||
} else {
|
||||
this.child(ApiKeysWithoutProviders::new())
|
||||
|
||||
@@ -8,7 +8,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
|
||||
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
|
||||
pub use settings::ModelMode;
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http.workspace = true
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
tokio-socks.workspace = true
|
||||
tokio.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_zeta2_prompt"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_zeta2_prompt.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
indoc.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -1,485 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use std::cmp;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
|
||||
|
||||
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
|
||||
|
||||
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
|
||||
Do not include the cursor marker in your output.
|
||||
If you're editing multiple files, be sure to reflect filename in the hunk's header.
|
||||
"};
|
||||
|
||||
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
|
||||
# Instructions
|
||||
|
||||
You are an edit prediction agent in a code editor.
|
||||
|
||||
Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
|
||||
Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
|
||||
Always continue along the user's current trajectory, rather than changing course.
|
||||
|
||||
## Output Format
|
||||
|
||||
You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
|
||||
along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
|
||||
|
||||
<edits path="my-project/src/myapp/cli.py">
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
- Specify the file to edit using the `path` attribute.
|
||||
- Use `<old_text>` and `<new_text>` tags to replace content
|
||||
- `<old_text>` must exactly match existing file content, including indentation
|
||||
- `<old_text>` cannot be empty
|
||||
- Do not escape quotes, newlines, or other characters within tags
|
||||
- Always close all tags properly
|
||||
- Don't include the <|user_cursor|> marker in your output.
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
|
||||
---
|
||||
|
||||
Remember that the edits in the edit history have already been applied.
|
||||
"#};
|
||||
|
||||
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
|
||||
let prompt_data = PromptData {
|
||||
events: request.events.clone(),
|
||||
cursor_point: request.cursor_point,
|
||||
cursor_path: request.excerpt_path.clone(),
|
||||
included_files: request.related_files.clone(),
|
||||
};
|
||||
match request.prompt_format {
|
||||
PromptFormat::MinimalQwen => {
|
||||
return Ok(MinimalQwenPrompt.render(&prompt_data));
|
||||
}
|
||||
PromptFormat::SeedCoder1120 => {
|
||||
return Ok(SeedCoder1120Prompt.render(&prompt_data));
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let insertions = match request.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
|
||||
vec![(request.cursor_point, CURSOR_MARKER)]
|
||||
}
|
||||
PromptFormat::OnlySnippets => vec![],
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
};
|
||||
|
||||
let mut prompt = match request.prompt_format {
|
||||
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OnlySnippets => String::new(),
|
||||
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
};
|
||||
|
||||
if request.events.is_empty() {
|
||||
prompt.push_str("(No edit history)\n\n");
|
||||
} else {
|
||||
let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
|
||||
"The following are the latest edits made by the user, from earlier to later.\n\n"
|
||||
} else {
|
||||
"Here are the latest edits made by the user, from earlier to later.\n\n"
|
||||
};
|
||||
prompt.push_str(edit_preamble);
|
||||
push_events(&mut prompt, &request.events);
|
||||
}
|
||||
|
||||
let excerpts_preamble = match request.prompt_format {
|
||||
PromptFormat::Minimal => indoc! {"
|
||||
## Part of the file under the cursor
|
||||
|
||||
(The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history has been applied.
|
||||
We only show part of the file around the cursor.
|
||||
You can only edit exactly this part of the file.
|
||||
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
|
||||
"},
|
||||
PromptFormat::OldTextNewText => indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
Here is some excerpts of code that you should take into account to predict the next edit.
|
||||
|
||||
The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
|
||||
|
||||
In addition other excerpts are included to better understand what the edit will be, including the declaration
|
||||
or references of symbols around the cursor, or other similar code snippets that may need to be updated
|
||||
following patterns that appear in the edit history.
|
||||
|
||||
Consider each of them carefully in relation to the edit history, and that the user may not have navigated
|
||||
to the next place they want to edit yet.
|
||||
|
||||
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
|
||||
"},
|
||||
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history have been applied.
|
||||
"}
|
||||
}
|
||||
};
|
||||
|
||||
prompt.push_str(excerpts_preamble);
|
||||
prompt.push('\n');
|
||||
|
||||
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
|
||||
for related_file in &request.related_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match request.prompt_format {
|
||||
PromptFormat::OldTextNewText => {
|
||||
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
|
||||
}
|
||||
PromptFormat::Minimal => {
|
||||
prompt.push_str(MINIMAL_PROMPT_REMINDER);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
|
||||
match prompt_format {
|
||||
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
|
||||
_ => GenerationParams::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_codeblock<'a>(
|
||||
path: &Path,
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
|
||||
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
file_line_count,
|
||||
include_line_numbers,
|
||||
output,
|
||||
);
|
||||
write!(output, "`````\n\n").unwrap();
|
||||
}
|
||||
|
||||
fn write_codeblock_with_filename<'a>(
|
||||
path: &Path,
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
|
||||
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
file_line_count,
|
||||
include_line_numbers,
|
||||
output,
|
||||
);
|
||||
write!(output, "`````\n\n").unwrap();
|
||||
}
|
||||
|
||||
pub fn write_excerpts<'a>(
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &mut String,
|
||||
) {
|
||||
let mut current_row = Line(0);
|
||||
let mut sorted_insertions = sorted_insertions.iter().peekable();
|
||||
|
||||
for excerpt in excerpts {
|
||||
if excerpt.start_line > current_row {
|
||||
writeln!(output, "…").unwrap();
|
||||
}
|
||||
if excerpt.text.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
current_row = excerpt.start_line;
|
||||
|
||||
for mut line in excerpt.text.lines() {
|
||||
if include_line_numbers {
|
||||
write!(output, "{}|", current_row.0 + 1).unwrap();
|
||||
}
|
||||
|
||||
while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
|
||||
match current_row.cmp(&insertion_location.line) {
|
||||
cmp::Ordering::Equal => {
|
||||
let (prefix, suffix) = line.split_at(insertion_location.column as usize);
|
||||
output.push_str(prefix);
|
||||
output.push_str(insertion_marker);
|
||||
line = suffix;
|
||||
sorted_insertions.next();
|
||||
}
|
||||
cmp::Ordering::Less => break,
|
||||
cmp::Ordering::Greater => {
|
||||
sorted_insertions.next();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
current_row.0 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if current_row < file_line_count {
|
||||
writeln!(output, "…").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
|
||||
if events.is_empty() {
|
||||
return;
|
||||
};
|
||||
|
||||
writeln!(output, "`````diff").unwrap();
|
||||
for event in events {
|
||||
writeln!(output, "{}", event).unwrap();
|
||||
}
|
||||
writeln!(output, "`````\n").unwrap();
|
||||
}
|
||||
|
||||
struct PromptData {
|
||||
events: Vec<Arc<Event>>,
|
||||
cursor_point: Point,
|
||||
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
|
||||
included_files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct GenerationParams {
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
trait PromptFormatter {
|
||||
fn render(&self, data: &PromptData) -> String;
|
||||
|
||||
fn generation_params() -> GenerationParams {
|
||||
return GenerationParams::default();
|
||||
}
|
||||
}
|
||||
|
||||
struct MinimalQwenPrompt;
|
||||
|
||||
impl PromptFormatter for MinimalQwenPrompt {
|
||||
fn render(&self, data: &PromptData) -> String {
|
||||
let edit_history = self.fmt_edit_history(data);
|
||||
let context = self.fmt_context(data);
|
||||
|
||||
format!(
|
||||
"{instructions}\n\n{edit_history}\n\n{context}",
|
||||
instructions = MinimalQwenPrompt::INSTRUCTIONS,
|
||||
edit_history = edit_history,
|
||||
context = context
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MinimalQwenPrompt {
|
||||
const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
|
||||
|
||||
fn fmt_edit_history(&self, data: &PromptData) -> String {
|
||||
if data.events.is_empty() {
|
||||
"(No edit history)\n\n".to_string()
|
||||
} else {
|
||||
let mut events_str = String::new();
|
||||
push_events(&mut events_str, &data.events);
|
||||
format!(
|
||||
"The following are the latest edits made by the user, from earlier to later.\n\n{}",
|
||||
events_str
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_context(&self, data: &PromptData) -> String {
|
||||
let mut context = String::new();
|
||||
let include_line_numbers = true;
|
||||
|
||||
for related_file in &data.included_files {
|
||||
writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
|
||||
|
||||
if related_file.path == data.cursor_path {
|
||||
write!(context, "<|fim_prefix|>").unwrap();
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[(data.cursor_point, "<|fim_suffix|>")],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
writeln!(context, "<|fim_middle|>").unwrap();
|
||||
} else {
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
}
|
||||
}
|
||||
context
|
||||
}
|
||||
}
|
||||
|
||||
struct SeedCoder1120Prompt;
|
||||
|
||||
impl PromptFormatter for SeedCoder1120Prompt {
|
||||
fn render(&self, data: &PromptData) -> String {
|
||||
let edit_history = self.fmt_edit_history(data);
|
||||
let context = self.fmt_context(data);
|
||||
|
||||
format!(
|
||||
"# Edit History:\n{edit_history}\n\n{context}",
|
||||
edit_history = edit_history,
|
||||
context = context
|
||||
)
|
||||
}
|
||||
|
||||
fn generation_params() -> GenerationParams {
|
||||
GenerationParams {
|
||||
temperature: Some(0.2),
|
||||
top_p: Some(0.9),
|
||||
stop: Some(vec!["<[end_of_sentence]>".into()]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SeedCoder1120Prompt {
|
||||
fn fmt_edit_history(&self, data: &PromptData) -> String {
|
||||
if data.events.is_empty() {
|
||||
"(No edit history)\n\n".to_string()
|
||||
} else {
|
||||
let mut events_str = String::new();
|
||||
push_events(&mut events_str, &data.events);
|
||||
events_str
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_context(&self, data: &PromptData) -> String {
|
||||
let mut context = String::new();
|
||||
let include_line_numbers = true;
|
||||
|
||||
for related_file in &data.included_files {
|
||||
writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
|
||||
|
||||
if related_file.path == data.cursor_path {
|
||||
let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
|
||||
context.push_str(&fim_prompt);
|
||||
} else {
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
}
|
||||
}
|
||||
context
|
||||
}
|
||||
|
||||
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
|
||||
let mut buf = String::new();
|
||||
const FIM_SUFFIX: &str = "<[fim-suffix]>";
|
||||
const FIM_PREFIX: &str = "<[fim-prefix]>";
|
||||
const FIM_MIDDLE: &str = "<[fim-middle]>";
|
||||
write!(buf, "{}", FIM_PREFIX).unwrap();
|
||||
write_excerpts(
|
||||
&file.excerpts,
|
||||
&[(cursor_point, FIM_SUFFIX)],
|
||||
file.max_row,
|
||||
true,
|
||||
&mut buf,
|
||||
);
|
||||
|
||||
// Swap prefix and suffix parts
|
||||
let index = buf.find(FIM_SUFFIX).unwrap();
|
||||
let prefix = &buf[..index];
|
||||
let suffix = &buf[index..];
|
||||
|
||||
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
|
||||
}
|
||||
}
|
||||
@@ -33,6 +33,7 @@ smol.workspace = true
|
||||
tempfile.workspace = true
|
||||
url = { workspace = true, features = ["serde"] }
|
||||
util.workspace = true
|
||||
terminal.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -8,9 +8,12 @@ use futures::{
|
||||
AsyncBufReadExt as _, AsyncRead, AsyncWrite, AsyncWriteExt as _, Stream, StreamExt as _,
|
||||
};
|
||||
use gpui::AsyncApp;
|
||||
use settings::Settings as _;
|
||||
use smol::channel;
|
||||
use smol::process::Child;
|
||||
use terminal::terminal_settings::TerminalSettings;
|
||||
use util::TryFutureExt as _;
|
||||
use util::shell_builder::ShellBuilder;
|
||||
|
||||
use crate::client::ModelContextServerBinary;
|
||||
use crate::transport::Transport;
|
||||
@@ -28,9 +31,14 @@ impl StdioTransport {
|
||||
working_directory: &Option<PathBuf>,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self> {
|
||||
let mut command = util::command::new_smol_command(&binary.executable);
|
||||
let shell = cx.update(|cx| TerminalSettings::get(None, cx).shell.clone())?;
|
||||
let builder = ShellBuilder::new(&shell, cfg!(windows));
|
||||
let (command, args) =
|
||||
builder.build(Some(binary.executable.display().to_string()), &binary.args);
|
||||
|
||||
let mut command = util::command::new_smol_command(command);
|
||||
command
|
||||
.args(&binary.args)
|
||||
.args(args)
|
||||
.envs(binary.env.unwrap_or_default())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
|
||||
@@ -21,7 +21,6 @@ arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
@@ -50,8 +49,6 @@ semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strsim.workspace = true
|
||||
strum.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
@@ -62,6 +59,7 @@ uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, EditPredictionUsage, UserStore};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
|
||||
use cloud_llm_client::{
|
||||
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
|
||||
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
|
||||
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
|
||||
ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
|
||||
use collections::{HashMap, HashSet};
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
|
||||
use futures::{
|
||||
AsyncReadExt as _, FutureExt as _, StreamExt as _,
|
||||
channel::{
|
||||
mpsc::{self, UnboundedReceiver},
|
||||
oneshot,
|
||||
},
|
||||
channel::mpsc::{self, UnboundedReceiver},
|
||||
select_biased,
|
||||
};
|
||||
use gpui::BackgroundExecutor;
|
||||
@@ -58,8 +54,10 @@ mod onboarding_modal;
|
||||
pub mod open_ai_response;
|
||||
mod prediction;
|
||||
pub mod sweep_ai;
|
||||
|
||||
#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
|
||||
pub mod udiff;
|
||||
mod xml_edits;
|
||||
|
||||
mod zed_edit_prediction_delegate;
|
||||
pub mod zeta1;
|
||||
pub mod zeta2;
|
||||
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
|
||||
use crate::onboarding_modal::ZedPredictModal;
|
||||
pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
pub use crate::prediction::EditPredictionInputs;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
pub use crate::sweep_ai::SweepAi;
|
||||
pub use telemetry_events::EditPredictionRating;
|
||||
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
|
||||
min_bytes: 128,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
|
||||
prompt_format: PromptFormat::DEFAULT,
|
||||
};
|
||||
|
||||
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
|
||||
use_context: bool,
|
||||
options: ZetaOptions,
|
||||
update_required: bool,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: EditPredictionModel,
|
||||
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
|
||||
Mercury,
|
||||
}
|
||||
|
||||
pub struct EditPredictionModelInput {
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<zeta_prompt::Event>>,
|
||||
related_files: Arc<[RelatedFile]>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ZetaOptions {
|
||||
pub context: EditPredictionExcerptOptions,
|
||||
pub max_prompt_bytes: usize,
|
||||
pub prompt_format: predict_edits_v3::PromptFormat,
|
||||
}
|
||||
|
||||
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
|
||||
pub enum DebugEvent {
|
||||
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
|
||||
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
|
||||
EditPredictionRequested(EditPredictionRequestedDebugEvent),
|
||||
EditPredictionStarted(EditPredictionStartedDebugEvent),
|
||||
EditPredictionFinished(EditPredictionFinishedDebugEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditPredictionRequestedDebugEvent {
|
||||
pub inputs: EditPredictionInputs,
|
||||
pub retrieval_time: Duration,
|
||||
pub struct EditPredictionStartedDebugEvent {
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub position: Anchor,
|
||||
pub local_prompt: Result<String, String>,
|
||||
pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditPredictionFinishedDebugEvent {
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub position: Anchor,
|
||||
pub model_output: Option<String>,
|
||||
}
|
||||
|
||||
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
|
||||
|
||||
struct ProjectState {
|
||||
events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
events: VecDeque<Arc<zeta_prompt::Event>>,
|
||||
last_event: Option<LastEvent>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
context_updates_tx: smol::channel::Sender<()>,
|
||||
context_updates_rx: smol::channel::Receiver<()>,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
last_prediction_refresh: Option<(EntityId, Instant)>,
|
||||
cancelled_predictions: HashSet<usize>,
|
||||
context: Entity<RelatedExcerptStore>,
|
||||
@@ -241,7 +252,7 @@ struct ProjectState {
|
||||
}
|
||||
|
||||
impl ProjectState {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.events
|
||||
.iter()
|
||||
.cloned()
|
||||
@@ -376,7 +387,7 @@ impl LastEvent {
|
||||
&self,
|
||||
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
|
||||
cx: &App,
|
||||
) -> Option<Arc<predict_edits_v3::Event>> {
|
||||
) -> Option<Arc<zeta_prompt::Event>> {
|
||||
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
|
||||
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
|
||||
|
||||
@@ -396,7 +407,7 @@ impl LastEvent {
|
||||
if path == old_path && diff.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(predict_edits_v3::Event::BufferChange {
|
||||
Some(Arc::new(zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
@@ -481,7 +492,6 @@ impl EditPredictionStore {
|
||||
},
|
||||
),
|
||||
update_required: false,
|
||||
debug_tx: None,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2,
|
||||
@@ -536,12 +546,6 @@ impl EditPredictionStore {
|
||||
self.eval_cache = Some(cache);
|
||||
}
|
||||
|
||||
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
|
||||
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
|
||||
self.debug_tx = Some(debug_watch_tx);
|
||||
debug_watch_rx
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ZetaOptions {
|
||||
&self.options
|
||||
}
|
||||
@@ -560,15 +564,35 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_history_for_project(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project_state| project_state.events.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn context_for_project<'a>(
|
||||
&'a self,
|
||||
project: &Entity<Project>,
|
||||
cx: &'a App,
|
||||
) -> &'a [RelatedFile] {
|
||||
) -> Arc<[RelatedFile]> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project| project.context.read(cx).related_files())
|
||||
.unwrap_or(&[])
|
||||
.unwrap_or_else(|| vec![].into())
|
||||
}
|
||||
|
||||
pub fn context_for_project_with_buffers<'a>(
|
||||
&'a self,
|
||||
project: &Entity<Project>,
|
||||
cx: &'a App,
|
||||
) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project| project.context.read(cx).related_files_with_buffers())
|
||||
}
|
||||
|
||||
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
||||
@@ -599,85 +623,21 @@ impl EditPredictionStore {
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut ProjectState {
|
||||
let entity_id = project.entity_id();
|
||||
let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
|
||||
self.projects
|
||||
.entry(entity_id)
|
||||
.or_insert_with(|| ProjectState {
|
||||
context: {
|
||||
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
|
||||
cx.subscribe(
|
||||
&related_excerpt_store,
|
||||
move |this, _, event, _| match event {
|
||||
RelatedExcerptStoreEvent::StartedRefresh => {
|
||||
if let Some(debug_tx) = this.debug_tx.clone() {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalStarted(
|
||||
ContextRetrievalStartedDebugEvent {
|
||||
project_entity_id: entity_id,
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: String::new(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
} => {
|
||||
if let Some(debug_tx) = this.debug_tx.clone() {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalFinished(
|
||||
ContextRetrievalFinishedDebugEvent {
|
||||
project_entity_id: entity_id,
|
||||
timestamp: Instant::now(),
|
||||
metadata: vec![
|
||||
(
|
||||
"Cache Hits",
|
||||
format!(
|
||||
"{}/{}",
|
||||
cache_hit_count,
|
||||
cache_hit_count + cache_miss_count
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Max LSP Time",
|
||||
format!(
|
||||
"{} ms",
|
||||
max_definition_latency.as_millis()
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Mean LSP Time",
|
||||
format!(
|
||||
"{} ms",
|
||||
mean_definition_latency.as_millis()
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
if let Some(project_state) = this.projects.get(&entity_id) {
|
||||
project_state.context_updates_tx.send_blocking(()).ok();
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
|
||||
this.handle_excerpt_store_event(entity_id, event);
|
||||
})
|
||||
.detach();
|
||||
related_excerpt_store
|
||||
},
|
||||
events: VecDeque::new(),
|
||||
last_event: None,
|
||||
recent_paths: VecDeque::new(),
|
||||
context_updates_rx,
|
||||
context_updates_tx,
|
||||
debug_tx: None,
|
||||
registered_buffers: HashMap::default(),
|
||||
current_prediction: None,
|
||||
cancelled_predictions: HashSet::default(),
|
||||
@@ -689,12 +649,79 @@ impl EditPredictionStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn project_context_updates(
|
||||
&self,
|
||||
pub fn remove_project(&mut self, project: &Entity<Project>) {
|
||||
self.projects.remove(&project.entity_id());
|
||||
}
|
||||
|
||||
fn handle_excerpt_store_event(
|
||||
&mut self,
|
||||
project_entity_id: EntityId,
|
||||
event: &RelatedExcerptStoreEvent,
|
||||
) {
|
||||
if let Some(project_state) = self.projects.get(&project_entity_id) {
|
||||
if let Some(debug_tx) = project_state.debug_tx.clone() {
|
||||
match event {
|
||||
RelatedExcerptStoreEvent::StartedRefresh => {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalStarted(
|
||||
ContextRetrievalStartedDebugEvent {
|
||||
project_entity_id: project_entity_id,
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: String::new(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
} => {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalFinished(
|
||||
ContextRetrievalFinishedDebugEvent {
|
||||
project_entity_id: project_entity_id,
|
||||
timestamp: Instant::now(),
|
||||
metadata: vec![
|
||||
(
|
||||
"Cache Hits",
|
||||
format!(
|
||||
"{}/{}",
|
||||
cache_hit_count,
|
||||
cache_hit_count + cache_miss_count
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Max LSP Time",
|
||||
format!("{} ms", max_definition_latency.as_millis())
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Mean LSP Time",
|
||||
format!("{} ms", mean_definition_latency.as_millis())
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_info(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
) -> Option<smol::channel::Receiver<()>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
Some(project_state.context_updates_rx.clone())
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<DebugEvent> {
|
||||
let project_state = self.get_or_init_project(project, cx);
|
||||
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
|
||||
project_state.debug_tx = Some(debug_watch_tx);
|
||||
debug_watch_rx
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
@@ -1348,6 +1375,7 @@ impl EditPredictionStore {
|
||||
let project_state = self.projects.get(&project.entity_id()).unwrap();
|
||||
let events = project_state.events(cx);
|
||||
let has_events = !events.is_empty();
|
||||
let debug_tx = project_state.debug_tx.clone();
|
||||
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
@@ -1357,55 +1385,29 @@ impl EditPredictionStore {
|
||||
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
|
||||
|
||||
let related_files = if self.use_context {
|
||||
self.context_for_project(&project, cx).to_vec()
|
||||
self.context_for_project(&project, cx)
|
||||
} else {
|
||||
Vec::new()
|
||||
Vec::new().into()
|
||||
};
|
||||
|
||||
let inputs = EditPredictionModelInput {
|
||||
project: project.clone(),
|
||||
buffer: active_buffer.clone(),
|
||||
snapshot: snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
recent_paths: project_state.recent_paths.clone(),
|
||||
trigger,
|
||||
diagnostic_search_range: diagnostic_search_range.clone(),
|
||||
debug_tx,
|
||||
};
|
||||
|
||||
let task = match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
trigger,
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
trigger,
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -1706,6 +1708,20 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub fn set_context_for_buffer(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.get_or_init_project(project, cx)
|
||||
.context
|
||||
.update(cx, |store, _| {
|
||||
store.set_related_files(related_files);
|
||||
});
|
||||
}
|
||||
|
||||
fn is_file_open_source(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
@@ -1729,14 +1745,14 @@ impl EditPredictionStore {
|
||||
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
|
||||
}
|
||||
|
||||
fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
|
||||
fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
|
||||
if !self.data_collection_choice.is_enabled() {
|
||||
return false;
|
||||
}
|
||||
events.iter().all(|event| {
|
||||
matches!(
|
||||
event.as_ref(),
|
||||
Event::BufferChange {
|
||||
zeta_prompt::Event::BufferChange {
|
||||
in_open_source_repo: true,
|
||||
..
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use crate::zeta1::MAX_EVENT_TOKENS;
|
||||
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
|
||||
use client::{UserStore, test::FakeServer};
|
||||
use clock::{FakeSystemClock, ReplicaId};
|
||||
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
|
||||
@@ -7,7 +7,6 @@ use cloud_llm_client::{
|
||||
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
|
||||
RejectEditPredictionsBody,
|
||||
};
|
||||
use edit_prediction_context::Line;
|
||||
use futures::{
|
||||
AsyncReadExt, StreamExt,
|
||||
channel::{mpsc, oneshot},
|
||||
@@ -28,6 +27,7 @@ use settings::SettingsStore;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
use util::{path, rel_path::rel_path};
|
||||
use uuid::Uuid;
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
|
||||
|
||||
@@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r"
|
||||
--- a/root/1.txt
|
||||
+++ b/root/1.txt
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r"
|
||||
--- a/root/1.txt
|
||||
+++ b/root/1.txt
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
@@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
});
|
||||
});
|
||||
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r#"
|
||||
--- a/root/2.txt
|
||||
+++ b/root/2.txt
|
||||
Hola!
|
||||
-Como
|
||||
+Como estas?
|
||||
Adios
|
||||
"#}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r#"
|
||||
--- a/root/2.txt
|
||||
+++ b/root/2.txt
|
||||
@@ ... @@
|
||||
Hola!
|
||||
-Como
|
||||
+Como estas?
|
||||
Adios
|
||||
"#},
|
||||
))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
@@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
// TODO Put back when we have a structured request again
|
||||
// assert_eq!(
|
||||
@@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
// );
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
|
||||
@@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r#"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"#}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r#"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"#},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
|
||||
@@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
const NO_OP_DIFF: &str = indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How
|
||||
Bye
|
||||
"};
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let response = model_response(NO_OP_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let response = model_response(request, "");
|
||||
let id = response.id.clone();
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
@@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_text("Hello!\nHow are you?\nBye", cx);
|
||||
});
|
||||
|
||||
let response = model_response(SIMPLE_DIFF);
|
||||
let response = model_response(request, SIMPLE_DIFF);
|
||||
let id = response.id.clone();
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
@@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(request, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_tx.send(first_response).unwrap();
|
||||
|
||||
@@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let second_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let second_response = model_response(request, SIMPLE_DIFF);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_tx.send(second_response).unwrap();
|
||||
|
||||
@@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(request, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_tx.send(first_response).unwrap();
|
||||
|
||||
@@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
// worse than current prediction
|
||||
let second_response = model_response(indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are
|
||||
Bye
|
||||
"});
|
||||
let second_response = model_response(
|
||||
request,
|
||||
indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are
|
||||
Bye
|
||||
"},
|
||||
);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_tx.send(second_response).unwrap();
|
||||
|
||||
@@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_first) = requests.predict.next().await.unwrap();
|
||||
let (request1, respond_first) = requests.predict.next().await.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_second) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_second) = requests.predict.next().await.unwrap();
|
||||
|
||||
// wait for throttle
|
||||
cx.run_until_parked();
|
||||
|
||||
// second responds first
|
||||
let second_response = model_response(SIMPLE_DIFF);
|
||||
let second_response = model_response(request, SIMPLE_DIFF);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_second.send(second_response).unwrap();
|
||||
|
||||
@@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let first_response = model_response(request1, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_first.send(first_response).unwrap();
|
||||
|
||||
@@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_first) = requests.predict.next().await.unwrap();
|
||||
let (request1, respond_first) = requests.predict.next().await.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_second) = requests.predict.next().await.unwrap();
|
||||
let (request2, respond_second) = requests.predict.next().await.unwrap();
|
||||
|
||||
// wait for throttle, so requests are sent
|
||||
cx.run_until_parked();
|
||||
@@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
// wait for throttle
|
||||
cx.run_until_parked();
|
||||
|
||||
let (_, respond_third) = requests.predict.next().await.unwrap();
|
||||
let (request3, respond_third) = requests.predict.next().await.unwrap();
|
||||
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let first_response = model_response(request1, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_first.send(first_response).unwrap();
|
||||
|
||||
@@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let cancelled_response = model_response(SIMPLE_DIFF);
|
||||
let cancelled_response = model_response(request2, SIMPLE_DIFF);
|
||||
let cancelled_id = cancelled_response.id.clone();
|
||||
respond_second.send(cancelled_response).unwrap();
|
||||
|
||||
@@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let third_response = model_response(SIMPLE_DIFF);
|
||||
let third_response = model_response(request3, SIMPLE_DIFF);
|
||||
let third_response_id = third_response.id.clone();
|
||||
respond_third.send(third_response).unwrap();
|
||||
|
||||
@@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
|
||||
// );
|
||||
// }
|
||||
|
||||
fn model_response(text: &str) -> open_ai::Response {
|
||||
// Generate a model response that would apply the given diff to the active file.
|
||||
fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
|
||||
let prompt = match &request.messages[0] {
|
||||
open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(content),
|
||||
} => content,
|
||||
_ => panic!("unexpected request {request:?}"),
|
||||
};
|
||||
|
||||
let open = "<editable_region>\n";
|
||||
let close = "</editable_region>";
|
||||
let cursor = "<|user_cursor|>";
|
||||
|
||||
let start_ix = open.len() + prompt.find(open).unwrap();
|
||||
let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
|
||||
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
|
||||
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
|
||||
|
||||
open_ai::Response {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
object: "response".into(),
|
||||
@@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response {
|
||||
choices: vec![open_ai::Choice {
|
||||
index: 0,
|
||||
message: open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(text.to_string())),
|
||||
content: Some(open_ai::MessageContent::Plain(new_excerpt)),
|
||||
tool_calls: vec![],
|
||||
},
|
||||
finish_reason: None,
|
||||
@@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
|
||||
.await;
|
||||
|
||||
let completion = EditPrediction {
|
||||
let prediction = EditPrediction {
|
||||
edits,
|
||||
edit_preview,
|
||||
buffer: buffer.clone(),
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
id: EditPredictionId("the-id".into()),
|
||||
inputs: EditPredictionInputs {
|
||||
inputs: ZetaPromptInput {
|
||||
events: Default::default(),
|
||||
included_files: Default::default(),
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
line: Line(0),
|
||||
column: 0,
|
||||
},
|
||||
related_files: Default::default(),
|
||||
cursor_path: Path::new("").into(),
|
||||
cursor_excerpt: "".into(),
|
||||
editable_range_in_excerpt: 0..0,
|
||||
cursor_offset_in_excerpt: 0,
|
||||
},
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
@@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
|
||||
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
|
||||
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
|
||||
};
|
||||
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
|
||||
@@ -38,16 +35,17 @@ impl Mercury {
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
pub(crate) fn request_prediction(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
_recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
_diagnostic_search_range: Range<Point>,
|
||||
EditPredictionModelInput {
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
@@ -62,6 +60,7 @@ impl Mercury {
|
||||
let http_client = cx.http_client();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let active_buffer = buffer.clone();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let (editable_range, context_range) =
|
||||
@@ -72,39 +71,39 @@ impl Mercury {
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let offset_range = editable_range.to_offset(&snapshot);
|
||||
let prompt = build_prompt(
|
||||
&events,
|
||||
&related_files,
|
||||
&snapshot,
|
||||
full_path.as_ref(),
|
||||
cursor_point,
|
||||
editable_range,
|
||||
context_range.clone(),
|
||||
);
|
||||
let context_offset_range = context_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events: events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(
|
||||
context_range.start.row,
|
||||
),
|
||||
text: snapshot
|
||||
.text_for_range(context_range.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
let editable_offset_range = editable_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = zeta_prompt::ZetaPromptInput {
|
||||
events,
|
||||
related_files,
|
||||
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
|
||||
- context_range.start.to_offset(&snapshot),
|
||||
cursor_path: full_path.clone(),
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt: (editable_offset_range.start
|
||||
- context_offset_range.start)
|
||||
..(editable_offset_range.end - context_offset_range.start),
|
||||
};
|
||||
|
||||
let prompt = build_prompt(&inputs);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: active_buffer.downgrade(),
|
||||
prompt: Some(prompt.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let request_body = open_ai::Request {
|
||||
model: "mercury-coder".into(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
@@ -160,6 +159,18 @@ impl Mercury {
|
||||
let id = mem::take(&mut response.id);
|
||||
let response_str = text_from_response(response).unwrap_or_default();
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: active_buffer.downgrade(),
|
||||
model_output: Some(response_str.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
|
||||
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
|
||||
|
||||
@@ -168,15 +179,16 @@ impl Mercury {
|
||||
|
||||
if response_str != NO_PREDICTION_OUTPUT {
|
||||
let old_text = snapshot
|
||||
.text_for_range(offset_range.clone())
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
edits.extend(
|
||||
language::text_diff(&old_text, &response_str)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(offset_range.start + range.start)
|
||||
..snapshot.anchor_before(offset_range.start + range.end),
|
||||
snapshot.anchor_after(editable_offset_range.start + range.start)
|
||||
..snapshot
|
||||
.anchor_before(editable_offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
}),
|
||||
@@ -186,8 +198,6 @@ impl Mercury {
|
||||
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) =
|
||||
result.await.context("Mercury edit prediction failed")?;
|
||||
@@ -208,15 +218,7 @@ impl Mercury {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
events: &[Arc<Event>],
|
||||
related_files: &[RelatedFile],
|
||||
cursor_buffer: &BufferSnapshot,
|
||||
cursor_buffer_path: &Path,
|
||||
cursor_point: Point,
|
||||
editable_range: Range<Point>,
|
||||
context_range: Range<Point>,
|
||||
) -> String {
|
||||
fn build_prompt(inputs: &ZetaPromptInput) -> String {
|
||||
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
|
||||
@@ -237,14 +239,14 @@ fn build_prompt(
|
||||
&mut prompt,
|
||||
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|
||||
|prompt| {
|
||||
for related_file in related_files {
|
||||
for related_file in inputs.related_files.iter() {
|
||||
for related_excerpt in &related_file.excerpts {
|
||||
push_delimited(
|
||||
prompt,
|
||||
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
|
||||
prompt.push_str(related_file.path.path.as_unix_str());
|
||||
prompt.push_str(related_file.path.to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&related_excerpt.text.to_string());
|
||||
},
|
||||
@@ -259,21 +261,22 @@ fn build_prompt(
|
||||
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
|
||||
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
|
||||
let prefix_range = context_range.start..editable_range.start;
|
||||
let suffix_range = editable_range.end..context_range.end;
|
||||
|
||||
prompt.extend(cursor_buffer.text_for_range(prefix_range));
|
||||
prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
|
||||
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
|
||||
let range_before_cursor = editable_range.start..cursor_point;
|
||||
let range_after_cursor = cursor_point..editable_range.end;
|
||||
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
|
||||
prompt.push_str(
|
||||
&inputs.cursor_excerpt
|
||||
[inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_TAG);
|
||||
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
|
||||
prompt.push_str(
|
||||
&inputs.cursor_excerpt
|
||||
[inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
|
||||
);
|
||||
});
|
||||
prompt.extend(cursor_buffer.text_for_range(suffix_range));
|
||||
prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
|
||||
},
|
||||
);
|
||||
|
||||
@@ -281,8 +284,8 @@ fn build_prompt(
|
||||
&mut prompt,
|
||||
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|
||||
|prompt| {
|
||||
for event in events {
|
||||
writeln!(prompt, "{event}").unwrap();
|
||||
for event in inputs.events.iter() {
|
||||
zeta_prompt::write_event(prompt, &event);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
|
||||
use edit_prediction_types::interpolate_edits;
|
||||
use gpui::{AsyncApp, Entity, SharedString};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
|
||||
use serde::Serialize;
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(pub SharedString);
|
||||
@@ -40,7 +39,7 @@ impl EditPredictionResult {
|
||||
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
buffer_snapshotted_at: Instant,
|
||||
response_received_at: Instant,
|
||||
inputs: EditPredictionInputs,
|
||||
inputs: ZetaPromptInput,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Self {
|
||||
if edits.is_empty() {
|
||||
@@ -94,15 +93,7 @@ pub struct EditPrediction {
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub buffer_snapshotted_at: Instant,
|
||||
pub response_received_at: Instant,
|
||||
pub inputs: EditPredictionInputs,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EditPredictionInputs {
|
||||
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
|
||||
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub inputs: zeta_prompt::ZetaPromptInput,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use super::*;
|
||||
use gpui::{App, Entity, TestAppContext, prelude::*};
|
||||
use language::{Buffer, ToOffset as _};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
@@ -154,14 +148,13 @@ mod tests {
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
buffer: buffer.clone(),
|
||||
edit_preview,
|
||||
inputs: EditPredictionInputs {
|
||||
inputs: ZetaPromptInput {
|
||||
events: vec![],
|
||||
included_files: vec![],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
column: 0,
|
||||
},
|
||||
related_files: vec![].into(),
|
||||
cursor_path: Path::new("path.txt").into(),
|
||||
cursor_offset_in_excerpt: 0,
|
||||
cursor_excerpt: "".into(),
|
||||
editable_range_in_excerpt: 0..0,
|
||||
},
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
|
||||
use language::{Point, ToOffset as _};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use project::{Project, ProjectPath};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
fmt::{self, Write as _},
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
|
||||
use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
|
||||
|
||||
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
@@ -44,40 +39,34 @@ impl SweepAi {
|
||||
|
||||
pub fn request_prediction_with_sweep(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
inputs: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let debug_info = self.debug_info.clone();
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
let full_path: Arc<Path> = inputs
|
||||
.snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let project_file = project::File::from_dyn(inputs.snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
let offset = inputs.position.to_offset(&inputs.snapshot);
|
||||
|
||||
let recent_buffers = recent_paths.iter().cloned();
|
||||
let recent_buffers = inputs.recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if inputs.buffer == buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
@@ -86,14 +75,13 @@ impl SweepAi {
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let text = snapshot.text();
|
||||
let text = inputs.snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in &events {
|
||||
for event in &inputs.events {
|
||||
write_event(event.as_ref(), &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
@@ -122,20 +110,23 @@ impl SweepAi {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let retrieval_chunks = related_files
|
||||
let retrieval_chunks = inputs
|
||||
.related_files
|
||||
.iter()
|
||||
.flat_map(|related_file| {
|
||||
related_file.excerpts.iter().map(|excerpt| FileChunk {
|
||||
file_path: related_file.path.path.as_unix_str().to_string(),
|
||||
start_line: excerpt.point_range.start.row as usize,
|
||||
end_line: excerpt.point_range.end.row as usize,
|
||||
file_path: related_file.path.to_string_lossy().to_string(),
|
||||
start_line: excerpt.row_range.start as usize,
|
||||
end_line: excerpt.row_range.end as usize,
|
||||
content: excerpt.text.to_string(),
|
||||
timestamp: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let diagnostic_entries = inputs
|
||||
.snapshot
|
||||
.diagnostics_in_range(inputs.diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
|
||||
@@ -195,21 +186,14 @@ impl SweepAi {
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
text: request_body.file_contents.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
let ep_inputs = zeta_prompt::ZetaPromptInput {
|
||||
events: inputs.events,
|
||||
related_files: inputs.related_files.clone(),
|
||||
cursor_path: full_path.clone(),
|
||||
cursor_excerpt: request_body.file_contents.into(),
|
||||
// we actually don't know
|
||||
editable_range_in_excerpt: 0..inputs.snapshot.len(),
|
||||
cursor_offset_in_excerpt: request_body.cursor_position,
|
||||
};
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
@@ -237,15 +221,20 @@ impl SweepAi {
|
||||
|
||||
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
let old_text = inputs
|
||||
.snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
inputs
|
||||
.snapshot
|
||||
.anchor_after(response.start_index + range.start)
|
||||
..inputs
|
||||
.snapshot
|
||||
.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
@@ -254,13 +243,13 @@ impl SweepAi {
|
||||
anyhow::Ok((
|
||||
response.autocomplete_id,
|
||||
edits,
|
||||
snapshot,
|
||||
inputs.snapshot,
|
||||
response_received_at,
|
||||
inputs,
|
||||
ep_inputs,
|
||||
))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
let buffer = inputs.buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
|
||||
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
fn write_event(
|
||||
event: &cloud_llm_client::predict_edits_v3::Event,
|
||||
f: &mut impl fmt::Write,
|
||||
) -> fmt::Result {
|
||||
fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
|
||||
match event {
|
||||
cloud_llm_client::predict_edits_v3::Event::BufferChange {
|
||||
zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
|
||||
@@ -14,68 +14,18 @@ use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::Entity;
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::Project;
|
||||
|
||||
pub async fn parse_diff<'a>(
|
||||
diff_str: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut edited_buffer = None;
|
||||
let mut edits = Vec::new();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk {
|
||||
path: file_path,
|
||||
hunk,
|
||||
} => {
|
||||
let (buffer, ranges) = match edited_buffer {
|
||||
None => {
|
||||
edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
|
||||
edited_buffer
|
||||
.as_ref()
|
||||
.context("Model tried to edit a file that wasn't included")?
|
||||
}
|
||||
Some(ref current) => current,
|
||||
};
|
||||
|
||||
edits.extend(
|
||||
resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
|
||||
.with_context(|| format!("Diff:\n{diff_str}"))?,
|
||||
);
|
||||
}
|
||||
DiffEvent::FileEnd { renamed_to } => {
|
||||
let (buffer, _) = edited_buffer
|
||||
.take()
|
||||
.context("Got a FileEnd event before an Hunk event")?;
|
||||
|
||||
if renamed_to.is_some() {
|
||||
anyhow::bail!("edit predictions cannot rename files");
|
||||
}
|
||||
|
||||
if diff.next()?.is_some() {
|
||||
anyhow::bail!("Edited more than one file");
|
||||
}
|
||||
|
||||
return Ok((buffer, edits));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("No EOF"))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_diff<'a>(
|
||||
diff_str: &'a str,
|
||||
pub async fn apply_diff(
|
||||
diff_str: &str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'a>> {
|
||||
) -> Result<OpenedBuffers> {
|
||||
let mut included_files = HashMap::default();
|
||||
|
||||
for line in diff_str.lines() {
|
||||
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
|
||||
})??
|
||||
.await?;
|
||||
|
||||
included_files.insert(path, buffer);
|
||||
included_files.insert(path.to_string(), buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
|
||||
let (buffer, ranges) = match current_file {
|
||||
None => {
|
||||
let buffer = included_files
|
||||
.get_mut(&file_path)
|
||||
.get_mut(file_path.as_ref())
|
||||
.expect("Opened all files in diff");
|
||||
|
||||
current_file = Some((buffer, ranges.as_slice()));
|
||||
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
|
||||
Ok(OpenedBuffers(included_files))
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
|
||||
let mut text = text.to_string();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
let hunk_offset = text
|
||||
.find(&hunk.context)
|
||||
.ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
|
||||
for edit in hunk.edits.iter().rev() {
|
||||
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
|
||||
text.replace_range(range, &edit.text);
|
||||
}
|
||||
}
|
||||
DiffEvent::FileEnd { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
struct PatchFile<'a> {
|
||||
old_path: Cow<'a, str>,
|
||||
new_path: Cow<'a, str>,
|
||||
@@ -492,7 +465,6 @@ mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
@@ -817,137 +789,6 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
let final_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
apply_diff(diff, &project, &mut cx.to_async())
|
||||
.await
|
||||
.expect_err("Non-unique edits should fail");
|
||||
|
||||
let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
|
||||
..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
|
||||
|
||||
let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(edits, None, cx);
|
||||
assert_eq!(buffer.text(), final_text);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one two three four
|
||||
five six seven eight
|
||||
nine ten eleven twelve
|
||||
"# };
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one two three four
|
||||
-five six seven eight
|
||||
+five SIX seven eight!
|
||||
nine ten eleven twelve
|
||||
"#};
|
||||
|
||||
let (buffer, edits) = parse_diff(diff, |_path| {
|
||||
Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.map(|(range, text)| (range.to_point(&buffer), text))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
edits,
|
||||
&[
|
||||
(Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
|
||||
(Point::new(1, 20)..Point::new(1, 20), "!".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
@@ -1,637 +0,0 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::{cmp, ops::Range, path::Path, sync::Arc};
|
||||
|
||||
const EDITS_TAG_NAME: &'static str = "edits";
|
||||
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
|
||||
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
|
||||
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
|
||||
|
||||
pub async fn parse_xml_edits<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
parse_xml_edits_inner(input, get_buffer)
|
||||
.await
|
||||
.with_context(|| format!("Failed to parse XML edits:\n{input}"))
|
||||
}
|
||||
|
||||
async fn parse_xml_edits_inner<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let xml_edits = extract_xml_replacements(input)?;
|
||||
|
||||
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
|
||||
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
|
||||
|
||||
let mut all_edits = vec![];
|
||||
for (old_text, new_text) in xml_edits.replacements {
|
||||
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
|
||||
let matched_old_text = buffer
|
||||
.text_for_range(match_range.clone())
|
||||
.collect::<String>();
|
||||
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
|
||||
all_edits.extend(
|
||||
edits_within_hunk
|
||||
.into_iter()
|
||||
.map(move |(inner_range, inner_text)| {
|
||||
(
|
||||
buffer.anchor_after(match_range.start + inner_range.start)
|
||||
..buffer.anchor_before(match_range.start + inner_range.end),
|
||||
inner_text,
|
||||
)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
Ok((buffer, all_edits))
|
||||
}
|
||||
|
||||
fn fuzzy_match_in_ranges(
|
||||
old_text: &str,
|
||||
buffer: &BufferSnapshot,
|
||||
context_ranges: &[Range<Anchor>],
|
||||
) -> Result<Range<usize>> {
|
||||
let mut state = FuzzyMatcher::new(buffer, old_text);
|
||||
let mut best_match = None;
|
||||
let mut tie_match_range = None;
|
||||
|
||||
for range in context_ranges {
|
||||
let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
|
||||
match (best_match_cost, state.match_range(range.to_offset(buffer))) {
|
||||
(Some(lowest_cost), Some((new_cost, new_range))) => {
|
||||
if new_cost == lowest_cost {
|
||||
tie_match_range = Some(new_range);
|
||||
} else if new_cost < lowest_cost {
|
||||
tie_match_range.take();
|
||||
best_match = Some((new_cost, new_range));
|
||||
}
|
||||
}
|
||||
(None, Some(new_match)) => {
|
||||
best_match = Some(new_match);
|
||||
}
|
||||
(None, None) | (Some(_), None) => {}
|
||||
};
|
||||
}
|
||||
|
||||
if let Some((_, best_match_range)) = best_match {
|
||||
if let Some(tie_match_range) = tie_match_range {
|
||||
anyhow::bail!(
|
||||
"Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
|
||||
best_match_range.clone(),
|
||||
buffer.text_for_range(best_match_range).collect::<String>(),
|
||||
tie_match_range.clone(),
|
||||
buffer.text_for_range(tie_match_range).collect::<String>()
|
||||
);
|
||||
}
|
||||
return Ok(best_match_range);
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
|
||||
old_text,
|
||||
context_ranges
|
||||
.iter()
|
||||
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
|
||||
.collect::<Vec<String>>()
|
||||
.join("```\n```")
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct XmlEdits<'a> {
|
||||
file_path: &'a str,
|
||||
/// Vec of (old_text, new_text) pairs
|
||||
replacements: Vec<(&'a str, &'a str)>,
|
||||
}
|
||||
|
||||
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
|
||||
let mut cursor = 0;
|
||||
|
||||
let (edits_body_start, edits_attrs) =
|
||||
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
|
||||
|
||||
let file_path = edits_attrs
|
||||
.trim_start()
|
||||
.strip_prefix("path")
|
||||
.context("no path attribute on edits tag")?
|
||||
.trim_end()
|
||||
.strip_prefix('=')
|
||||
.context("no value for path attribute")?
|
||||
.trim()
|
||||
.trim_start_matches('"')
|
||||
.trim_end_matches('"');
|
||||
|
||||
cursor = edits_body_start;
|
||||
let mut edits_list = Vec::new();
|
||||
|
||||
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
|
||||
let old_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
|
||||
|
||||
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
|
||||
.context("no new_text tag following old_text")?;
|
||||
let new_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
|
||||
|
||||
edits_list.push((old_text, new_text));
|
||||
}
|
||||
|
||||
Ok(XmlEdits {
|
||||
file_path,
|
||||
replacements: edits_list,
|
||||
})
|
||||
}
|
||||
|
||||
/// Trims a single leading and trailing newline
|
||||
fn trim_surrounding_newlines(input: &str) -> &str {
|
||||
let start = input.strip_prefix('\n').unwrap_or(input);
|
||||
let end = start.strip_suffix('\n').unwrap_or(start);
|
||||
end
|
||||
}
|
||||
|
||||
fn find_tag_open<'a>(
|
||||
input: &'a str,
|
||||
cursor: &mut usize,
|
||||
expected_tag: &str,
|
||||
) -> Result<Option<(usize, &'a str)>> {
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() {
|
||||
let Some(tag_start) = input[search_pos..].find("<") else {
|
||||
break;
|
||||
};
|
||||
let tag_start = search_pos + tag_start;
|
||||
if !input[tag_start + 1..].starts_with(expected_tag) {
|
||||
search_pos = search_pos + tag_start + 1;
|
||||
continue;
|
||||
};
|
||||
|
||||
let after_tag_name = tag_start + expected_tag.len() + 1;
|
||||
let close_bracket = input[after_tag_name..]
|
||||
.find('>')
|
||||
.with_context(|| format!("missing > after <{}", expected_tag))?;
|
||||
let attrs_end = after_tag_name + close_bracket;
|
||||
let body_start = attrs_end + 1;
|
||||
|
||||
let attributes = input[after_tag_name..attrs_end].trim();
|
||||
*cursor = body_start;
|
||||
|
||||
return Ok(Some((body_start, attributes)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
|
||||
let mut depth = 1;
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() && depth > 0 {
|
||||
let Some(bracket_offset) = input[search_pos..].find('<') else {
|
||||
break;
|
||||
};
|
||||
let bracket_pos = search_pos + bracket_offset;
|
||||
|
||||
if input[bracket_pos..].starts_with("</")
|
||||
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
|
||||
{
|
||||
let close_start = bracket_pos + 2;
|
||||
let tag_name = input[close_start..close_start + close_end].trim();
|
||||
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
*cursor = close_start + close_end + 1;
|
||||
return Ok(bracket_pos);
|
||||
}
|
||||
}
|
||||
search_pos = close_start + close_end + 1;
|
||||
continue;
|
||||
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
|
||||
let close_bracket_pos = bracket_pos + close_bracket_offset;
|
||||
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
search_pos = bracket_pos + 1;
|
||||
}
|
||||
|
||||
anyhow::bail!("no closing tag found")
|
||||
}
|
||||
|
||||
const REPLACEMENT_COST: u32 = 1;
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const DELETION_COST: u32 = 10;
|
||||
|
||||
/// A fuzzy matcher that can process text chunks incrementally
|
||||
/// and return the best match found so far at each step.
|
||||
struct FuzzyMatcher<'a> {
|
||||
snapshot: &'a BufferSnapshot,
|
||||
query_lines: Vec<&'a str>,
|
||||
matrix: SearchMatrix,
|
||||
}
|
||||
|
||||
impl<'a> FuzzyMatcher<'a> {
|
||||
fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
|
||||
let query_lines = old_text.lines().collect();
|
||||
Self {
|
||||
snapshot,
|
||||
query_lines,
|
||||
matrix: SearchMatrix::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
|
||||
let point_range = range.to_point(&self.snapshot);
|
||||
let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
|
||||
|
||||
self.matrix
|
||||
.reset(self.query_lines.len() + 1, buffer_line_count + 1);
|
||||
let query_line_count = self.query_lines.len();
|
||||
|
||||
for row in 0..query_line_count {
|
||||
let query_line = self.query_lines[row].trim();
|
||||
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
|
||||
|
||||
self.matrix.set(
|
||||
row + 1,
|
||||
0,
|
||||
SearchState::new(leading_deletion_cost, SearchDirection::Up),
|
||||
);
|
||||
|
||||
let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
|
||||
|
||||
let mut col = 0;
|
||||
while let Some(buffer_line) = buffer_lines.next() {
|
||||
let buffer_line = buffer_line.trim();
|
||||
let up = SearchState::new(
|
||||
self.matrix
|
||||
.get(row, col + 1)
|
||||
.cost
|
||||
.saturating_add(DELETION_COST),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
self.matrix
|
||||
.get(row + 1, col)
|
||||
.cost
|
||||
.saturating_add(INSERTION_COST),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_line == buffer_line {
|
||||
self.matrix.get(row, col).cost
|
||||
} else if fuzzy_eq(query_line, buffer_line) {
|
||||
self.matrix.get(row, col).cost + REPLACEMENT_COST
|
||||
} else {
|
||||
self.matrix
|
||||
.get(row, col)
|
||||
.cost
|
||||
.saturating_add(DELETION_COST + INSERTION_COST)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
self.matrix
|
||||
.set(row + 1, col + 1, up.min(left).min(diagonal));
|
||||
col += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Find all matches with the best cost
|
||||
let mut best_cost = u32::MAX;
|
||||
let mut matches_with_best_cost = Vec::new();
|
||||
|
||||
for col in 1..=buffer_line_count {
|
||||
let cost = self.matrix.get(query_line_count, col).cost;
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
matches_with_best_cost.clear();
|
||||
matches_with_best_cost.push(col as u32);
|
||||
} else if cost == best_cost {
|
||||
matches_with_best_cost.push(col as u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Find ranges for the matches
|
||||
for &match_end_col in &matches_with_best_cost {
|
||||
let mut matched_lines = 0;
|
||||
let mut query_row = query_line_count;
|
||||
let mut match_start_col = match_end_col;
|
||||
while query_row > 0 && match_start_col > 0 {
|
||||
let current = self.matrix.get(query_row, match_start_col as usize);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_row -= 1;
|
||||
match_start_col -= 1;
|
||||
matched_lines += 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_row -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
match_start_col -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let buffer_row_start = match_start_col + point_range.start.row;
|
||||
let buffer_row_end = match_end_col + point_range.start.row;
|
||||
|
||||
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
|
||||
let matched_ratio = matched_lines as f32
|
||||
/ (matched_buffer_row_count as f32).max(query_line_count as f32);
|
||||
if matched_ratio >= 0.8 {
|
||||
let buffer_start_ix = self
|
||||
.snapshot
|
||||
.point_to_offset(Point::new(buffer_row_start, 0));
|
||||
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
|
||||
buffer_row_end - 1,
|
||||
self.snapshot.line_len(buffer_row_end - 1),
|
||||
));
|
||||
return Some((best_cost, buffer_start_ix..buffer_end_ix));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn fuzzy_eq(left: &str, right: &str) -> bool {
|
||||
const THRESHOLD: f64 = 0.8;
|
||||
|
||||
let min_levenshtein = left.len().abs_diff(right.len());
|
||||
let min_normalized_levenshtein =
|
||||
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
|
||||
if min_normalized_levenshtein < THRESHOLD {
|
||||
return false;
|
||||
}
|
||||
|
||||
strsim::normalized_levenshtein(left, right) >= THRESHOLD
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
cost: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(cost: u32, direction: SearchDirection) -> Self {
|
||||
Self { cost, direction }
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchMatrix {
|
||||
cols: usize,
|
||||
rows: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl SearchMatrix {
|
||||
fn new(cols: usize) -> Self {
|
||||
SearchMatrix {
|
||||
cols,
|
||||
rows: 0,
|
||||
data: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, rows: usize, cols: usize) {
|
||||
self.rows = rows;
|
||||
self.cols = cols;
|
||||
self.data
|
||||
.fill(SearchState::new(0, SearchDirection::Diagonal));
|
||||
self.data.resize(
|
||||
self.rows * self.cols,
|
||||
SearchState::new(0, SearchDirection::Diagonal),
|
||||
);
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
debug_assert!(row < self.rows);
|
||||
debug_assert!(col < self.cols);
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, state: SearchState) {
|
||||
debug_assert!(row < self.rows && col < self.cols);
|
||||
self.data[row * self.cols + col] = state;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</old_text>
|
||||
<new_text>
|
||||
new content
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_wrong_closing_tags() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</new_text>
|
||||
<new_text>
|
||||
new content
|
||||
</old_text>
|
||||
</ edits >
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_xml_like_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<foo><bar></bar></foo>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<foo><bar><baz></baz></bar></foo>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
|
||||
assert_eq!(
|
||||
result.replacements[0].1,
|
||||
"<foo><bar><baz></baz></bar></foo>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_conflicting_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<new_text></new_text>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<old_text></old_text>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
|
||||
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_multiple_pairs() {
|
||||
let input = indoc! {r#"
|
||||
Some reasoning before edits. Lots of thinking going on here
|
||||
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</old_text>
|
||||
<new_text>
|
||||
first new
|
||||
</new_text>
|
||||
<old_text>
|
||||
second old
|
||||
</edits>
|
||||
<new_text>
|
||||
second new
|
||||
</old_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 2);
|
||||
assert_eq!(result.replacements[0].0, "first old");
|
||||
assert_eq!(result.replacements[0].1, "first new");
|
||||
assert_eq!(result.replacements[1].0, "second old");
|
||||
assert_eq!(result.replacements[1].1, "second new");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_unexpected_eof() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</
|
||||
"#};
|
||||
|
||||
extract_xml_replacements(input).expect_err("Unexpected end of file");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parse_xml_edits(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one two three four
|
||||
five six seven eight
|
||||
nine ten eleven twelve
|
||||
thirteen fourteen fifteen
|
||||
sixteen seventeen eighteen
|
||||
"#};
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let edits = indoc! {r#"
|
||||
<edits path="root/file1">
|
||||
<old_text>
|
||||
nine ten eleven twelve
|
||||
</old_text>
|
||||
<new_text>
|
||||
nine TEN eleven twelve!
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
|
||||
let (buffer, edits) = parse_xml_edits(edits, |_path| {
|
||||
Some((&buffer_snapshot, included_ranges.as_slice()))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.map(|(range, text)| (range.to_point(&buffer), text))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
edits,
|
||||
&[
|
||||
(Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
|
||||
(Point::new(2, 22)..Point::new(2, 22), "!".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
});
|
||||
|
||||
FakeFs::new(cx.background_executor.clone())
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,23 @@
|
||||
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
|
||||
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
|
||||
prediction::{EditPredictionInputs, EditPredictionResult},
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::{
|
||||
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
|
||||
predict_edits_v3::Event,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
|
||||
};
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::AppVersion;
|
||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
use zeta_prompt::{Event, ZetaPromptInput};
|
||||
|
||||
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
|
||||
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
|
||||
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
|
||||
|
||||
pub(crate) fn request_prediction_with_zeta1(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
EditPredictionModelInput {
|
||||
project,
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
events,
|
||||
trigger,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
|
||||
let can_collect_file = store.can_collect_file(project, file, cx);
|
||||
let can_collect_file = store.can_collect_file(&project, file, cx);
|
||||
let git_info = if can_collect_file {
|
||||
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
)
|
||||
.await;
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
let context_start_offset = context_range.start.to_offset(&snapshot);
|
||||
let editable_offset_range = editable_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = ZetaPromptInput {
|
||||
events: included_events.into(),
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
|
||||
text: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
related_files: vec![].into(),
|
||||
cursor_path: full_path,
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt: (editable_range.start - context_start_offset)
|
||||
..(editable_offset_range.end - context_start_offset),
|
||||
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
|
||||
};
|
||||
|
||||
// let response = perform_predict_edits(PerformPredictEditsParams {
|
||||
// client,
|
||||
// llm_token,
|
||||
// app_version,
|
||||
// body,
|
||||
// })
|
||||
// .await;
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
prompt: Some(serde_json::to_string(&inputs).unwrap()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (response, usage) = match response {
|
||||
Ok(response) => response,
|
||||
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
.ok();
|
||||
}
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
model_output: Some(response.output_excerpt.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let edit_prediction = process_completion_response(
|
||||
response,
|
||||
buffer,
|
||||
@@ -226,7 +242,7 @@ fn process_completion_response(
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_range: Range<usize>,
|
||||
inputs: EditPredictionInputs,
|
||||
inputs: ZetaPromptInput,
|
||||
buffer_snapshotted_at: Instant,
|
||||
received_response_at: Instant,
|
||||
cx: &AsyncApp,
|
||||
|
||||
@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
|
||||
use crate::open_ai_response::text_from_response;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
use crate::{
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
|
||||
EditPredictionRequestedDebugEvent, EditPredictionStore,
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
|
||||
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
|
||||
};
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, Line};
|
||||
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{Entity, Task, prelude::*};
|
||||
use language::{Anchor, BufferSnapshot};
|
||||
use language::{Buffer, Point, ToOffset as _, ToPoint};
|
||||
use project::{Project, ProjectItem as _};
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::EditPredictionRejectReason;
|
||||
use gpui::{Task, prelude::*};
|
||||
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
|
||||
use release_channel::AppVersion;
|
||||
use std::{
|
||||
env,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
|
||||
const MAX_CONTEXT_TOKENS: usize = 150;
|
||||
const MAX_REWRITE_TOKENS: usize = 350;
|
||||
|
||||
pub fn request_prediction_with_zeta2(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
active_snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
mut included_files: Vec<RelatedFile>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
EditPredictionModelInput {
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
related_files,
|
||||
events,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let options = store.options.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let Some((excerpt_path, active_project_path)) = active_snapshot
|
||||
let Some(excerpt_path) = snapshot
|
||||
.file()
|
||||
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
|
||||
.zip(active_buffer.read(cx).project_path(cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("No file path for excerpt")));
|
||||
};
|
||||
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
let debug_tx = store.debug_tx.clone();
|
||||
|
||||
let file = active_buffer.read(cx).file();
|
||||
|
||||
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
|
||||
|
||||
// TODO data collection
|
||||
let can_collect_data = file
|
||||
.as_ref()
|
||||
.map_or(false, |file| store.can_collect_file(project, file, cx));
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
let eval_cache = store.eval_cache.clone();
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
let active_buffer = active_buffer.clone();
|
||||
async move {
|
||||
let cursor_offset = position.to_offset(&active_snapshot);
|
||||
let cursor_point = cursor_offset.to_point(&active_snapshot);
|
||||
|
||||
let before_retrieval = Instant::now();
|
||||
|
||||
let excerpt_options = options.context;
|
||||
|
||||
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
&active_snapshot,
|
||||
&excerpt_options,
|
||||
) else {
|
||||
return Ok((None, None));
|
||||
};
|
||||
|
||||
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
|
||||
..active_snapshot.anchor_before(excerpt.range.end);
|
||||
let related_excerpt = RelatedExcerpt {
|
||||
anchor_range: excerpt_anchor_range.clone(),
|
||||
point_range: Point::new(excerpt.line_range.start.0, 0)
|
||||
..Point::new(excerpt.line_range.end.0, 0),
|
||||
text: active_snapshot.as_rope().slice(excerpt.range),
|
||||
};
|
||||
|
||||
if let Some(buffer_ix) = included_files
|
||||
.iter()
|
||||
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
|
||||
{
|
||||
let file = &mut included_files[buffer_ix];
|
||||
file.excerpts.push(related_excerpt);
|
||||
file.merge_excerpts();
|
||||
let last_ix = included_files.len() - 1;
|
||||
included_files.swap(buffer_ix, last_ix);
|
||||
} else {
|
||||
let active_file = RelatedFile {
|
||||
path: active_project_path,
|
||||
buffer: active_buffer.downgrade(),
|
||||
excerpts: vec![related_excerpt],
|
||||
max_row: active_snapshot.max_point().row,
|
||||
};
|
||||
included_files.push(active_file);
|
||||
}
|
||||
|
||||
let included_files = included_files
|
||||
.iter()
|
||||
.map(|related_file| predict_edits_v3::RelatedFile {
|
||||
path: Arc::from(related_file.path.path.as_std_path()),
|
||||
max_row: Line(related_file.max_row),
|
||||
excerpts: related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| predict_edits_v3::Excerpt {
|
||||
start_line: Line(excerpt.point_range.start.row),
|
||||
text: excerpt.text.to_string().into(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cloud_request = predict_edits_v3::PredictEditsRequest {
|
||||
excerpt_path,
|
||||
excerpt: String::new(),
|
||||
excerpt_line_range: Line(0)..Line(0),
|
||||
excerpt_range: 0..0,
|
||||
cursor_point: predict_edits_v3::Point {
|
||||
line: predict_edits_v3::Line(cursor_point.row),
|
||||
column: cursor_point.column,
|
||||
},
|
||||
related_files: included_files,
|
||||
let cursor_offset = position.to_offset(&snapshot);
|
||||
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
|
||||
&snapshot,
|
||||
related_files,
|
||||
events,
|
||||
can_collect_data,
|
||||
debug_info: debug_tx.is_some(),
|
||||
prompt_max_bytes: Some(options.max_prompt_bytes),
|
||||
prompt_format: options.prompt_format,
|
||||
excerpt_parent: None,
|
||||
git_info: None,
|
||||
trigger,
|
||||
};
|
||||
excerpt_path,
|
||||
cursor_offset,
|
||||
);
|
||||
|
||||
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
included_files: cloud_request.related_files,
|
||||
events: cloud_request.events,
|
||||
cursor_point: cloud_request.cursor_point,
|
||||
cursor_path: cloud_request.excerpt_path,
|
||||
};
|
||||
|
||||
let retrieval_time = Instant::now() - before_retrieval;
|
||||
|
||||
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let prompt = format_zeta_prompt(&prompt_input);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionRequested(
|
||||
EditPredictionRequestedDebugEvent {
|
||||
inputs: inputs.clone(),
|
||||
retrieval_time,
|
||||
buffer: active_buffer.downgrade(),
|
||||
local_prompt: match prompt_result.as_ref() {
|
||||
Ok(prompt) => Ok(prompt.clone()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
prompt: Some(prompt.clone()),
|
||||
position,
|
||||
response_rx,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
Some(response_tx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((Err("Request skipped".to_string()), Duration::ZERO))
|
||||
.ok();
|
||||
}
|
||||
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
|
||||
}
|
||||
|
||||
let prompt = prompt_result?;
|
||||
let generation_params =
|
||||
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
|
||||
let request = open_ai::Request {
|
||||
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: generation_params.stop.unwrap_or_default(),
|
||||
temperature: generation_params.temperature.or(Some(0.7)),
|
||||
stop: Default::default(),
|
||||
temperature: Default::default(),
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
|
||||
|
||||
log::trace!("Sending edit prediction request");
|
||||
|
||||
let before_request = Instant::now();
|
||||
let response = EditPredictionStore::send_raw_llm_request(
|
||||
request,
|
||||
client,
|
||||
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
|
||||
)
|
||||
.await;
|
||||
let received_response_at = Instant::now();
|
||||
let request_time = received_response_at - before_request;
|
||||
|
||||
log::trace!("Got edit prediction response");
|
||||
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((
|
||||
response
|
||||
.as_ref()
|
||||
.map_err(|err| err.to_string())
|
||||
.map(|response| response.0.clone()),
|
||||
request_time,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (res, usage) = response?;
|
||||
let request_id = EditPredictionId(res.id.clone().into());
|
||||
let Some(mut output_text) = text_from_response(res) else {
|
||||
return Ok((Some((request_id, None)), usage));
|
||||
};
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
position,
|
||||
model_output: Some(output_text.clone()),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
if output_text.contains(CURSOR_MARKER) {
|
||||
log::trace!("Stripping out {CURSOR_MARKER} from response");
|
||||
output_text = output_text.replace(CURSOR_MARKER, "");
|
||||
}
|
||||
|
||||
let get_buffer_from_context = |path: &Path| {
|
||||
if Some(path) == active_file_full_path.as_deref() {
|
||||
Some((
|
||||
&active_snapshot,
|
||||
std::slice::from_ref(&excerpt_anchor_range),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (_, edits) = match options.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
if output_text.contains("--- a/\n+++ b/\nNo edits") {
|
||||
let edits = vec![];
|
||||
(&active_snapshot, edits)
|
||||
} else {
|
||||
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
_ => {
|
||||
bail!("unsupported prompt format {}", options.prompt_format)
|
||||
}
|
||||
};
|
||||
let old_text = snapshot
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(editable_offset_range.start + range.start)
|
||||
..snapshot.anchor_before(editable_offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
anyhow::Ok((
|
||||
Some((
|
||||
request_id,
|
||||
Some((
|
||||
inputs,
|
||||
active_buffer,
|
||||
active_snapshot.clone(),
|
||||
prompt_input,
|
||||
buffer,
|
||||
snapshot.clone(),
|
||||
edits,
|
||||
received_response_at,
|
||||
)),
|
||||
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn zeta2_prompt_input(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
related_files: Arc<[zeta_prompt::RelatedFile]>,
|
||||
events: Vec<Arc<zeta_prompt::Event>>,
|
||||
excerpt_path: Arc<Path>,
|
||||
cursor_offset: usize,
|
||||
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
|
||||
let cursor_point = cursor_offset.to_point(snapshot);
|
||||
|
||||
let (editable_range, context_range) =
|
||||
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
|
||||
cursor_point,
|
||||
snapshot,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let context_start_offset = context_range.start.to_offset(snapshot);
|
||||
let editable_offset_range = editable_range.to_offset(snapshot);
|
||||
let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
|
||||
let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
|
||||
..(editable_offset_range.end - context_start_offset);
|
||||
|
||||
let prompt_input = zeta_prompt::ZetaPromptInput {
|
||||
cursor_path: excerpt_path,
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt,
|
||||
cursor_offset_in_excerpt,
|
||||
events,
|
||||
related_files,
|
||||
};
|
||||
(editable_offset_range, prompt_input)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "ep_cli"
|
||||
name = "ep"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
@@ -20,10 +20,9 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace= true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
dirs.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -51,12 +50,21 @@ smol.workspace = true
|
||||
sqlez.workspace = true
|
||||
sqlez_macros.workspace = true
|
||||
terminal_view.workspace = true
|
||||
toml.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
edit_prediction = { workspace = true, features = ["eval-support"] }
|
||||
wasmtime.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
zlog.workspace = true
|
||||
|
||||
# Wasmtime is included as a dependency in order to enable the same
|
||||
# features that are enabled in Zed.
|
||||
#
|
||||
# If we don't enable these features we get crashes when creating
|
||||
# a Tree-sitter WasmStore.
|
||||
[package.metadata.cargo-machete]
|
||||
ignored = ["wasmtime"]
|
||||
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -5,11 +5,13 @@ use anthropic::{
|
||||
use anyhow::Result;
|
||||
use http_client::HttpClient;
|
||||
use indoc::indoc;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use sqlez::bindable::Bind;
|
||||
use sqlez::bindable::StaticColumnCount;
|
||||
use sqlez_macros::sql;
|
||||
use std::hash::Hash;
|
||||
use std::hash::Hasher;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PlainLlmClient {
|
||||
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
|
||||
}
|
||||
|
||||
impl PlainLlmClient {
|
||||
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
fn new() -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
Ok(Self {
|
||||
@@ -29,12 +32,12 @@ impl PlainLlmClient {
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<AnthropicResponse> {
|
||||
let request = AnthropicRequest {
|
||||
model,
|
||||
model: model.to_string(),
|
||||
max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
@@ -105,11 +108,12 @@ struct SerializableMessage {
|
||||
}
|
||||
|
||||
impl BatchingLlmClient {
|
||||
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
fn new(cache_path: &Path) -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path);
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
|
||||
let mut statement = sqlez::statement::Statement::prepare(
|
||||
&connection,
|
||||
indoc! {"
|
||||
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
let response = self.lookup(&model, max_tokens, &messages)?;
|
||||
let response = self.lookup(model, max_tokens, &messages)?;
|
||||
if let Some(response) = response {
|
||||
return Ok(Some(response));
|
||||
}
|
||||
|
||||
self.mark_for_batch(&model, max_tokens, &messages)?;
|
||||
self.mark_for_batch(model, max_tokens, &messages)?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
log::info!("Uploaded {} successful requests", success_count);
|
||||
log::info!("Downloaded {} successful requests", success_count);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub enum LlmClient {
|
||||
pub enum AnthropicClient {
|
||||
// No batching
|
||||
Plain(PlainLlmClient),
|
||||
Batch(BatchingLlmClient),
|
||||
Dummy,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
|
||||
impl AnthropicClient {
|
||||
pub fn plain() -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new()?))
|
||||
}
|
||||
|
||||
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(
|
||||
cache_path,
|
||||
http_client,
|
||||
)?))
|
||||
pub fn batch(cache_path: &Path) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
@@ -389,29 +390,29 @@ impl LlmClient {
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
match self {
|
||||
LlmClient::Plain(plain_llm_client) => plain_llm_client
|
||||
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
.map(Some),
|
||||
LlmClient::Batch(batching_llm_client) => {
|
||||
AnthropicClient::Batch(batching_llm_client) => {
|
||||
batching_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
}
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync_batches(&self) -> Result<()> {
|
||||
match self {
|
||||
LlmClient::Plain(_) => Ok(()),
|
||||
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
AnthropicClient::Plain(_) => Ok(()),
|
||||
AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,641 +0,0 @@
|
||||
use crate::metrics::{self, Scores};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{
|
||||
EvaluateArguments, PredictionOptions,
|
||||
example::{Example, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
paths::print_run_data_dir,
|
||||
predict::{PredictionDetails, perform_predict, setup_store},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecutionData {
|
||||
execution_id: String,
|
||||
diff: String,
|
||||
reasoning: String,
|
||||
}
|
||||
|
||||
pub async fn run_evaluate(
|
||||
args: EvaluateArguments,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
if args.example_paths.is_empty() {
|
||||
eprintln!("No examples provided");
|
||||
return;
|
||||
}
|
||||
|
||||
let all_tasks = args.example_paths.into_iter().map(|path| {
|
||||
let options = args.options.clone();
|
||||
let app_state = app_state.clone();
|
||||
let example = NamedExample::load(&path).expect("Failed to load example");
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let project = example.setup_project(&app_state, cx).await.unwrap();
|
||||
|
||||
let providers = (0..args.repetitions)
|
||||
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
|
||||
let tasks = providers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(repetition_ix, store)| {
|
||||
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
|
||||
let example = example.clone();
|
||||
let project = project.clone();
|
||||
let options = options.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let name = example.name.clone();
|
||||
run_evaluate_one(
|
||||
example,
|
||||
repetition_ix,
|
||||
project,
|
||||
store,
|
||||
options,
|
||||
!args.skip_prediction,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| (err, name, repetition_ix))
|
||||
})
|
||||
});
|
||||
futures::future::join_all(tasks).await
|
||||
})
|
||||
});
|
||||
let all_results = futures::future::join_all(all_tasks).await;
|
||||
|
||||
write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
|
||||
if let Some(mut output_file) =
|
||||
std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
|
||||
{
|
||||
write_aggregated_scores(&mut output_file, &all_results).log_err();
|
||||
};
|
||||
|
||||
if args.repetitions > 1 {
|
||||
if let Err(e) = write_bucketed_analysis(&all_results) {
|
||||
eprintln!("Failed to write bucketed analysis: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
|
||||
}
|
||||
|
||||
fn write_aggregated_scores(
|
||||
w: &mut impl std::io::Write,
|
||||
all_results: &Vec<
|
||||
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
let mut successful = Vec::new();
|
||||
let mut failed_count = 0;
|
||||
|
||||
for result in all_results.iter().flatten() {
|
||||
match result {
|
||||
Ok((eval_result, _execution_data)) => successful.push(eval_result),
|
||||
Err((err, name, repetition_ix)) => {
|
||||
if failed_count == 0 {
|
||||
writeln!(w, "## Errors\n")?;
|
||||
}
|
||||
|
||||
failed_count += 1;
|
||||
writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if successful.len() > 1 {
|
||||
let edit_scores = successful
|
||||
.iter()
|
||||
.filter_map(|r| r.edit_scores.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let has_edit_predictions = edit_scores.len() > 0;
|
||||
let aggregated_result = EvaluationResult {
|
||||
context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
|
||||
edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
|
||||
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
|
||||
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
|
||||
/ successful.len(),
|
||||
};
|
||||
|
||||
writeln!(w, "\n{}", "-".repeat(80))?;
|
||||
writeln!(w, "\n## TOTAL SCORES")?;
|
||||
writeln!(w, "{:#}", aggregated_result)?;
|
||||
}
|
||||
|
||||
if successful.len() + failed_count > 1 {
|
||||
writeln!(
|
||||
w,
|
||||
"\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
|
||||
successful.len(),
|
||||
successful.len() + failed_count,
|
||||
(successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_evaluate_one(
|
||||
example: NamedExample,
|
||||
repetition_ix: Option<u16>,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
prediction_options: PredictionOptions,
|
||||
predict: bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(EvaluationResult, ExecutionData)> {
|
||||
let predict_result = perform_predict(
|
||||
example.clone(),
|
||||
project,
|
||||
store,
|
||||
repetition_ix,
|
||||
prediction_options,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let evaluation_result = evaluate(&example.example, &predict_result, predict);
|
||||
|
||||
if repetition_ix.is_none() {
|
||||
write_eval_result(
|
||||
&example,
|
||||
&predict_result,
|
||||
&evaluation_result,
|
||||
&mut std::io::stdout(),
|
||||
std::io::stdout().is_terminal(),
|
||||
predict,
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(mut results_file) =
|
||||
std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
|
||||
{
|
||||
write_eval_result(
|
||||
&example,
|
||||
&predict_result,
|
||||
&evaluation_result,
|
||||
&mut results_file,
|
||||
false,
|
||||
predict,
|
||||
)
|
||||
.log_err();
|
||||
}
|
||||
|
||||
let execution_data = ExecutionData {
|
||||
execution_id: if let Some(rep_ix) = repetition_ix {
|
||||
format!("{:03}", rep_ix)
|
||||
} else {
|
||||
example.name.clone()
|
||||
},
|
||||
diff: predict_result.diff.clone(),
|
||||
reasoning: std::fs::read_to_string(
|
||||
predict_result
|
||||
.run_example_dir
|
||||
.join("prediction_response.md"),
|
||||
)
|
||||
.unwrap_or_default(),
|
||||
};
|
||||
|
||||
anyhow::Ok((evaluation_result, execution_data))
|
||||
}
|
||||
|
||||
fn write_eval_result(
|
||||
example: &NamedExample,
|
||||
predictions: &PredictionDetails,
|
||||
evaluation_result: &EvaluationResult,
|
||||
out: &mut impl Write,
|
||||
use_color: bool,
|
||||
predict: bool,
|
||||
) -> Result<()> {
|
||||
if predict {
|
||||
writeln!(
|
||||
out,
|
||||
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
|
||||
compare_diffs(
|
||||
&example.example.expected_patch,
|
||||
&predictions.diff,
|
||||
use_color
|
||||
)
|
||||
)?;
|
||||
writeln!(
|
||||
out,
|
||||
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
|
||||
compare_diffs(
|
||||
&predictions.diff,
|
||||
&example.example.expected_patch,
|
||||
use_color
|
||||
)
|
||||
)?;
|
||||
}
|
||||
|
||||
writeln!(out, "{:#}", evaluation_result)?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct EditScores {
|
||||
pub line_match: Scores,
|
||||
pub chr_f: f64,
|
||||
}
|
||||
|
||||
impl EditScores {
|
||||
pub fn aggregate(scores: &[EditScores]) -> EditScores {
|
||||
let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
|
||||
let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
|
||||
|
||||
EditScores { line_match, chr_f }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EvaluationResult {
|
||||
pub edit_scores: Option<EditScores>,
|
||||
pub context_scores: Scores,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EvaluationResult {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if f.alternate() {
|
||||
self.fmt_table(f)
|
||||
} else {
|
||||
self.fmt_markdown(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EvaluationResult {
|
||||
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Context Scores
|
||||
{}
|
||||
"#,
|
||||
self.context_scores.to_markdown(),
|
||||
)?;
|
||||
if let Some(scores) = &self.edit_scores {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Edit Prediction Scores
|
||||
{}"#,
|
||||
scores.line_match.to_markdown()
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "#### Prompt Statistics")?;
|
||||
writeln!(f, "─────────────────────────")?;
|
||||
writeln!(f, "Prompt_len Generated_len")?;
|
||||
writeln!(f, "─────────────────────────")?;
|
||||
writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f, "#### Performance Scores")?;
|
||||
writeln!(
|
||||
f,
|
||||
"──────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
" TP FP FN Precision Recall F1"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"──────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
|
||||
self.context_scores.true_positives,
|
||||
self.context_scores.false_positives,
|
||||
self.context_scores.false_negatives,
|
||||
self.context_scores.precision() * 100.0,
|
||||
self.context_scores.recall() * 100.0,
|
||||
self.context_scores.f1_score() * 100.0
|
||||
)?;
|
||||
if let Some(edit_scores) = &self.edit_scores {
|
||||
let line_match = &edit_scores.line_match;
|
||||
writeln!(f, "Edit Prediction")?;
|
||||
writeln!(
|
||||
f,
|
||||
" ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
" └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
|
||||
"-", "-", "-", "-", "-", edit_scores.chr_f
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
|
||||
let mut eval_result = EvaluationResult {
|
||||
prompt_len: preds.prompt_len,
|
||||
generated_len: preds.generated_len,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if predict {
|
||||
// todo: alternatives for patches
|
||||
let expected_patch = example
|
||||
.expected_patch
|
||||
.lines()
|
||||
.map(DiffLine::parse)
|
||||
.collect::<Vec<_>>();
|
||||
let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
|
||||
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
|
||||
|
||||
eval_result.edit_scores = Some(EditScores { line_match, chr_f });
|
||||
}
|
||||
|
||||
eval_result
|
||||
}
|
||||
|
||||
/// Return annotated `patch_a` so that:
|
||||
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
|
||||
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
|
||||
pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
|
||||
let green = if use_color { "\x1b[32m✓ " } else { "" };
|
||||
let red = if use_color { "\x1b[31m✗ " } else { "" };
|
||||
let neutral = if use_color { " " } else { "" };
|
||||
let reset = if use_color { "\x1b[0m" } else { "" };
|
||||
let lines_a = patch_a.lines().map(DiffLine::parse);
|
||||
let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
|
||||
|
||||
let annotated = lines_a
|
||||
.map(|line| match line {
|
||||
DiffLine::Addition(_) | DiffLine::Deletion(_) => {
|
||||
if lines_b.contains(&line) {
|
||||
format!("{green}{line}{reset}")
|
||||
} else {
|
||||
format!("{red}{line}{reset}")
|
||||
}
|
||||
}
|
||||
_ => format!("{neutral}{line}{reset}"),
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
annotated.join("\n")
|
||||
}
|
||||
|
||||
fn write_bucketed_analysis(
|
||||
all_results: &Vec<
|
||||
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
#[derive(Debug)]
|
||||
struct EditBucket {
|
||||
diff: String,
|
||||
is_correct: bool,
|
||||
execution_indices: Vec<String>,
|
||||
reasoning_samples: Vec<String>,
|
||||
}
|
||||
|
||||
let mut total_executions = 0;
|
||||
let mut empty_predictions = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
let mut buckets: HashMap<String, EditBucket> = HashMap::new();
|
||||
|
||||
for result in all_results.iter().flatten() {
|
||||
total_executions += 1;
|
||||
|
||||
let (evaluation_result, execution_data) = match result {
|
||||
Ok((eval_result, execution_data)) => {
|
||||
if execution_data.diff.is_empty() {
|
||||
empty_predictions.push(execution_data);
|
||||
continue;
|
||||
}
|
||||
(eval_result, execution_data)
|
||||
}
|
||||
Err(err) => {
|
||||
errors.push(err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
buckets
|
||||
.entry(execution_data.diff.clone())
|
||||
.and_modify(|bucket| {
|
||||
bucket
|
||||
.execution_indices
|
||||
.push(execution_data.execution_id.clone());
|
||||
bucket
|
||||
.reasoning_samples
|
||||
.push(execution_data.reasoning.clone());
|
||||
})
|
||||
.or_insert_with(|| EditBucket {
|
||||
diff: execution_data.diff.clone(),
|
||||
is_correct: {
|
||||
evaluation_result
|
||||
.edit_scores
|
||||
.as_ref()
|
||||
.map_or(false, |edit_scores| {
|
||||
edit_scores.line_match.false_positives == 0
|
||||
&& edit_scores.line_match.false_negatives == 0
|
||||
&& edit_scores.line_match.true_positives > 0
|
||||
})
|
||||
},
|
||||
execution_indices: vec![execution_data.execution_id.clone()],
|
||||
reasoning_samples: vec![execution_data.reasoning.clone()],
|
||||
});
|
||||
}
|
||||
|
||||
let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
|
||||
sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
|
||||
(true, false) => std::cmp::Ordering::Less,
|
||||
(false, true) => std::cmp::Ordering::Greater,
|
||||
_ => b.execution_indices.len().cmp(&a.execution_indices.len()),
|
||||
});
|
||||
|
||||
let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
|
||||
let mut output = std::fs::File::create(&output_path)?;
|
||||
|
||||
writeln!(output, "# Bucketed Edit Analysis\n")?;
|
||||
|
||||
writeln!(output, "## Summary\n")?;
|
||||
writeln!(output, "- **Total executions**: {}", total_executions)?;
|
||||
|
||||
let correct_count: usize = sorted_buckets
|
||||
.iter()
|
||||
.filter(|b| b.is_correct)
|
||||
.map(|b| b.execution_indices.len())
|
||||
.sum();
|
||||
|
||||
let incorrect_count: usize = sorted_buckets
|
||||
.iter()
|
||||
.filter(|b| !b.is_correct)
|
||||
.map(|b| b.execution_indices.len())
|
||||
.sum();
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **Correct predictions**: {} ({:.1}%)",
|
||||
correct_count,
|
||||
(correct_count as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **Incorrect predictions**: {} ({:.1}%)",
|
||||
incorrect_count,
|
||||
(incorrect_count as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **No Predictions**: {} ({:.1}%)",
|
||||
empty_predictions.len(),
|
||||
(empty_predictions.len() as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
|
||||
writeln!(
|
||||
output,
|
||||
"- **Unique incorrect edit patterns**: {}\n",
|
||||
unique_incorrect
|
||||
)?;
|
||||
|
||||
writeln!(output, "---\n")?;
|
||||
|
||||
for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
|
||||
if idx == 0 {
|
||||
writeln!(
|
||||
output,
|
||||
"## Correct Predictions ({} occurrences)\n",
|
||||
bucket.execution_indices.len()
|
||||
)?;
|
||||
}
|
||||
|
||||
writeln!(output, "**Predicted Edit:**\n")?;
|
||||
writeln!(output, "```diff")?;
|
||||
writeln!(output, "{}", bucket.diff)?;
|
||||
writeln!(output, "```\n")?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"**Executions:** {}\n",
|
||||
bucket.execution_indices.join(", ")
|
||||
)?;
|
||||
writeln!(output, "---\n")?;
|
||||
}
|
||||
|
||||
for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
|
||||
writeln!(
|
||||
output,
|
||||
"## Incorrect Prediction #{} ({} occurrences)\n",
|
||||
idx + 1,
|
||||
bucket.execution_indices.len()
|
||||
)?;
|
||||
|
||||
writeln!(output, "**Predicted Edit:**\n")?;
|
||||
writeln!(output, "```diff")?;
|
||||
writeln!(output, "{}", bucket.diff)?;
|
||||
writeln!(output, "```\n")?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"**Executions:** {}\n",
|
||||
bucket.execution_indices.join(", ")
|
||||
)?;
|
||||
|
||||
for (exec_id, reasoning) in bucket
|
||||
.execution_indices
|
||||
.iter()
|
||||
.zip(bucket.reasoning_samples.iter())
|
||||
{
|
||||
writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
|
||||
}
|
||||
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
if !empty_predictions.is_empty() {
|
||||
writeln!(
|
||||
output,
|
||||
"## No Predictions ({} occurrences)\n",
|
||||
empty_predictions.len()
|
||||
)?;
|
||||
|
||||
for execution_data in &empty_predictions {
|
||||
writeln!(
|
||||
output,
|
||||
"{}",
|
||||
fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
|
||||
)?;
|
||||
}
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
|
||||
|
||||
for (err, name, repetition_ix) in &errors {
|
||||
writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
|
||||
}
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
|
||||
let exec_content = format!(
|
||||
"\n### Execution {} `{}/{}/prediction_response.md`{}",
|
||||
exec_id,
|
||||
crate::paths::RUN_DIR.display(),
|
||||
exec_id,
|
||||
indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
|
||||
);
|
||||
indent_text(&exec_content, 2)
|
||||
}
|
||||
|
||||
fn indent_text(text: &str, spaces: usize) -> String {
|
||||
let indent = " ".repeat(spaces);
|
||||
text.lines()
|
||||
.collect::<Vec<_>>()
|
||||
.join(&format!("\n{}", indent))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
|
||||
let err = format!("{err:?}")
|
||||
.replace("<edits", "```xml\n<edits")
|
||||
.replace("</edits>", "</edits>\n```");
|
||||
format!(
|
||||
"### ERROR {name}{}\n\n{err}\n",
|
||||
repetition_ix
|
||||
.map(|ix| format!(" [RUN {ix:03}]"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
@@ -1,59 +1,103 @@
|
||||
use crate::{
|
||||
PredictionProvider, PromptFormat,
|
||||
metrics::ClassificationMetrics,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use gpui::Entity;
|
||||
use http_client::Url;
|
||||
use language::{Anchor, Buffer};
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
fmt::{self, Display},
|
||||
fs,
|
||||
hash::Hash,
|
||||
hash::Hasher,
|
||||
io::Write,
|
||||
io::{Read, Write},
|
||||
mem,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, OnceLock},
|
||||
};
|
||||
use zeta_prompt::RelatedFile;
|
||||
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::ValueEnum;
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use collections::HashMap;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use gpui::{AsyncApp, Entity, Task, http_client::Url};
|
||||
use language::{Anchor, Buffer};
|
||||
use project::{Project, ProjectPath};
|
||||
use pulldown_cmark::CowStr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
|
||||
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
|
||||
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NamedExample {
|
||||
pub name: String,
|
||||
pub example: Example,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Example {
|
||||
#[serde(default)]
|
||||
pub name: String,
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
pub uncommitted_diff: String,
|
||||
pub cursor_path: PathBuf,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_position: String,
|
||||
pub edit_history: String,
|
||||
pub expected_patch: String,
|
||||
|
||||
/// The full content of the file where an edit is being predicted, and the
|
||||
/// actual cursor offset.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub buffer: Option<ExampleBuffer>,
|
||||
|
||||
/// The context retrieved for the prediction. This requires the worktree to
|
||||
/// be loaded and the language server to be started.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<ExampleContext>,
|
||||
|
||||
/// The input and expected output from the edit prediction model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<ExamplePrompt>,
|
||||
|
||||
/// The actual predictions from the model.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub predictions: Vec<ExamplePrediction>,
|
||||
|
||||
/// The scores, for how well the actual predictions match the expected
|
||||
/// predictions.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub score: Vec<ExampleScore>,
|
||||
|
||||
/// The application state used to process this example.
|
||||
#[serde(skip)]
|
||||
pub state: Option<ExampleState>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExampleState {
|
||||
pub project: Entity<Project>,
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub cursor_position: Anchor,
|
||||
pub _open_buffers: OpenedBuffers,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleContext {
|
||||
pub files: Arc<[RelatedFile]>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleBuffer {
|
||||
pub content: String,
|
||||
pub cursor_row: u32,
|
||||
pub cursor_column: u32,
|
||||
pub cursor_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrompt {
|
||||
pub input: String,
|
||||
pub expected_output: String,
|
||||
pub format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrediction {
|
||||
pub actual_patch: String,
|
||||
pub actual_output: String,
|
||||
pub provider: PredictionProvider,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleScore {
|
||||
pub delta_chr_f: f32,
|
||||
pub line_match: ClassificationMetrics,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
@@ -90,485 +134,244 @@ impl Example {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = self.repo_name()?;
|
||||
pub fn worktree_path(&self) -> PathBuf {
|
||||
WORKTREES_DIR
|
||||
.join(&self.name)
|
||||
.join(self.repo_name().unwrap().1.as_ref())
|
||||
}
|
||||
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
pub fn repo_path(&self) -> PathBuf {
|
||||
let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
|
||||
REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &self.repository_url],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
let mut examples = Vec::new();
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
let stdin_path: PathBuf = PathBuf::from("-");
|
||||
|
||||
let inputs = if inputs.is_empty() {
|
||||
&[stdin_path]
|
||||
} else {
|
||||
inputs
|
||||
};
|
||||
|
||||
for path in inputs {
|
||||
let is_stdin = path.as_path() == Path::new("-");
|
||||
let content = if is_stdin {
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin()
|
||||
.read_to_string(&mut buffer)
|
||||
.expect("Failed to read from stdin");
|
||||
buffer
|
||||
} else {
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &self.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
if revision != self.revision {
|
||||
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
|
||||
}
|
||||
revision
|
||||
std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
|
||||
};
|
||||
let filename = path.file_stem().unwrap().to_string_lossy().to_string();
|
||||
let ext = if !is_stdin {
|
||||
path.extension()
|
||||
.map(|ext| ext.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| panic!("{} should have an extension", path.display()))
|
||||
} else {
|
||||
"jsonl".to_string()
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["worktree", "add", "-f", &worktree_path_string, &file_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !self.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
if !apply_result.status.success() {
|
||||
anyhow::bail!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
match ext.as_ref() {
|
||||
"json" => {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
|
||||
panic!("Failed to parse example file: {}\n{error}", path.display())
|
||||
});
|
||||
if example.name.is_empty() {
|
||||
example.name = filename;
|
||||
}
|
||||
examples.push(example);
|
||||
}
|
||||
"jsonl" => examples.extend(
|
||||
content
|
||||
.lines()
|
||||
.enumerate()
|
||||
.map(|(line_ix, line)| {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Failed to parse example on {}:{}",
|
||||
path.display(),
|
||||
line_ix + 1
|
||||
)
|
||||
});
|
||||
if example.name.is_empty() {
|
||||
example.name = format!("{filename}-{line_ix}")
|
||||
}
|
||||
example
|
||||
})
|
||||
.collect::<Vec<Example>>(),
|
||||
),
|
||||
"md" => {
|
||||
examples.push(parse_markdown_example(filename, &content).unwrap());
|
||||
}
|
||||
ext => {
|
||||
panic!("{} has invalid example extension `{ext}`", path.display())
|
||||
}
|
||||
}
|
||||
|
||||
Ok(worktree_path)
|
||||
}
|
||||
examples
|
||||
}
|
||||
|
||||
pub fn unique_name(&self) -> String {
|
||||
let mut hasher = std::hash::DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
let disambiguator = hasher.finish();
|
||||
let hash = format!("{:04x}", disambiguator);
|
||||
format!("{}_{}", &self.revision[..8], &hash[..4])
|
||||
pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
|
||||
let mut content = String::new();
|
||||
for example in examples {
|
||||
let line = serde_json::to_string(example).unwrap();
|
||||
content.push_str(&line);
|
||||
content.push('\n');
|
||||
}
|
||||
if let Some(output_path) = output_path {
|
||||
std::fs::write(output_path, content).expect("Failed to write examples");
|
||||
} else {
|
||||
std::io::stdout().write_all(&content.as_bytes()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub type ActualExcerpt = Excerpt;
|
||||
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Excerpt {
|
||||
pub path: PathBuf,
|
||||
pub text: String,
|
||||
}
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
pub enum ExampleFormat {
|
||||
Json,
|
||||
Toml,
|
||||
Md,
|
||||
}
|
||||
let parser = Parser::new(input);
|
||||
|
||||
impl NamedExample {
|
||||
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let ext = path.extension();
|
||||
let mut example = Example {
|
||||
name: id,
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: PathBuf::new().into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
buffer: None,
|
||||
context: None,
|
||||
prompt: None,
|
||||
predictions: Vec::new(),
|
||||
score: Vec::new(),
|
||||
state: None,
|
||||
};
|
||||
|
||||
match ext.and_then(|s| s.to_str()) {
|
||||
Some("json") => Ok(Self {
|
||||
name: path.file_stem().unwrap_or_default().display().to_string(),
|
||||
example: serde_json::from_str(&content)?,
|
||||
}),
|
||||
Some("toml") => Ok(Self {
|
||||
name: path.file_stem().unwrap_or_default().display().to_string(),
|
||||
example: toml::from_str(&content)?,
|
||||
}),
|
||||
Some("md") => Self::parse_md(&content),
|
||||
Some(_) => {
|
||||
anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
|
||||
}
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"Failed to determine example type since the file does not have an extension."
|
||||
);
|
||||
}
|
||||
}
|
||||
let mut name = String::new();
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
ExpectedExcerpts,
|
||||
ExpectedPatch,
|
||||
Other,
|
||||
}
|
||||
|
||||
pub fn parse_md(input: &str) -> Result<Self> {
|
||||
use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
let mut current_section = Section::Other;
|
||||
|
||||
let parser = Parser::new(input);
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
let mut named = NamedExample {
|
||||
name: String::new(),
|
||||
example: Example {
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: PathBuf::new(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
},
|
||||
};
|
||||
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
ExpectedExcerpts,
|
||||
ExpectedPatch,
|
||||
Other,
|
||||
}
|
||||
|
||||
let mut current_section = Section::Other;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
if !named.name.is_empty()
|
||||
&& current_section == Section::Other
|
||||
// in h1 section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
named.example.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
named.example.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
if let Some((field, value)) = line.split_once('=') {
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
example.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
example.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
if !named.name.is_empty() {
|
||||
anyhow::bail!(
|
||||
"Found multiple H1 headings. There should only be one with the name of the example."
|
||||
);
|
||||
}
|
||||
named.name = mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
Section::UncommittedDiff
|
||||
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
|
||||
Section::EditHistory
|
||||
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
|
||||
Section::CursorPosition
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
|
||||
Section::ExpectedPatch
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
|
||||
Section::ExpectedExcerpts
|
||||
} else {
|
||||
Section::Other
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
}
|
||||
Event::Start(Tag::CodeBlock(kind)) => {
|
||||
match kind {
|
||||
CodeBlockKind::Fenced(info) => {
|
||||
block_info = info;
|
||||
}
|
||||
CodeBlockKind::Indented => {
|
||||
anyhow::bail!("Unexpected indented codeblock");
|
||||
}
|
||||
};
|
||||
}
|
||||
Event::Start(_) => {
|
||||
text.clear();
|
||||
block_info = "".into();
|
||||
}
|
||||
Event::End(TagEnd::CodeBlock) => {
|
||||
let block_info = block_info.trim();
|
||||
match current_section {
|
||||
Section::UncommittedDiff => {
|
||||
named.example.uncommitted_diff = mem::take(&mut text);
|
||||
}
|
||||
Section::EditHistory => {
|
||||
named.example.edit_history.push_str(&mem::take(&mut text));
|
||||
}
|
||||
Section::CursorPosition => {
|
||||
named.example.cursor_path = block_info.into();
|
||||
named.example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
named.example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Other => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if named.example.cursor_path.as_path() == Path::new("")
|
||||
|| named.example.cursor_position.is_empty()
|
||||
{
|
||||
anyhow::bail!("Missing cursor position codeblock");
|
||||
}
|
||||
|
||||
Ok(named)
|
||||
}
|
||||
|
||||
pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
|
||||
match format {
|
||||
ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
|
||||
ExampleFormat::Toml => {
|
||||
Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
if !name.is_empty() {
|
||||
anyhow::bail!(
|
||||
"Found multiple H1 headings. There should only be one with the name of the example."
|
||||
);
|
||||
}
|
||||
name = mem::take(&mut text);
|
||||
}
|
||||
ExampleFormat::Md => Ok(write!(out, "{}", self)?),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_project(
|
||||
&self,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Project>> {
|
||||
let worktree_path = self.setup_worktree().await?;
|
||||
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
|
||||
AUTHENTICATED
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
.clone()
|
||||
.await;
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})?
|
||||
.await;
|
||||
|
||||
anyhow::Ok(project)
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self) -> Result<PathBuf> {
|
||||
self.example.setup_worktree(self.file_name()).await
|
||||
}
|
||||
|
||||
pub fn file_name(&self) -> String {
|
||||
self.name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_whitespace() {
|
||||
'-'
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
Section::UncommittedDiff
|
||||
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
|
||||
Section::EditHistory
|
||||
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
|
||||
Section::CursorPosition
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
|
||||
Section::ExpectedPatch
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
|
||||
Section::ExpectedExcerpts
|
||||
} else {
|
||||
c.to_ascii_lowercase()
|
||||
Section::Other
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
}
|
||||
Event::Start(Tag::CodeBlock(kind)) => {
|
||||
match kind {
|
||||
CodeBlockKind::Fenced(info) => {
|
||||
block_info = info;
|
||||
}
|
||||
CodeBlockKind::Indented => {
|
||||
anyhow::bail!("Unexpected indented codeblock");
|
||||
}
|
||||
};
|
||||
}
|
||||
Event::Start(_) => {
|
||||
text.clear();
|
||||
block_info = "".into();
|
||||
}
|
||||
Event::End(TagEnd::CodeBlock) => {
|
||||
let block_info = block_info.trim();
|
||||
match current_section {
|
||||
Section::UncommittedDiff => {
|
||||
example.uncommitted_diff = mem::take(&mut text);
|
||||
}
|
||||
Section::EditHistory => {
|
||||
example.edit_history.push_str(&mem::take(&mut text));
|
||||
}
|
||||
Section::CursorPosition => {
|
||||
example.cursor_path = Path::new(block_info).into();
|
||||
example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Other => {}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn cursor_position(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})?;
|
||||
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = self
|
||||
.example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))?;
|
||||
let mut cursor_excerpt = self.example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let Some((excerpt_offset, _)) = matches.next() else {
|
||||
anyhow::bail!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
|
||||
);
|
||||
};
|
||||
assert!(matches.next().is_none());
|
||||
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_edit_history(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'_>> {
|
||||
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
impl Display for NamedExample {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "# {}\n\n", self.name)?;
|
||||
write!(
|
||||
f,
|
||||
"{REPOSITORY_URL_FIELD} = {}\n",
|
||||
self.example.repository_url
|
||||
)?;
|
||||
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
|
||||
|
||||
write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
|
||||
write!(f, "`````diff\n")?;
|
||||
write!(f, "{}", self.example.uncommitted_diff)?;
|
||||
write!(f, "`````\n")?;
|
||||
|
||||
if !self.example.edit_history.is_empty() {
|
||||
write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
write!(
|
||||
f,
|
||||
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
|
||||
self.example.cursor_path.display(),
|
||||
self.example.cursor_position
|
||||
)?;
|
||||
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
|
||||
|
||||
if !self.example.expected_patch.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
|
||||
self.example.expected_patch
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
|
||||
anyhow::bail!("Missing cursor position codeblock");
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
Ok(example)
|
||||
}
|
||||
|
||||
280
crates/edit_prediction_cli/src/format_prompt.rs
Normal file
280
crates/edit_prediction_cli/src/format_prompt.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
use crate::{
|
||||
PromptFormat,
|
||||
example::{Example, ExamplePrompt},
|
||||
headless::EpAppState,
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
|
||||
pub async fn run_format_prompt(
|
||||
example: &mut Example,
|
||||
prompt_format: PromptFormat,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
run_context_retrieval(example, app_state, cx.clone()).await;
|
||||
|
||||
let prompt = match prompt_format {
|
||||
PromptFormat::Teacher => TeacherPrompt::format(example),
|
||||
PromptFormat::Zeta2 => {
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let snapshot = state
|
||||
.buffer
|
||||
.read_with(&cx, |buffer, _| buffer.snapshot())
|
||||
.unwrap();
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store
|
||||
.update(&mut cx, |ep_store, _cx| {
|
||||
zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example.context.as_ref().unwrap().files.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example.buffer.as_ref().unwrap().cursor_offset,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
format_zeta_prompt(&input)
|
||||
}
|
||||
};
|
||||
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.expected_patch.clone(), // TODO
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
|
||||
pub trait PromptFormatter {
|
||||
fn format(example: &Example) -> String;
|
||||
}
|
||||
|
||||
pub trait PromptParser {
|
||||
/// Return unified diff patch of prediction given raw LLM response
|
||||
fn parse(example: &Example, response: &str) -> String;
|
||||
}
|
||||
|
||||
pub struct TeacherPrompt;
|
||||
|
||||
impl PromptFormatter for TeacherPrompt {
|
||||
fn format(example: &Example) -> String {
|
||||
let edit_history = Self::format_edit_history(&example.edit_history);
|
||||
let context = Self::format_context(example);
|
||||
let editable_region = Self::format_editable_region(example);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history)
|
||||
.replace("{{editable_region}}", &editable_region);
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
// Strip comments ("garbage lines") from edit history
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_udiff_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
|
||||
if history_lines.is_empty() {
|
||||
return "(No edit history)".to_string();
|
||||
}
|
||||
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
if example.context.is_none() {
|
||||
panic!("Missing context retriever step");
|
||||
}
|
||||
|
||||
let mut prompt = String::new();
|
||||
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
fn format_editable_region(example: &Example) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
let path_str = example.cursor_path.to_string_lossy();
|
||||
result.push_str(&format!("`````path=\"{path_str}\"\n"));
|
||||
result.push_str(Self::EDITABLE_REGION_START);
|
||||
|
||||
// TODO: control number of lines around cursor
|
||||
result.push_str(&example.cursor_position);
|
||||
if !example.cursor_position.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
|
||||
result.push_str("`````");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::EDITABLE_REGION_START)
|
||||
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
|
||||
let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
|
||||
|
||||
let region = &text[start..end];
|
||||
|
||||
region.replace("<|user_cursor|>", "")
|
||||
}
|
||||
|
||||
fn is_udiff_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptParser for TeacherPrompt {
|
||||
fn parse(example: &Example, response: &str) -> String {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
// (can be fixed by getting cursor coordinates at the load_example stage)
|
||||
// 2. Context retriever just didn't include cursor line.
|
||||
//
|
||||
// In that case, fallback to using `cursor_position` as excerpt.
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.expect("`buffer` should be filled in in the context collection step")
|
||||
.content;
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
let new_editable_region = extract_last_codeblock(response);
|
||||
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
if !cursor_file.contains(&old_editable_region) {
|
||||
panic!("Something's wrong: editable_region is not found in the cursor file")
|
||||
}
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
|
||||
let diff = language::unified_diff(&cursor_file, &edited_file);
|
||||
|
||||
let diff = indoc::formatdoc! {"
|
||||
--- a/{path}
|
||||
+++ b/{path}
|
||||
{diff}
|
||||
",
|
||||
path = example.cursor_path.to_string_lossy(),
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````path='something' lines=1:2
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let text = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = TeacherPrompt::extract_editable_region(text);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@ use std::sync::Arc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
/// Headless subset of `workspace::AppState`.
|
||||
pub struct ZetaCliAppState {
|
||||
pub struct EpAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
|
||||
}
|
||||
|
||||
// TODO: dedupe with crates/eval/src/eval.rs
|
||||
pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
pub fn init(cx: &mut App) -> EpAppState {
|
||||
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
|
||||
|
||||
let app_version = AppVersion::load(
|
||||
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
prompt_store::init(cx);
|
||||
terminal_view::init(cx);
|
||||
|
||||
ZetaCliAppState {
|
||||
EpAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
|
||||
320
crates/edit_prediction_cli/src/load_project.rs
Normal file
320
crates/edit_prediction_cli/src/load_project.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
headless::EpAppState,
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, ToOffset, ToPoint};
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
|
||||
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
|
||||
if example.state.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let project = setup_project(example, &app_state, &mut cx).await;
|
||||
let buffer_store = project
|
||||
.read_with(&cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
.detach();
|
||||
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
|
||||
example.buffer = buffer
|
||||
.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
Some(ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
example.state = Some(ExampleState {
|
||||
buffer,
|
||||
project,
|
||||
cursor_position,
|
||||
_open_buffers,
|
||||
});
|
||||
}
|
||||
|
||||
async fn cursor_position(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (Entity<Buffer>, Anchor) {
|
||||
let worktree = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
|
||||
.unwrap()
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
let cursor_offset_within_excerpt = example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))
|
||||
.unwrap();
|
||||
let mut cursor_excerpt = example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
|
||||
panic!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
|
||||
);
|
||||
});
|
||||
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
excerpt_offset
|
||||
}).unwrap();
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor = cursor_buffer
|
||||
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
|
||||
.unwrap();
|
||||
|
||||
(cursor_buffer, cursor_anchor)
|
||||
}
|
||||
|
||||
async fn setup_project(
|
||||
example: &mut Example,
|
||||
app_state: &Arc<EpAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Entity<Project> {
|
||||
setup_worktree(example).await;
|
||||
|
||||
let project = cx
|
||||
.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&example.worktree_path(), true, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})
|
||||
.unwrap()
|
||||
.await;
|
||||
project
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(example: &Example) {
|
||||
let repo_dir = example.repo_path();
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir).unwrap();
|
||||
run_git(&repo_dir, &["init"]).await.unwrap();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.repository_url],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", example.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &example.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
if revision != example.revision {
|
||||
run_git(&repo_dir, &["tag", &example.revision, &revision])
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = example.worktree_path();
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()])
|
||||
.await
|
||||
.unwrap();
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.name, revision.as_str()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"worktree",
|
||||
"add",
|
||||
"-f",
|
||||
&worktree_path_string,
|
||||
&example.name,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !example.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.unwrap();
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin
|
||||
.write_all(example.uncommitted_diff.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
stdin.close().await.unwrap();
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await.unwrap();
|
||||
if !apply_result.status.success() {
|
||||
panic!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_edit_history(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers> {
|
||||
edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
@@ -1,522 +1,196 @@
|
||||
mod evaluate;
|
||||
mod anthropic_client;
|
||||
mod example;
|
||||
mod format_prompt;
|
||||
mod headless;
|
||||
mod load_project;
|
||||
mod metrics;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod source_location;
|
||||
mod training;
|
||||
mod util;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
|
||||
use crate::{
|
||||
evaluate::run_evaluate,
|
||||
example::{ExampleFormat, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
predict::run_predict,
|
||||
source_location::SourceLocation,
|
||||
training::{context::ContextType, distill::run_distill},
|
||||
util::{open_buffer, open_buffer_with_language_server},
|
||||
};
|
||||
use ::util::{ResultExt, paths::PathStyle};
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
use gpui::{Application, AsyncApp, Entity, prelude::*};
|
||||
use language::{Bias, Buffer, BufferSnapshot, Point};
|
||||
use metrics::delta_chr_f;
|
||||
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::io::{self};
|
||||
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::example::{read_examples, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::predict::run_prediction;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
use crate::score::run_scoring;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeta")]
|
||||
struct ZetaCliArgs {
|
||||
#[command(name = "ep")]
|
||||
struct EpArgs {
|
||||
#[arg(long, default_value_t = false)]
|
||||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10)]
|
||||
max_parallelism: usize,
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
#[clap(global = true)]
|
||||
inputs: Vec<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
output: Option<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
in_place: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Command {
|
||||
Context(ContextArgs),
|
||||
Predict(PredictArguments),
|
||||
Eval(EvaluateArguments),
|
||||
Distill(DistillArguments),
|
||||
ConvertExample {
|
||||
path: PathBuf,
|
||||
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
|
||||
output_format: ExampleFormat,
|
||||
},
|
||||
Score {
|
||||
golden_patch: PathBuf,
|
||||
actual_patch: PathBuf,
|
||||
},
|
||||
/// Parse markdown examples and output a combined .jsonl file
|
||||
ParseExample,
|
||||
/// Create git worktrees for each example and load file contents
|
||||
LoadBuffer,
|
||||
/// Retrieve context for input examples.
|
||||
Context,
|
||||
/// Generate a prompt string for a specific model
|
||||
FormatPrompt(FormatPromptArgs),
|
||||
/// Runs edit prediction
|
||||
Predict(PredictArgs),
|
||||
/// Computes a score based on actual and expected patches
|
||||
Score(PredictArgs),
|
||||
/// Print aggregated scores
|
||||
Eval(PredictArgs),
|
||||
/// Remove git repositories and worktrees
|
||||
Clean,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct ContextArgs {
|
||||
#[arg(long)]
|
||||
provider: ContextProvider,
|
||||
#[arg(long)]
|
||||
worktree: PathBuf,
|
||||
#[arg(long)]
|
||||
cursor: SourceLocation,
|
||||
#[arg(long)]
|
||||
use_language_server: bool,
|
||||
#[arg(long)]
|
||||
edit_history: Option<FileOrStdin>,
|
||||
#[clap(flatten)]
|
||||
zeta2_args: Zeta2Args,
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
prompt_format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum ContextProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PromptFormat {
|
||||
Teacher,
|
||||
Zeta2,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
struct Zeta2Args {
|
||||
#[arg(long, default_value_t = 8192)]
|
||||
max_prompt_bytes: usize,
|
||||
#[arg(long, default_value_t = 2048)]
|
||||
max_excerpt_bytes: usize,
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
min_excerpt_bytes: usize,
|
||||
#[arg(long, default_value_t = 0.66)]
|
||||
target_before_cursor_over_total_bytes: f32,
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
max_diagnostic_bytes: usize,
|
||||
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
|
||||
prompt_format: PromptFormat,
|
||||
#[arg(long, value_enum, default_value_t = Default::default())]
|
||||
output_format: OutputFormat,
|
||||
#[arg(long, default_value_t = 42)]
|
||||
file_indexing_parallelism: usize,
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_imports_gathering: bool,
|
||||
#[arg(long, default_value_t = u8::MAX)]
|
||||
max_retrieved_definitions: u8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct PredictArguments {
|
||||
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
|
||||
format: PredictionsOutputFormat,
|
||||
example_path: PathBuf,
|
||||
#[clap(flatten)]
|
||||
options: PredictionOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct DistillArguments {
|
||||
split_commit_dataset: PathBuf,
|
||||
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
|
||||
context_type: ContextType,
|
||||
#[clap(long)]
|
||||
batch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
pub struct PredictionOptions {
|
||||
#[clap(flatten)]
|
||||
zeta2: Zeta2Args,
|
||||
struct PredictArgs {
|
||||
#[clap(long)]
|
||||
provider: PredictionProvider,
|
||||
#[clap(long, value_enum, default_value_t = CacheMode::default())]
|
||||
cache: CacheMode,
|
||||
#[clap(long, default_value_t = 1)]
|
||||
repetitions: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
|
||||
pub enum CacheMode {
|
||||
/// Use cached LLM requests and responses, except when multiple repetitions are requested
|
||||
#[default]
|
||||
Auto,
|
||||
/// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
|
||||
#[value(alias = "request")]
|
||||
Requests,
|
||||
/// Ignore existing cache entries for both LLM and search.
|
||||
Skip,
|
||||
/// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
|
||||
/// Useful for reproducing results and fixing bugs outside of search queries
|
||||
Force,
|
||||
}
|
||||
|
||||
impl CacheMode {
|
||||
fn use_cached_llm_responses(&self) -> bool {
|
||||
self.assert_not_auto();
|
||||
matches!(self, CacheMode::Requests | CacheMode::Force)
|
||||
}
|
||||
|
||||
fn use_cached_search_results(&self) -> bool {
|
||||
self.assert_not_auto();
|
||||
matches!(self, CacheMode::Force)
|
||||
}
|
||||
|
||||
fn assert_not_auto(&self) {
|
||||
assert_ne!(
|
||||
*self,
|
||||
CacheMode::Auto,
|
||||
"Cache mode should not be auto at this point!"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone)]
|
||||
pub enum PredictionsOutputFormat {
|
||||
Json,
|
||||
Md,
|
||||
Diff,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct EvaluateArguments {
|
||||
example_paths: Vec<PathBuf>,
|
||||
#[clap(flatten)]
|
||||
options: PredictionOptions,
|
||||
#[clap(short, long, default_value_t = 1, alias = "repeat")]
|
||||
repetitions: u16,
|
||||
#[arg(long)]
|
||||
skip_prediction: bool,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PredictionProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
Zeta2,
|
||||
Sweep,
|
||||
Mercury,
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Teacher,
|
||||
}
|
||||
|
||||
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
|
||||
edit_prediction::ZetaOptions {
|
||||
context: EditPredictionExcerptOptions {
|
||||
max_bytes: args.max_excerpt_bytes,
|
||||
min_bytes: args.min_excerpt_bytes,
|
||||
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
|
||||
},
|
||||
max_prompt_bytes: args.max_prompt_bytes,
|
||||
prompt_format: args.prompt_format.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum PromptFormat {
|
||||
OnlySnippets,
|
||||
#[default]
|
||||
OldTextNewText,
|
||||
Minimal,
|
||||
MinimalQwen,
|
||||
SeedCoder1120,
|
||||
}
|
||||
|
||||
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
|
||||
fn into(self) -> predict_edits_v3::PromptFormat {
|
||||
match self {
|
||||
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
|
||||
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
|
||||
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
|
||||
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
|
||||
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone)]
|
||||
enum OutputFormat {
|
||||
#[default]
|
||||
Prompt,
|
||||
Request,
|
||||
Full,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum FileOrStdin {
|
||||
File(PathBuf),
|
||||
Stdin,
|
||||
}
|
||||
|
||||
impl FileOrStdin {
|
||||
async fn read_to_string(&self) -> Result<String, std::io::Error> {
|
||||
match self {
|
||||
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
|
||||
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for FileOrStdin {
|
||||
type Err = <PathBuf as FromStr>::Err;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"-" => Ok(Self::Stdin),
|
||||
_ => Ok(Self::File(PathBuf::from_str(s)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LoadedContext {
|
||||
full_path_str: String,
|
||||
snapshot: BufferSnapshot,
|
||||
clipped_cursor: Point,
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
lsp_open_handle: Option<OpenLspBufferHandle>,
|
||||
}
|
||||
|
||||
async fn load_context(
|
||||
args: &ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<LoadedContext> {
|
||||
let ContextArgs {
|
||||
worktree: worktree_path,
|
||||
cursor,
|
||||
use_language_server,
|
||||
..
|
||||
} = args;
|
||||
|
||||
let worktree_path = worktree_path.canonicalize()?;
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let mut ready_languages = HashSet::default();
|
||||
let (lsp_open_handle, buffer) = if *use_language_server {
|
||||
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
|
||||
project.clone(),
|
||||
worktree.clone(),
|
||||
cursor.path.clone(),
|
||||
&mut ready_languages,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
(Some(lsp_open_handle), buffer)
|
||||
} else {
|
||||
let buffer =
|
||||
open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
|
||||
(None, buffer)
|
||||
};
|
||||
|
||||
let full_path_str = worktree
|
||||
.read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
|
||||
.display(PathStyle::local())
|
||||
.to_string();
|
||||
|
||||
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
|
||||
if clipped_cursor != cursor.point {
|
||||
let max_row = snapshot.max_point().row;
|
||||
if cursor.point.row < max_row {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (line length is {})",
|
||||
cursor.point,
|
||||
snapshot.line_len(cursor.point.row)
|
||||
));
|
||||
impl EpArgs {
|
||||
fn output_path(&self) -> Option<PathBuf> {
|
||||
if self.in_place {
|
||||
if self.inputs.len() == 1 {
|
||||
self.inputs.first().cloned()
|
||||
} else {
|
||||
panic!("--in-place requires exactly one input file")
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (max row is {})",
|
||||
cursor.point,
|
||||
max_row
|
||||
));
|
||||
self.output.clone()
|
||||
}
|
||||
}
|
||||
|
||||
Ok(LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
clipped_cursor,
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
lsp_open_handle,
|
||||
})
|
||||
}
|
||||
|
||||
async fn zeta2_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<String> {
|
||||
let LoadedContext {
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
clipped_cursor,
|
||||
lsp_open_handle: _handle,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
// wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
|
||||
// the whole worktree.
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})?
|
||||
.await;
|
||||
let output = cx
|
||||
.update(|cx| {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
store.update(cx, |store, cx| {
|
||||
store.set_options(zeta2_args_to_options(&args.zeta2_args));
|
||||
store.register_buffer(&buffer, &project, cx);
|
||||
});
|
||||
cx.spawn(async move |cx| {
|
||||
let updates_rx = store.update(cx, |store, cx| {
|
||||
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &buffer, cursor, cx);
|
||||
store.project_context_updates(&project).unwrap()
|
||||
})?;
|
||||
|
||||
updates_rx.recv().await.ok();
|
||||
|
||||
let context = store.update(cx, |store, cx| {
|
||||
store.context_for_project(&project, cx).to_vec()
|
||||
})?;
|
||||
|
||||
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn zeta1_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
|
||||
let LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
clipped_cursor,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
let events = match args.edit_history {
|
||||
Some(events) => events.read_to_string().await?,
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let prompt_for_events = move || (events, 0);
|
||||
cx.update(|cx| {
|
||||
edit_prediction::zeta1::gather_context(
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
clipped_cursor,
|
||||
prompt_for_events,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
}
|
||||
|
||||
fn main() {
|
||||
zlog::init();
|
||||
zlog::init_output_stderr();
|
||||
let args = ZetaCliArgs::parse();
|
||||
let args = EpArgs::parse();
|
||||
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
return;
|
||||
}
|
||||
|
||||
let output = args.output_path();
|
||||
let command = match args.command {
|
||||
Some(cmd) => cmd,
|
||||
None => {
|
||||
EpArgs::command().print_help().unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match &command {
|
||||
Command::Clean => {
|
||||
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let mut examples = read_examples(&args.inputs);
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client);
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = Arc::new(headless::init(cx));
|
||||
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
match args.command {
|
||||
None => {
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
} else {
|
||||
panic!("Expected a command");
|
||||
}
|
||||
}
|
||||
Some(Command::Context(context_args)) => {
|
||||
let result = match context_args.provider {
|
||||
ContextProvider::Zeta1 => {
|
||||
let context =
|
||||
zeta1_context(context_args, &app_state, cx).await.unwrap();
|
||||
serde_json::to_string_pretty(&context.body).unwrap()
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
_ => (),
|
||||
};
|
||||
|
||||
for data in examples.chunks_mut(args.max_parallelism) {
|
||||
let mut futures = Vec::new();
|
||||
for example in data.iter_mut() {
|
||||
let cx = cx.clone();
|
||||
let app_state = app_state.clone();
|
||||
futures.push(async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadBuffer => {
|
||||
run_load_project(example, app_state.clone(), cx).await;
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(example, app_state, cx).await;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(example, args.prompt_format, app_state, cx).await;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state, cx).await;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
ContextProvider::Zeta2 => {
|
||||
zeta2_context(context_args, &app_state, cx).await.unwrap()
|
||||
}
|
||||
};
|
||||
println!("{}", result);
|
||||
});
|
||||
}
|
||||
Some(Command::Predict(arguments)) => {
|
||||
run_predict(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Eval(arguments)) => {
|
||||
run_evaluate(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Distill(arguments)) => {
|
||||
let _guard = cx
|
||||
.update(|cx| gpui_tokio::Tokio::handle(cx))
|
||||
.unwrap()
|
||||
.enter();
|
||||
run_distill(arguments).await.log_err();
|
||||
}
|
||||
Some(Command::ConvertExample {
|
||||
path,
|
||||
output_format,
|
||||
}) => {
|
||||
let example = NamedExample::load(path).unwrap();
|
||||
example.write(output_format, io::stdout()).unwrap();
|
||||
}
|
||||
Some(Command::Score {
|
||||
golden_patch,
|
||||
actual_patch,
|
||||
}) => {
|
||||
let golden_content = std::fs::read_to_string(golden_patch).unwrap();
|
||||
let actual_content = std::fs::read_to_string(actual_patch).unwrap();
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
|
||||
let golden_diff: Vec<DiffLine> = golden_content
|
||||
.lines()
|
||||
.map(|line| DiffLine::parse(line))
|
||||
.collect();
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
|
||||
let actual_diff: Vec<DiffLine> = actual_content
|
||||
.lines()
|
||||
.map(|line| DiffLine::parse(line))
|
||||
.collect();
|
||||
|
||||
let score = delta_chr_f(&golden_diff, &actual_diff);
|
||||
println!("{:.2}", score);
|
||||
}
|
||||
Some(Command::Clean) => {
|
||||
std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
|
||||
}
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
type Counts = HashMap<String, usize>;
|
||||
type CountsDelta = HashMap<String, isize>;
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Scores {
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationMetrics {
|
||||
pub true_positives: usize,
|
||||
pub false_positives: usize,
|
||||
pub false_negatives: usize,
|
||||
}
|
||||
|
||||
impl Scores {
|
||||
pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
|
||||
impl ClassificationMetrics {
|
||||
pub fn from_sets(
|
||||
expected: &HashSet<String>,
|
||||
actual: &HashSet<String>,
|
||||
) -> ClassificationMetrics {
|
||||
let true_positives = expected.intersection(actual).count();
|
||||
let false_positives = actual.difference(expected).count();
|
||||
let false_negatives = expected.difference(actual).count();
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -45,32 +49,16 @@ impl Scores {
|
||||
}
|
||||
}
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
format!(
|
||||
"
|
||||
Precision : {:.4}
|
||||
Recall : {:.4}
|
||||
F1 Score : {:.4}
|
||||
True Positives : {}
|
||||
False Positives : {}
|
||||
False Negatives : {}",
|
||||
self.precision(),
|
||||
self.recall(),
|
||||
self.f1_score(),
|
||||
self.true_positives,
|
||||
self.false_positives,
|
||||
self.false_negatives
|
||||
)
|
||||
}
|
||||
|
||||
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
|
||||
pub fn aggregate<'a>(
|
||||
scores: impl Iterator<Item = &'a ClassificationMetrics>,
|
||||
) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -81,7 +69,7 @@ False Negatives : {}",
|
||||
false_negatives += score.false_negatives;
|
||||
}
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
@@ -115,7 +103,10 @@ False Negatives : {}",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
|
||||
pub fn line_match_score(
|
||||
expected_patch: &[DiffLine],
|
||||
actual_patch: &[DiffLine],
|
||||
) -> ClassificationMetrics {
|
||||
let expected_change_lines = expected_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
Scores::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
}
|
||||
|
||||
enum ChrfWhitespace {
|
||||
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
|
||||
let expected_counts = ngram_delta_to_counts(&expected_delta);
|
||||
let actual_counts = ngram_delta_to_counts(&actual_delta);
|
||||
|
||||
let score = Scores::from_counts(&expected_counts, &actual_counts);
|
||||
let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
|
||||
total_precision += score.precision();
|
||||
total_recall += score.recall();
|
||||
}
|
||||
|
||||
@@ -1,57 +1,25 @@
|
||||
use std::{env, path::PathBuf, sync::LazyLock};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::LazyLock,
|
||||
};
|
||||
|
||||
pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
|
||||
pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
let dir = dirs::home_dir().unwrap().join(".zed_ep");
|
||||
ensure_dir(&dir)
|
||||
});
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
|
||||
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
TARGET_ZETA_DIR
|
||||
DATA_DIR
|
||||
.join("runs")
|
||||
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
|
||||
pub fn print_run_data_dir(deep: bool, use_color: bool) {
|
||||
println!("\n## Run Data\n");
|
||||
let mut files = Vec::new();
|
||||
|
||||
let current_dir = std::env::current_dir().unwrap();
|
||||
for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
|
||||
let file = file.unwrap();
|
||||
if file.file_type().unwrap().is_dir() && deep {
|
||||
for file in std::fs::read_dir(file.path()).unwrap() {
|
||||
let path = file.unwrap().path();
|
||||
let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
|
||||
files.push(format!(
|
||||
"- {}/{}{}{}",
|
||||
path.parent().unwrap().display(),
|
||||
if use_color { "\x1b[34m" } else { "" },
|
||||
path.file_name().unwrap().display(),
|
||||
if use_color { "\x1b[0m" } else { "" },
|
||||
));
|
||||
}
|
||||
} else {
|
||||
let path = file.path();
|
||||
let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
|
||||
files.push(format!(
|
||||
"- {}/{}{}{}",
|
||||
path.parent().unwrap().display(),
|
||||
if use_color { "\x1b[34m" } else { "" },
|
||||
path.file_name().unwrap().display(),
|
||||
if use_color { "\x1b[0m" } else { "" }
|
||||
));
|
||||
}
|
||||
}
|
||||
files.sort();
|
||||
|
||||
for file in files {
|
||||
println!("{}", file);
|
||||
}
|
||||
|
||||
println!(
|
||||
"\n💡 Tip of the day: {} always points to the latest run\n",
|
||||
LATEST_EXAMPLE_RUN_DIR.display()
|
||||
);
|
||||
fn ensure_dir(path: &Path) -> PathBuf {
|
||||
std::fs::create_dir_all(path).expect("Failed to create directory");
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
@@ -1,374 +1,271 @@
|
||||
use crate::example::{ActualExcerpt, NamedExample};
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
|
||||
use crate::{
|
||||
CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
|
||||
PredictionProvider, PromptFormat,
|
||||
anthropic_client::AnthropicClient,
|
||||
example::{Example, ExamplePrediction},
|
||||
format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
||||
use gpui::{AppContext as _, AsyncApp, Task};
|
||||
use std::{
|
||||
fs,
|
||||
sync::{
|
||||
Arc, Mutex, OnceLock,
|
||||
atomic::{AtomicUsize, Ordering::SeqCst},
|
||||
},
|
||||
};
|
||||
use ::serde::Serialize;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub async fn run_predict(
|
||||
args: PredictArguments,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
pub async fn run_prediction(
|
||||
example: &mut Example,
|
||||
provider: Option<PredictionProvider>,
|
||||
repetition_count: usize,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
let example = NamedExample::load(args.example_path).unwrap();
|
||||
let project = example.setup_project(app_state, cx).await.unwrap();
|
||||
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
let result = perform_predict(example, project, store, None, args.options, cx)
|
||||
.await
|
||||
if !example.predictions.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
|
||||
|
||||
let provider = provider.unwrap();
|
||||
|
||||
if matches!(provider, PredictionProvider::Teacher) {
|
||||
if example.prompt.is_none() {
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
|
||||
}
|
||||
|
||||
let batched = true;
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
|
||||
) {
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
AUTHENTICATED
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
.clone()
|
||||
.await;
|
||||
}
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
result.write(args.format, std::io::stdout()).unwrap();
|
||||
|
||||
print_run_data_dir(true, std::io::stdout().is_terminal());
|
||||
}
|
||||
ep_store
|
||||
.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher => unreachable!(),
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let run_dir = RUN_DIR.join(&example.name);
|
||||
|
||||
pub fn setup_store(
|
||||
provider: PredictionProvider,
|
||||
project: &Entity<Project>,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<EditPredictionStore>> {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
let updated_example = Arc::new(Mutex::new(example.clone()));
|
||||
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
let mut debug_rx = ep_store
|
||||
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
|
||||
.unwrap();
|
||||
let debug_task = cx.background_spawn({
|
||||
let updated_example = updated_example.clone();
|
||||
let current_run_ix = current_run_ix.clone();
|
||||
let run_dir = run_dir.clone();
|
||||
async move {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
let run_ix = current_run_ix.load(SeqCst);
|
||||
let mut updated_example = updated_example.lock().unwrap();
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let run_dir = if repetition_count > 1 {
|
||||
run_dir.join(format!("{:03}", run_ix))
|
||||
} else {
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
let store = store.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
match event {
|
||||
DebugEvent::EditPredictionStarted(request) => {
|
||||
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
||||
|
||||
anyhow::Ok(store)
|
||||
}
|
||||
|
||||
pub async fn perform_predict(
|
||||
example: NamedExample,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
repetition_ix: Option<u16>,
|
||||
options: PredictionOptions,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<PredictionDetails> {
|
||||
let mut cache_mode = options.cache;
|
||||
if repetition_ix.is_some() {
|
||||
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
|
||||
panic!("Repetitions are not supported in Auto cache mode");
|
||||
} else {
|
||||
cache_mode = CacheMode::Skip;
|
||||
}
|
||||
} else if cache_mode == CacheMode::Auto {
|
||||
cache_mode = CacheMode::Requests;
|
||||
}
|
||||
|
||||
let mut example_run_dir = RUN_DIR.join(&example.file_name());
|
||||
if let Some(repetition_ix) = repetition_ix {
|
||||
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
|
||||
}
|
||||
fs::create_dir_all(&example_run_dir)?;
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
store.with_eval_cache(Arc::new(RunCache {
|
||||
example_run_dir: example_run_dir.clone(),
|
||||
cache_mode,
|
||||
}));
|
||||
})?;
|
||||
|
||||
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
|
||||
|
||||
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
|
||||
|
||||
let prompt_format = options.zeta2.prompt_format;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
let mut options = store.options().clone();
|
||||
options.prompt_format = prompt_format.into();
|
||||
store.set_options(options);
|
||||
})?;
|
||||
|
||||
let mut debug_task = gpui::Task::ready(Ok(()));
|
||||
|
||||
if options.provider == crate::PredictionProvider::Zeta2 {
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
|
||||
|
||||
debug_task = cx.background_spawn({
|
||||
let result = result.clone();
|
||||
async move {
|
||||
let mut start_time = None;
|
||||
let mut retrieval_finished_at = None;
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
match event {
|
||||
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
|
||||
start_time = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_prompt.md"),
|
||||
&info.search_prompt,
|
||||
)?;
|
||||
if let Some(prompt) = request.prompt {
|
||||
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
}
|
||||
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
|
||||
retrieval_finished_at = Some(info.timestamp);
|
||||
for (key, value) in &info.metadata {
|
||||
if *key == "search_queries" {
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
value.as_bytes(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionFinished(request) => {
|
||||
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
||||
|
||||
if let Some(output) = request.model_output {
|
||||
fs::write(run_dir.join("prediction_response.md"), &output)?;
|
||||
updated_example
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_output = output;
|
||||
}
|
||||
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
|
||||
let prediction_started_at = Instant::now();
|
||||
start_time.get_or_insert(prediction_started_at);
|
||||
let prompt = request.local_prompt.unwrap_or_default();
|
||||
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
|
||||
{
|
||||
let mut result = result.lock().unwrap();
|
||||
result.prompt_len = prompt.chars().count();
|
||||
|
||||
for included_file in request.inputs.included_files {
|
||||
let insertions =
|
||||
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
|
||||
result.excerpts.extend(included_file.excerpts.iter().map(
|
||||
|excerpt| ActualExcerpt {
|
||||
path: included_file.path.components().skip(1).collect(),
|
||||
text: String::from(excerpt.text.as_ref()),
|
||||
},
|
||||
));
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
included_file.excerpts.iter(),
|
||||
if included_file.path == request.inputs.cursor_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut result.excerpts_text,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response =
|
||||
edit_prediction::open_ai_response::text_from_response(response)
|
||||
.unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
|
||||
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
result.retrieval_time =
|
||||
retrieval_finished_at.unwrap() - start_time.unwrap();
|
||||
result.prediction_time = prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
if run_ix >= repetition_count {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
store.update(cx, |store, cx| {
|
||||
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?;
|
||||
}
|
||||
|
||||
let prediction = store
|
||||
.update(cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&project,
|
||||
&cursor_buffer,
|
||||
cursor_anchor,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
debug_task.await?;
|
||||
|
||||
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
|
||||
|
||||
result.diff = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
anyhow::Ok(result)
|
||||
}
|
||||
|
||||
struct RunCache {
|
||||
cache_mode: CacheMode,
|
||||
example_run_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl RunCache {
|
||||
fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
|
||||
CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
|
||||
}
|
||||
|
||||
fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
|
||||
CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
|
||||
}
|
||||
|
||||
fn link_to_run(&self, key: &EvalCacheKey) {
|
||||
let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
|
||||
fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
|
||||
|
||||
let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
|
||||
fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalCache for RunCache {
|
||||
fn read(&self, key: EvalCacheKey) -> Option<String> {
|
||||
let path = RunCache::output_cache_path(&key);
|
||||
|
||||
if path.exists() {
|
||||
let use_cache = match key.0 {
|
||||
EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
|
||||
EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
|
||||
self.cache_mode.use_cached_llm_responses()
|
||||
}
|
||||
};
|
||||
if use_cache {
|
||||
log::info!("Using cache entry: {}", path.display());
|
||||
self.link_to_run(&key);
|
||||
Some(fs::read_to_string(path).unwrap())
|
||||
} else {
|
||||
log::trace!("Skipping cached entry: {}", path.display());
|
||||
None
|
||||
}
|
||||
} else if matches!(self.cache_mode, CacheMode::Force) {
|
||||
panic!(
|
||||
"No cached entry found for {:?}. Run without `--cache force` at least once.",
|
||||
key.0
|
||||
);
|
||||
for ix in 0..repetition_count {
|
||||
current_run_ix.store(ix, SeqCst);
|
||||
let run_dir = if repetition_count > 1 {
|
||||
run_dir.join(format!("{:03}", ix))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
|
||||
fs::create_dir_all(&*CACHE_DIR).unwrap();
|
||||
|
||||
let input_path = RunCache::input_cache_path(&key);
|
||||
fs::write(&input_path, input).unwrap();
|
||||
|
||||
let output_path = RunCache::output_cache_path(&key);
|
||||
log::trace!("Writing cache entry: {}", output_path.display());
|
||||
fs::write(&output_path, output).unwrap();
|
||||
|
||||
self.link_to_run(&key);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictionDetails {
|
||||
pub diff: String,
|
||||
pub excerpts: Vec<ActualExcerpt>,
|
||||
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
|
||||
pub retrieval_time: Duration,
|
||||
pub prediction_time: Duration,
|
||||
pub total_time: Duration,
|
||||
pub run_example_dir: PathBuf,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
}
|
||||
|
||||
impl PredictionDetails {
|
||||
pub fn new(run_example_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
diff: Default::default(),
|
||||
excerpts: Default::default(),
|
||||
excerpts_text: Default::default(),
|
||||
retrieval_time: Default::default(),
|
||||
prediction_time: Default::default(),
|
||||
total_time: Default::default(),
|
||||
run_example_dir,
|
||||
prompt_len: 0,
|
||||
generated_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
|
||||
let formatted = match format {
|
||||
PredictionsOutputFormat::Md => self.to_markdown(),
|
||||
PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
|
||||
PredictionsOutputFormat::Diff => self.diff.clone(),
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
Ok(out.write_all(formatted.as_bytes())?)
|
||||
fs::create_dir_all(&run_dir).unwrap();
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predictions
|
||||
.push(ExamplePrediction {
|
||||
actual_patch: String::new(),
|
||||
actual_output: String::new(),
|
||||
provider,
|
||||
});
|
||||
|
||||
let prediction = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&state.project,
|
||||
&state.buffer,
|
||||
state.cursor_position,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
format!(
|
||||
"## Excerpts\n\n\
|
||||
{}\n\n\
|
||||
## Prediction\n\n\
|
||||
{}\n\n\
|
||||
## Time\n\n\
|
||||
Retrieval: {}ms\n\
|
||||
Prediction: {}ms\n\n\
|
||||
Total: {}ms\n",
|
||||
self.excerpts_text,
|
||||
self.diff,
|
||||
self.retrieval_time.as_millis(),
|
||||
self.prediction_time.as_millis(),
|
||||
self.total_time.as_millis(),
|
||||
)
|
||||
ep_store
|
||||
.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})
|
||||
.unwrap();
|
||||
debug_task.await.unwrap();
|
||||
|
||||
*example = Arc::into_inner(updated_example)
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
|
||||
let llm_model_name = "claude-sonnet-4-5";
|
||||
let max_tokens = 16384;
|
||||
let llm_client = if batched {
|
||||
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
|
||||
} else {
|
||||
AnthropicClient::plain()
|
||||
};
|
||||
let llm_client = llm_client.expect("Failed to create LLM client");
|
||||
|
||||
let prompt = example
|
||||
.prompt
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
|
||||
|
||||
let messages = vec![anthropic::Message {
|
||||
role: anthropic::Role::User,
|
||||
content: vec![anthropic::RequestContent::Text {
|
||||
text: prompt.input.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = llm_client
|
||||
.generate(llm_model_name, max_tokens, messages)
|
||||
.await
|
||||
.unwrap()
|
||||
else {
|
||||
// Request stashed for batched processing
|
||||
return;
|
||||
};
|
||||
|
||||
let actual_output = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output);
|
||||
|
||||
let prediction = ExamplePrediction {
|
||||
actual_patch,
|
||||
actual_output,
|
||||
provider: PredictionProvider::Teacher,
|
||||
};
|
||||
|
||||
example.predictions.push(prediction);
|
||||
}
|
||||
|
||||
pub async fn sync_batches(provider: &PredictionProvider) {
|
||||
match provider {
|
||||
PredictionProvider::Teacher => {
|
||||
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
|
||||
let llm_client =
|
||||
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
|
||||
llm_client
|
||||
.sync_batches()
|
||||
.await
|
||||
.expect("Failed to sync batches");
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,106 +1,136 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use crate::{
|
||||
example::{Example, ExampleContext},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
|
||||
use project::lsp_store::OpenLspBufferHandle;
|
||||
use project::{Project, ProjectPath, Worktree};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use util::rel_path::RelPath;
|
||||
use language::{Buffer, LanguageNotFound};
|
||||
use project::Project;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
pub fn open_buffer(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<RelPath>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Entity<Buffer>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path,
|
||||
})?;
|
||||
pub async fn run_context_retrieval(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
if example.context.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let project = state.project.clone();
|
||||
|
||||
let _lsp_handle = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let mut events = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
DebugEvent::ContextRetrievalFinished(_) => {
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
})
|
||||
let context_files = ep_store
|
||||
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
|
||||
.unwrap();
|
||||
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn open_buffer_with_language_server(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<RelPath>,
|
||||
ready_languages: &mut HashSet<LanguageId>,
|
||||
async fn wait_for_language_server_to_start(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
|
||||
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
|
||||
|
||||
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
|
||||
(
|
||||
project.register_buffer_with_language_servers(&buffer, cx),
|
||||
project.path_style(cx),
|
||||
)
|
||||
})?;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
) {
|
||||
let language_registry = project
|
||||
.read_with(cx, |project, _| project.languages().clone())
|
||||
.unwrap();
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(path.as_std_path())
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
anyhow::bail!(error);
|
||||
panic!("Failed to load language for file path: {}", error);
|
||||
}
|
||||
|
||||
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})?
|
||||
let Some(language_id) = buffer
|
||||
.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})
|
||||
.unwrap()
|
||||
else {
|
||||
return Err(anyhow!("No language for {}", path.display(path_style)));
|
||||
panic!("No language for {:?}", example.cursor_path);
|
||||
};
|
||||
|
||||
let log_prefix = format!("{} | ", path.display(path_style));
|
||||
let mut ready_languages = HashSet::default();
|
||||
let log_prefix = format!("{} | ", example.name);
|
||||
if !ready_languages.contains(&language_id) {
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
ready_languages.insert(language_id);
|
||||
}
|
||||
|
||||
let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _cx| project.lsp_store())
|
||||
.unwrap();
|
||||
|
||||
// hacky wait for buffer to be registered with the language server
|
||||
for _ in 0..100 {
|
||||
let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.map(|(_, language_server)| language_server.server_id())
|
||||
if lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.map(|(_, language_server)| language_server.server_id())
|
||||
})
|
||||
})
|
||||
})?
|
||||
else {
|
||||
.unwrap()
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
} else {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(10))
|
||||
.await;
|
||||
continue;
|
||||
};
|
||||
|
||||
return Ok((lsp_open_handle, language_server_id, buffer));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(anyhow!("No language server found for buffer"));
|
||||
panic!("No language server found for buffer");
|
||||
}
|
||||
|
||||
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
|
||||
pub fn wait_for_lang_server(
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
119
crates/edit_prediction_cli/src/score.rs
Normal file
119
crates/edit_prediction_cli/src/score.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
PredictArgs,
|
||||
example::{Example, ExampleScore},
|
||||
headless::EpAppState,
|
||||
metrics::{self, ClassificationMetrics},
|
||||
predict::run_prediction,
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run_scoring(
|
||||
example: &mut Example,
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
cx: AsyncApp,
|
||||
) {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_patch = parse_patch(&example.expected_patch);
|
||||
|
||||
let mut scores = vec![];
|
||||
|
||||
for pred in &example.predictions {
|
||||
let actual_patch = parse_patch(&pred.actual_patch);
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
|
||||
|
||||
scores.push(ExampleScore {
|
||||
delta_chr_f,
|
||||
line_match,
|
||||
});
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
patch.lines().map(DiffLine::parse).collect()
|
||||
}
|
||||
|
||||
pub fn print_report(examples: &[Example]) {
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
|
||||
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
|
||||
);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
let mut all_line_match_scores = Vec::new();
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
|
||||
for example in examples {
|
||||
for score in example.score.iter() {
|
||||
let line_match = &score.line_match;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
truncate_name(&example.name, 30),
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0,
|
||||
score.delta_chr_f
|
||||
);
|
||||
|
||||
all_line_match_scores.push(line_match.clone());
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
if !all_line_match_scores.is_empty() {
|
||||
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
|
||||
let avg_delta_chr_f: f32 =
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
"TOTAL",
|
||||
total_line_match.true_positives,
|
||||
total_line_match.false_positives,
|
||||
total_line_match.false_negatives,
|
||||
total_line_match.precision() * 100.0,
|
||||
total_line_match.recall() * 100.0,
|
||||
total_line_match.f1_score() * 100.0,
|
||||
avg_delta_chr_f
|
||||
);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!("\n");
|
||||
}
|
||||
|
||||
fn truncate_name(name: &str, max_len: usize) -> String {
|
||||
if name.len() <= max_len {
|
||||
name.to_string()
|
||||
} else {
|
||||
format!("{}...", &name[..max_len - 3])
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
|
||||
|
||||
use ::util::{paths::PathStyle, rel_path::RelPath};
|
||||
use anyhow::{Result, anyhow};
|
||||
use language::Point;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct SourceLocation {
|
||||
pub path: Arc<RelPath>,
|
||||
pub point: Point,
|
||||
}
|
||||
|
||||
impl Serialize for SourceLocation {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for SourceLocation {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse().map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for SourceLocation {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}:{}:{}",
|
||||
self.path.display(PathStyle::Posix),
|
||||
self.point.row + 1,
|
||||
self.point.column + 1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SourceLocation {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!(
|
||||
"Invalid source location. Expected 'file.rs:line:column', got '{}'",
|
||||
s
|
||||
));
|
||||
}
|
||||
|
||||
let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
|
||||
let line: u32 = parts[1]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
|
||||
let column: u32 = parts[2]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
|
||||
|
||||
// Convert from 1-based to 0-based indexing
|
||||
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
|
||||
|
||||
Ok(SourceLocation { path, point })
|
||||
}
|
||||
}
|
||||
@@ -46,3 +46,7 @@ Output example:
|
||||
## Code Context
|
||||
|
||||
{{context}}
|
||||
|
||||
## Editable region
|
||||
|
||||
{{editable_region}}
|
||||
@@ -1,89 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
|
||||
|
||||
#[derive(Debug, Clone, Default, clap::ValueEnum)]
|
||||
pub enum ContextType {
|
||||
#[default]
|
||||
CurrentFile,
|
||||
}
|
||||
|
||||
const MAX_CONTEXT_SIZE: usize = 32768;
|
||||
|
||||
pub fn collect_context(
|
||||
context_type: &ContextType,
|
||||
worktree_dir: &Path,
|
||||
cursor: SourceLocation,
|
||||
) -> String {
|
||||
let context = match context_type {
|
||||
ContextType::CurrentFile => {
|
||||
let file_path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
|
||||
|
||||
let context = add_special_tags(&context, worktree_dir, cursor);
|
||||
context
|
||||
}
|
||||
};
|
||||
|
||||
let region_end_offset = context.find(TeacherModel::REGION_END);
|
||||
|
||||
if context.len() <= MAX_CONTEXT_SIZE {
|
||||
return context;
|
||||
}
|
||||
|
||||
if let Some(region_end_offset) = region_end_offset
|
||||
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
|
||||
{
|
||||
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
|
||||
format!(
|
||||
"[...{} bytes truncated]\n{}\n",
|
||||
to_truncate,
|
||||
&context[to_truncate..]
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}\n[...{} bytes truncated]\n",
|
||||
&context[..MAX_CONTEXT_SIZE],
|
||||
context.len() - MAX_CONTEXT_SIZE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add <|editable_region_start/end|> tags
|
||||
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
|
||||
let path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let file = std::fs::read_to_string(&path).unwrap_or_default();
|
||||
let lines = file.lines().collect::<Vec<_>>();
|
||||
let cursor_row = cursor.point.row as usize;
|
||||
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
|
||||
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
|
||||
|
||||
let snippet = lines[start_line..end_line].join("\n");
|
||||
|
||||
if context.contains(&snippet) {
|
||||
let mut cursor_line = lines[cursor_row].to_string();
|
||||
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
|
||||
|
||||
let mut snippet_with_tags_lines = vec![];
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_START);
|
||||
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
|
||||
snippet_with_tags_lines.push(&cursor_line);
|
||||
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_END);
|
||||
let snippet_with_tags = snippet_with_tags_lines.join("\n");
|
||||
|
||||
context.replace(&snippet, &snippet_with_tags)
|
||||
} else {
|
||||
log::warn!(
|
||||
"Can't find area around the cursor in the context; proceeding without special tags"
|
||||
);
|
||||
context.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strip_special_tags(context: &str) -> String {
|
||||
context
|
||||
.replace(TeacherModel::REGION_START, "")
|
||||
.replace(TeacherModel::REGION_END, "")
|
||||
.replace(TeacherModel::USER_CURSOR, "")
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
DistillArguments,
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::ContextType,
|
||||
llm_client::LlmClient,
|
||||
teacher::{TeacherModel, TeacherOutput},
|
||||
},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use reqwest_client::ReqwestClient;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SplitCommit {
|
||||
repo_url: String,
|
||||
commit_sha: String,
|
||||
edit_history: String,
|
||||
expected_patch: String,
|
||||
cursor_position: String,
|
||||
}
|
||||
|
||||
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
|
||||
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
|
||||
.expect("Failed to read split commit dataset")
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
|
||||
.collect();
|
||||
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
|
||||
let llm_client = if let Some(cache_path) = arguments.batch {
|
||||
LlmClient::batch(&cache_path, http_client)?
|
||||
} else {
|
||||
LlmClient::plain(http_client)?
|
||||
};
|
||||
|
||||
let mut teacher = TeacherModel::new(
|
||||
"claude-sonnet-4-5".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
llm_client,
|
||||
);
|
||||
|
||||
let mut num_marked_for_batching = 0;
|
||||
|
||||
for commit in split_commits {
|
||||
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
|
||||
println!("{}", serde_json::to_string(&distilled)?);
|
||||
} else {
|
||||
if num_marked_for_batching == 0 {
|
||||
log::warn!("Marked for batching");
|
||||
}
|
||||
num_marked_for_batching += 1;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{} requests are marked for batching",
|
||||
num_marked_for_batching
|
||||
);
|
||||
let llm_client = teacher.client;
|
||||
llm_client.sync_batches().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn distill_one(
|
||||
teacher: &mut TeacherModel,
|
||||
commit: SplitCommit,
|
||||
) -> Result<Option<TeacherOutput>> {
|
||||
let cursor: SourceLocation = commit
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let path = cursor.path.to_rel_path_buf();
|
||||
|
||||
let example = Example {
|
||||
repository_url: commit.repo_url,
|
||||
revision: commit.commit_sha,
|
||||
uncommitted_diff: commit.edit_history.clone(),
|
||||
cursor_path: path.as_std_path().to_path_buf(),
|
||||
cursor_position: commit.cursor_position,
|
||||
edit_history: commit.edit_history, // todo: trim
|
||||
expected_patch: commit.expected_patch,
|
||||
};
|
||||
|
||||
let prediction = teacher.predict(example).await;
|
||||
|
||||
prediction
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
pub mod context;
|
||||
pub mod distill;
|
||||
pub mod llm_client;
|
||||
pub mod teacher;
|
||||
@@ -1,266 +0,0 @@
|
||||
use crate::{
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::{ContextType, collect_context, strip_special_tags},
|
||||
llm_client::LlmClient,
|
||||
},
|
||||
};
|
||||
use anthropic::{Message, RequestContent, ResponseContent, Role};
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct TeacherModel {
|
||||
pub llm_name: String,
|
||||
pub context: ContextType,
|
||||
pub client: LlmClient,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct TeacherOutput {
|
||||
parsed_output: String,
|
||||
prompt: String,
|
||||
raw_llm_response: String,
|
||||
context: String,
|
||||
diff: String,
|
||||
}
|
||||
|
||||
impl TeacherModel {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
|
||||
|
||||
/// Number of lines to include before the cursor position
|
||||
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Number of lines to include after the cursor position
|
||||
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
|
||||
TeacherModel {
|
||||
llm_name,
|
||||
context,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
|
||||
let name = input.unique_name();
|
||||
let worktree_dir = input.setup_worktree(name).await?;
|
||||
let cursor: SourceLocation = input
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
|
||||
let edit_history = Self::format_edit_history(&input.edit_history);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history);
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: vec![RequestContent::Text {
|
||||
text: prompt.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = self
|
||||
.client
|
||||
.generate(self.llm_name.clone(), 16384, messages)
|
||||
.await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let response_text = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let parsed_output = self.parse_response(&response_text);
|
||||
|
||||
let original_editable_region = Self::extract_editable_region(&context);
|
||||
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
|
||||
let context_after_edit = strip_special_tags(&context_after_edit);
|
||||
let context_before_edit = strip_special_tags(&context);
|
||||
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
|
||||
|
||||
// zeta distill --batch batch_results.txt
|
||||
// zeta distill
|
||||
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
|
||||
// - store LLM requests in a batch, don't actual send the request
|
||||
// - send the batch (2000 requests) after all inputs are processed
|
||||
// 2. `zeta send-batches`
|
||||
// - upload the batch to Anthropic
|
||||
|
||||
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
|
||||
// https://crates.io/crates/anthropic-sdk-rust
|
||||
|
||||
// - poll for results
|
||||
// - when ready, store results in cache (a database)
|
||||
// 3. `zeta distill` again
|
||||
// - use the cached results this time
|
||||
|
||||
Ok(Some(TeacherOutput {
|
||||
parsed_output,
|
||||
prompt,
|
||||
raw_llm_response: response_text,
|
||||
context,
|
||||
diff,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_response(&self, content: &str) -> String {
|
||||
let codeblock = Self::extract_last_codeblock(content);
|
||||
let editable_region = Self::extract_editable_region(&codeblock);
|
||||
|
||||
editable_region
|
||||
}
|
||||
|
||||
/// Extract content from the last code-fenced block if any, or else return content as is
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::REGION_START)
|
||||
.map_or(0, |pos| pos + Self::REGION_START.len());
|
||||
let end = text.find(Self::REGION_END).unwrap_or(text.len());
|
||||
|
||||
text[start..end].to_string()
|
||||
}
|
||||
|
||||
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn is_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = "This is a test response.";
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, response.to_string());
|
||||
|
||||
let response = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
`````
|
||||
actual response
|
||||
`````
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, "actual response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = TeacherModel::extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::RelatedExcerpt;
|
||||
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::ops::Range;
|
||||
use zeta_prompt::RelatedExcerpt;
|
||||
|
||||
#[cfg(not(test))]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
|
||||
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
|
||||
|
||||
input_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let offset_range = range.to_offset(buffer);
|
||||
RelatedExcerpt {
|
||||
point_range: range,
|
||||
anchor_range: buffer.anchor_before(offset_range.start)
|
||||
..buffer.anchor_after(offset_range.end),
|
||||
text: buffer.as_rope().slice(offset_range),
|
||||
}
|
||||
.map(|range| RelatedExcerpt {
|
||||
row_range: range.start.row..range.end.row,
|
||||
text: buffer.text_for_range(range).collect(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -3,13 +3,13 @@ use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
|
||||
use project::{LocationLink, Project, ProjectPath};
|
||||
use serde::{Serialize, Serializer};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
|
||||
|
||||
pub use cloud_llm_client::predict_edits_v3::Line;
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
|
||||
|
||||
const IDENTIFIER_LINE_COUNT: u32 = 3;
|
||||
|
||||
pub struct RelatedExcerptStore {
|
||||
project: WeakEntity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
related_files: Arc<[RelatedFile]>,
|
||||
related_file_buffers: Vec<Entity<Buffer>>,
|
||||
cache: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
|
||||
identifier_line_count: u32,
|
||||
@@ -68,82 +70,6 @@ struct CachedDefinition {
|
||||
anchor_range: Range<Anchor>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedFile {
|
||||
#[serde(serialize_with = "serialize_project_path")]
|
||||
pub path: ProjectPath,
|
||||
#[serde(skip)]
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub excerpts: Vec<RelatedExcerpt>,
|
||||
pub max_row: u32,
|
||||
}
|
||||
|
||||
impl RelatedFile {
|
||||
pub fn merge_excerpts(&mut self) {
|
||||
self.excerpts.sort_unstable_by(|a, b| {
|
||||
a.point_range
|
||||
.start
|
||||
.cmp(&b.point_range.start)
|
||||
.then(b.point_range.end.cmp(&a.point_range.end))
|
||||
});
|
||||
|
||||
let mut index = 1;
|
||||
while index < self.excerpts.len() {
|
||||
if self.excerpts[index - 1]
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index].point_range.start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = self.excerpts.remove(index);
|
||||
if removed
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index - 1].point_range.end)
|
||||
.is_gt()
|
||||
{
|
||||
self.excerpts[index - 1].point_range.end = removed.point_range.end;
|
||||
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedExcerpt {
|
||||
#[serde(skip)]
|
||||
pub anchor_range: Range<Anchor>,
|
||||
#[serde(serialize_with = "serialize_point_range")]
|
||||
pub point_range: Range<Point>,
|
||||
#[serde(serialize_with = "serialize_rope")]
|
||||
pub text: Rope,
|
||||
}
|
||||
|
||||
fn serialize_project_path<S: Serializer>(
|
||||
project_path: &ProjectPath,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
project_path.path.serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
rope.to_string().serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_point_range<S: Serializer>(
|
||||
range: &Range<Point>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
[
|
||||
[range.start.row, range.start.column],
|
||||
[range.end.row, range.end.column],
|
||||
]
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
|
||||
|
||||
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
|
||||
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
|
||||
RelatedExcerptStore {
|
||||
project: project.downgrade(),
|
||||
update_tx,
|
||||
related_files: Vec::new(),
|
||||
related_files: Vec::new().into(),
|
||||
related_file_buffers: Vec::new(),
|
||||
cache: Default::default(),
|
||||
identifier_line_count: IDENTIFIER_LINE_COUNT,
|
||||
}
|
||||
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
|
||||
self.update_tx.unbounded_send((buffer, position)).ok();
|
||||
}
|
||||
|
||||
pub fn related_files(&self) -> &[RelatedFile] {
|
||||
&self.related_files
|
||||
pub fn related_files(&self) -> Arc<[RelatedFile]> {
|
||||
self.related_files.clone()
|
||||
}
|
||||
|
||||
pub fn related_files_with_buffers(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
|
||||
self.related_files
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.related_file_buffers.iter().cloned())
|
||||
}
|
||||
|
||||
pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
|
||||
self.related_files = files.into();
|
||||
}
|
||||
|
||||
async fn fetch_excerpts(
|
||||
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
|
||||
}
|
||||
mean_definition_latency /= cache_miss_count.max(1) as u32;
|
||||
|
||||
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
|
||||
let (new_cache, related_files, related_file_buffers) =
|
||||
rebuild_related_files(&project, new_cache, cx).await?;
|
||||
|
||||
if let Some(file) = &file {
|
||||
log::debug!(
|
||||
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.cache = new_cache;
|
||||
this.related_files = related_files;
|
||||
this.related_files = related_files.into();
|
||||
this.related_file_buffers = related_file_buffers;
|
||||
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
|
||||
}
|
||||
|
||||
async fn rebuild_related_files(
|
||||
project: &Entity<Project>,
|
||||
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
|
||||
) -> Result<(
|
||||
HashMap<Identifier, Arc<CacheEntry>>,
|
||||
Vec<RelatedFile>,
|
||||
Vec<Entity<Buffer>>,
|
||||
)> {
|
||||
let mut snapshots = HashMap::default();
|
||||
let mut worktree_root_names = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
|
||||
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
|
||||
.read_with(cx, |buffer, _| buffer.snapshot())?,
|
||||
);
|
||||
}
|
||||
let worktree_id = definition.path.worktree_id;
|
||||
if let hash_map::Entry::Vacant(e) =
|
||||
worktree_root_names.entry(definition.path.worktree_id)
|
||||
{
|
||||
project.read_with(cx, |project, cx| {
|
||||
if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
|
||||
e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cx
|
||||
.background_spawn(async move {
|
||||
let mut files = Vec::<RelatedFile>::new();
|
||||
let mut files = Vec::new();
|
||||
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
|
||||
let mut paths_by_buffer = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
|
||||
continue;
|
||||
};
|
||||
let excerpts = assemble_excerpts(snapshot, ranges);
|
||||
files.push(RelatedFile {
|
||||
path: project_path.clone(),
|
||||
buffer: buffer.downgrade(),
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
});
|
||||
let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let path = Path::new(&format!(
|
||||
"{}/{}",
|
||||
root_name,
|
||||
project_path.path.as_unix_str()
|
||||
))
|
||||
.into();
|
||||
|
||||
files.push((
|
||||
buffer,
|
||||
RelatedFile {
|
||||
path,
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
files.sort_by_key(|file| file.path.clone());
|
||||
(new_entries, files)
|
||||
files.sort_by_key(|(_, file)| file.path.clone());
|
||||
let (related_buffers, related_files) = files.into_iter().unzip();
|
||||
|
||||
(new_entries, related_files, related_buffers)
|
||||
})
|
||||
.await)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
&excerpts,
|
||||
&[
|
||||
(
|
||||
"src/company.rs",
|
||||
"root/src/company.rs",
|
||||
&[indoc! {"
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
}"}],
|
||||
),
|
||||
(
|
||||
"src/main.rs",
|
||||
"root/src/main.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
pub struct Session {
|
||||
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
],
|
||||
),
|
||||
(
|
||||
"src/person.rs",
|
||||
"root/src/person.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
impl Person {
|
||||
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.text.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
(file.path.path.as_unix_str(), excerpts)
|
||||
(file.path.to_str().unwrap(), excerpts)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let expected_excerpts = expected_files
|
||||
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
|
||||
if excerpt.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if current_row < excerpt.point_range.start.row {
|
||||
if current_row < excerpt.row_range.start {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
current_row = excerpt.point_range.start.row;
|
||||
current_row = excerpt.row_range.start;
|
||||
|
||||
for line in excerpt.text.to_string().lines() {
|
||||
output.push_str(line);
|
||||
|
||||
@@ -17,7 +17,6 @@ anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
codestral.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
@@ -46,6 +45,7 @@ ui_input.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -17,7 +17,7 @@ use gpui::{
|
||||
};
|
||||
use multi_buffer::MultiBuffer;
|
||||
use project::Project;
|
||||
use text::OffsetRangeExt;
|
||||
use text::Point;
|
||||
use ui::{
|
||||
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
|
||||
StyledTypography as _, h_flex, v_flex,
|
||||
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
|
||||
) -> Self {
|
||||
let store = EditPredictionStore::global(client, user_store, cx);
|
||||
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
|
||||
let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
|
||||
let _update_task = cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
|
||||
self.handle_context_retrieval_finished(info, window, cx);
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionRequested(_) => {}
|
||||
DebugEvent::EditPredictionStarted(_) => {}
|
||||
DebugEvent::EditPredictionFinished(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
|
||||
run.finished_at = Some(info.timestamp);
|
||||
run.metadata = info.metadata;
|
||||
|
||||
let project = self.project.clone();
|
||||
let related_files = self
|
||||
.store
|
||||
.read(cx)
|
||||
.context_for_project(&self.project, cx)
|
||||
.to_vec();
|
||||
.context_for_project_with_buffers(&self.project, cx)
|
||||
.map_or(Vec::new(), |files| files.collect());
|
||||
|
||||
let editor = run.editor.clone();
|
||||
let multibuffer = run.editor.read(cx).buffer().clone();
|
||||
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let mut paths = Vec::new();
|
||||
for related_file in related_files {
|
||||
let (buffer, point_ranges): (_, Vec<_>) =
|
||||
if let Some(buffer) = related_file.buffer.upgrade() {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
|
||||
(
|
||||
buffer,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(related_file.path.clone(), cx)
|
||||
})?
|
||||
.await?,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.point_range.clone())
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
for (related_file, buffer) in related_files {
|
||||
let point_ranges = related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| {
|
||||
Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cx.update(|_, cx| {
|
||||
let path = PathKey::for_buffer(&buffer, cx);
|
||||
paths.push((path, buffer, point_ranges));
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use cloud_zeta2_prompt::write_codeblock;
|
||||
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
|
||||
use editor::{Editor, ExcerptRange, MultiBuffer};
|
||||
use feature_flags::FeatureFlag;
|
||||
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
|
||||
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
|
||||
|
||||
for event in &prediction.inputs.events {
|
||||
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
|
||||
formatted_inputs.push_str("```diff\n");
|
||||
zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
|
||||
formatted_inputs.push_str("```\n\n");
|
||||
}
|
||||
|
||||
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
|
||||
|
||||
for included_file in &prediction.inputs.included_files {
|
||||
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
|
||||
write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
|
||||
|
||||
for included_file in prediction.inputs.related_files.as_ref() {
|
||||
write!(
|
||||
&mut formatted_inputs,
|
||||
"### {}\n\n",
|
||||
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
&included_file.excerpts,
|
||||
if included_file.path == prediction.inputs.cursor_path {
|
||||
cursor_insertions.as_slice()
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut formatted_inputs,
|
||||
);
|
||||
for excerpt in included_file.excerpts.iter() {
|
||||
write!(
|
||||
&mut formatted_inputs,
|
||||
"```{}\n{}\n```\n",
|
||||
included_file.path.display(),
|
||||
excerpt.text
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
|
||||
|
||||
writeln!(
|
||||
&mut formatted_inputs,
|
||||
"```{}\n{}<CURSOR>{}\n```\n",
|
||||
prediction.inputs.cursor_path.display(),
|
||||
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
|
||||
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
self.active_prediction = Some(ActivePrediction {
|
||||
prediction,
|
||||
feedback_editor: cx.new(|cx| {
|
||||
|
||||
@@ -280,7 +280,11 @@ pub fn deploy_context_menu(
|
||||
"Copy Permalink",
|
||||
Box::new(CopyPermalinkToLine),
|
||||
)
|
||||
.action_disabled_when(!has_git_repo, "File History", Box::new(git::FileHistory));
|
||||
.action_disabled_when(
|
||||
!has_git_repo,
|
||||
"View File History",
|
||||
Box::new(git::FileHistory),
|
||||
);
|
||||
match focus {
|
||||
Some(focus) => builder.context(focus),
|
||||
None => builder,
|
||||
|
||||
@@ -29,6 +29,7 @@ pub struct ExtensionHostProxy {
|
||||
slash_command_proxy: RwLock<Option<Arc<dyn ExtensionSlashCommandProxy>>>,
|
||||
context_server_proxy: RwLock<Option<Arc<dyn ExtensionContextServerProxy>>>,
|
||||
debug_adapter_provider_proxy: RwLock<Option<Arc<dyn ExtensionDebugAdapterProviderProxy>>>,
|
||||
language_model_provider_proxy: RwLock<Option<Arc<dyn ExtensionLanguageModelProviderProxy>>>,
|
||||
}
|
||||
|
||||
impl ExtensionHostProxy {
|
||||
@@ -54,6 +55,7 @@ impl ExtensionHostProxy {
|
||||
slash_command_proxy: RwLock::default(),
|
||||
context_server_proxy: RwLock::default(),
|
||||
debug_adapter_provider_proxy: RwLock::default(),
|
||||
language_model_provider_proxy: RwLock::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +92,15 @@ impl ExtensionHostProxy {
|
||||
.write()
|
||||
.replace(Arc::new(proxy));
|
||||
}
|
||||
|
||||
pub fn register_language_model_provider_proxy(
|
||||
&self,
|
||||
proxy: impl ExtensionLanguageModelProviderProxy,
|
||||
) {
|
||||
self.language_model_provider_proxy
|
||||
.write()
|
||||
.replace(Arc::new(proxy));
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ExtensionThemeProxy: Send + Sync + 'static {
|
||||
@@ -375,6 +386,49 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
/// A function that registers a language model provider with the registry.
|
||||
/// This allows extension_host to create the provider (which requires WasmExtension)
|
||||
/// and pass a registration closure to the language_models crate.
|
||||
pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send + Sync + 'static>;
|
||||
|
||||
pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
|
||||
/// Register an LLM provider from an extension.
|
||||
/// The `register_fn` closure will be called with the App context and should
|
||||
/// register the provider with the LanguageModelRegistry.
|
||||
fn register_language_model_provider(
|
||||
&self,
|
||||
provider_id: Arc<str>,
|
||||
register_fn: LanguageModelProviderRegistration,
|
||||
cx: &mut App,
|
||||
);
|
||||
|
||||
/// Unregister an LLM provider when an extension is unloaded.
|
||||
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
|
||||
fn register_language_model_provider(
|
||||
&self,
|
||||
provider_id: Arc<str>,
|
||||
register_fn: LanguageModelProviderRegistration,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.register_language_model_provider(provider_id, register_fn, cx)
|
||||
}
|
||||
|
||||
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
|
||||
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.unregister_language_model_provider(provider_id, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ExtensionHostProxy {
|
||||
fn register_context_server(
|
||||
&self,
|
||||
|
||||
@@ -93,6 +93,8 @@ pub struct ExtensionManifest {
|
||||
pub debug_adapters: BTreeMap<Arc<str>, DebugAdapterManifestEntry>,
|
||||
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
|
||||
pub debug_locators: BTreeMap<Arc<str>, DebugLocatorManifestEntry>,
|
||||
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
|
||||
pub language_model_providers: BTreeMap<Arc<str>, LanguageModelProviderManifestEntry>,
|
||||
}
|
||||
|
||||
impl ExtensionManifest {
|
||||
@@ -288,6 +290,71 @@ pub struct DebugAdapterManifestEntry {
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct DebugLocatorManifestEntry {}
|
||||
|
||||
/// Manifest entry for a language model provider.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelProviderManifestEntry {
|
||||
/// Display name for the provider.
|
||||
pub name: String,
|
||||
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
|
||||
#[serde(default)]
|
||||
pub icon: Option<String>,
|
||||
/// Default models to show even before API connection.
|
||||
#[serde(default)]
|
||||
pub models: Vec<LanguageModelManifestEntry>,
|
||||
/// Authentication configuration.
|
||||
#[serde(default)]
|
||||
pub auth: Option<LanguageModelAuthConfig>,
|
||||
}
|
||||
|
||||
/// Manifest entry for a language model.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelManifestEntry {
|
||||
/// Unique identifier for the model.
|
||||
pub id: String,
|
||||
/// Display name for the model.
|
||||
pub name: String,
|
||||
/// Maximum input token count.
|
||||
#[serde(default)]
|
||||
pub max_token_count: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
#[serde(default)]
|
||||
pub max_output_tokens: Option<u64>,
|
||||
/// Whether the model supports image inputs.
|
||||
#[serde(default)]
|
||||
pub supports_images: bool,
|
||||
/// Whether the model supports tool/function calling.
|
||||
#[serde(default)]
|
||||
pub supports_tools: bool,
|
||||
/// Whether the model supports extended thinking/reasoning.
|
||||
#[serde(default)]
|
||||
pub supports_thinking: bool,
|
||||
}
|
||||
|
||||
/// Authentication configuration for a language model provider.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelAuthConfig {
|
||||
/// Environment variable name for the API key.
|
||||
#[serde(default)]
|
||||
pub env_var: Option<String>,
|
||||
/// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token").
|
||||
#[serde(default)]
|
||||
pub credential_label: Option<String>,
|
||||
/// OAuth configuration for web-based authentication flows.
|
||||
#[serde(default)]
|
||||
pub oauth: Option<OAuthConfig>,
|
||||
}
|
||||
|
||||
/// OAuth configuration for web-based authentication.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct OAuthConfig {
|
||||
/// The text to display on the sign-in button (e.g., "Sign in with GitHub").
|
||||
#[serde(default)]
|
||||
pub sign_in_button_label: Option<String>,
|
||||
/// The icon to display on the sign-in button (e.g., "github").
|
||||
#[serde(default)]
|
||||
pub sign_in_button_icon: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtensionManifest {
|
||||
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
|
||||
let extension_name = extension_dir
|
||||
@@ -358,6 +425,7 @@ fn manifest_from_old_manifest(
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,6 +459,7 @@ mod tests {
|
||||
capabilities: vec![],
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,27 @@ pub use wit::{
|
||||
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
|
||||
latest_github_release,
|
||||
},
|
||||
zed::extension::llm_provider::{
|
||||
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
|
||||
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
|
||||
ImageData as LlmImageData, MessageContent as LlmMessageContent,
|
||||
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
|
||||
ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
|
||||
OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
|
||||
OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
|
||||
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
|
||||
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
|
||||
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
|
||||
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
|
||||
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
|
||||
ToolUseJsonParseError as LlmToolUseJsonParseError,
|
||||
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
|
||||
get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser,
|
||||
oauth_start_web_auth as llm_oauth_start_web_auth,
|
||||
request_credential as llm_request_credential,
|
||||
send_oauth_http_request as llm_oauth_http_request,
|
||||
store_credential as llm_store_credential,
|
||||
},
|
||||
zed::extension::nodejs::{
|
||||
node_binary_path, npm_install_package, npm_package_installed_version,
|
||||
npm_package_latest_version,
|
||||
@@ -259,6 +280,94 @@ pub trait Extension: Send + Sync {
|
||||
) -> Result<DebugRequest, String> {
|
||||
Err("`run_dap_locator` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Returns information about language model providers offered by this extension.
|
||||
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Returns the models available for a provider.
|
||||
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Returns markdown content to display in the provider's settings UI.
|
||||
/// This can include setup instructions, links to documentation, etc.
|
||||
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if the provider is authenticated.
|
||||
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Start an OAuth device flow sign-in.
|
||||
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
|
||||
/// Opens the browser to the verification URL and returns the user code that should
|
||||
/// be displayed to the user.
|
||||
fn llm_provider_start_device_flow_sign_in(
|
||||
&mut self,
|
||||
_provider_id: &str,
|
||||
) -> Result<String, String> {
|
||||
Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Poll for device flow sign-in completion.
|
||||
/// This is called after llm_provider_start_device_flow_sign_in returns the user code.
|
||||
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
|
||||
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
|
||||
Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Reset credentials for the provider.
|
||||
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
|
||||
Err("`llm_provider_reset_credentials` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Count tokens for a request.
|
||||
fn llm_count_tokens(
|
||||
&self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
_request: &LlmCompletionRequest,
|
||||
) -> Result<u64, String> {
|
||||
Err("`llm_count_tokens` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Start streaming a completion from the model.
|
||||
/// Returns a stream ID that can be used with `llm_stream_completion_next` and `llm_stream_completion_close`.
|
||||
fn llm_stream_completion_start(
|
||||
&mut self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
_request: &LlmCompletionRequest,
|
||||
) -> Result<String, String> {
|
||||
Err("`llm_stream_completion_start` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Get the next event from a completion stream.
|
||||
/// Returns `Ok(None)` when the stream is complete.
|
||||
fn llm_stream_completion_next(
|
||||
&mut self,
|
||||
_stream_id: &str,
|
||||
) -> Result<Option<LlmCompletionEvent>, String> {
|
||||
Err("`llm_stream_completion_next` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Close a completion stream and release its resources.
|
||||
fn llm_stream_completion_close(&mut self, _stream_id: &str) {
|
||||
// Default implementation does nothing
|
||||
}
|
||||
|
||||
/// Get cache configuration for a model (if prompt caching is supported).
|
||||
fn llm_cache_configuration(
|
||||
&self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
) -> Option<LlmCacheConfiguration> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers the provided type as a Zed extension.
|
||||
@@ -518,6 +627,65 @@ impl wit::Guest for Component {
|
||||
) -> Result<DebugRequest, String> {
|
||||
extension().run_dap_locator(locator_name, build_task)
|
||||
}
|
||||
|
||||
fn llm_providers() -> Vec<LlmProviderInfo> {
|
||||
extension().llm_providers()
|
||||
}
|
||||
|
||||
fn llm_provider_models(provider_id: String) -> Result<Vec<LlmModelInfo>, String> {
|
||||
extension().llm_provider_models(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_settings_markdown(provider_id: String) -> Option<String> {
|
||||
extension().llm_provider_settings_markdown(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_is_authenticated(provider_id: String) -> bool {
|
||||
extension().llm_provider_is_authenticated(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result<String, String> {
|
||||
extension().llm_provider_start_device_flow_sign_in(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> {
|
||||
extension().llm_provider_poll_device_flow_sign_in(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> {
|
||||
extension().llm_provider_reset_credentials(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_count_tokens(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
request: LlmCompletionRequest,
|
||||
) -> Result<u64, String> {
|
||||
extension().llm_count_tokens(&provider_id, &model_id, &request)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_start(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
request: LlmCompletionRequest,
|
||||
) -> Result<String, String> {
|
||||
extension().llm_stream_completion_start(&provider_id, &model_id, &request)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_next(stream_id: String) -> Result<Option<LlmCompletionEvent>, String> {
|
||||
extension().llm_stream_completion_next(&stream_id)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_close(stream_id: String) {
|
||||
extension().llm_stream_completion_close(&stream_id)
|
||||
}
|
||||
|
||||
fn llm_cache_configuration(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
) -> Option<LlmCacheConfiguration> {
|
||||
extension().llm_cache_configuration(&provider_id, &model_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// The ID of a language server.
|
||||
|
||||
@@ -8,6 +8,7 @@ world extension {
|
||||
import platform;
|
||||
import process;
|
||||
import nodejs;
|
||||
import llm-provider;
|
||||
|
||||
use common.{env-vars, range};
|
||||
use context-server.{context-server-configuration};
|
||||
@@ -15,6 +16,10 @@ world extension {
|
||||
use lsp.{completion, symbol};
|
||||
use process.{command};
|
||||
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
|
||||
use llm-provider.{
|
||||
provider-info, model-info, completion-request,
|
||||
credential-type, cache-configuration, completion-event, token-usage
|
||||
};
|
||||
|
||||
/// Initializes the extension.
|
||||
export init-extension: func();
|
||||
@@ -164,4 +169,74 @@ world extension {
|
||||
export dap-config-to-scenario: func(config: debug-config) -> result<debug-scenario, string>;
|
||||
export dap-locator-create-scenario: func(locator-name: string, build-config-template: build-task-template, resolved-label: string, debug-adapter-name: string) -> option<debug-scenario>;
|
||||
export run-dap-locator: func(locator-name: string, config: resolved-task) -> result<debug-request, string>;
|
||||
|
||||
/// Returns information about language model providers offered by this extension.
|
||||
export llm-providers: func() -> list<provider-info>;
|
||||
|
||||
/// Returns the models available for a provider.
|
||||
export llm-provider-models: func(provider-id: string) -> result<list<model-info>, string>;
|
||||
|
||||
/// Returns markdown content to display in the provider's settings UI.
|
||||
/// This can include setup instructions, links to documentation, etc.
|
||||
export llm-provider-settings-markdown: func(provider-id: string) -> option<string>;
|
||||
|
||||
/// Check if the provider is authenticated.
|
||||
export llm-provider-is-authenticated: func(provider-id: string) -> bool;
|
||||
|
||||
/// Start an OAuth device flow sign-in.
|
||||
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
|
||||
///
|
||||
/// The device flow works as follows:
|
||||
/// 1. Extension requests a device code from the OAuth provider
|
||||
/// 2. Extension opens the verification URL in the browser
|
||||
/// 3. Extension returns the user code to display to the user
|
||||
/// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in
|
||||
/// 5. Extension polls for the access token while user authorizes in browser
|
||||
/// 6. Once authorized, extension stores the credential and returns success
|
||||
///
|
||||
/// Returns the user code that should be displayed to the user while they
|
||||
/// complete authorization in the browser.
|
||||
export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<string, string>;
|
||||
|
||||
/// Poll for device flow sign-in completion.
|
||||
/// This is called after llm-provider-start-device-flow-sign-in returns the user code.
|
||||
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
|
||||
/// Returns Ok(()) on successful authentication, or an error message on failure.
|
||||
export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Reset credentials for the provider.
|
||||
export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Count tokens for a request.
|
||||
export llm-count-tokens: func(
|
||||
provider-id: string,
|
||||
model-id: string,
|
||||
request: completion-request
|
||||
) -> result<u64, string>;
|
||||
|
||||
/// Start streaming a completion from the model.
|
||||
/// Returns a stream ID that can be used with llm-stream-next and llm-stream-close.
|
||||
export llm-stream-completion-start: func(
|
||||
provider-id: string,
|
||||
model-id: string,
|
||||
request: completion-request
|
||||
) -> result<string, string>;
|
||||
|
||||
/// Get the next event from a completion stream.
|
||||
/// Returns None when the stream is complete.
|
||||
export llm-stream-completion-next: func(
|
||||
stream-id: string
|
||||
) -> result<option<completion-event>, string>;
|
||||
|
||||
/// Close a completion stream and release its resources.
|
||||
export llm-stream-completion-close: func(
|
||||
stream-id: string
|
||||
);
|
||||
|
||||
/// Get cache configuration for a model (if prompt caching is supported).
|
||||
export llm-cache-configuration: func(
|
||||
provider-id: string,
|
||||
model-id: string
|
||||
) -> option<cache-configuration>;
|
||||
|
||||
}
|
||||
|
||||
348
crates/extension_api/wit/since_v0.8.0/llm-provider.wit
Normal file
348
crates/extension_api/wit/since_v0.8.0/llm-provider.wit
Normal file
@@ -0,0 +1,348 @@
|
||||
interface llm-provider {
|
||||
/// Information about a language model provider.
|
||||
record provider-info {
|
||||
/// Unique identifier for the provider (e.g., "my-extension.my-provider").
|
||||
id: string,
|
||||
/// Display name for the provider.
|
||||
name: string,
|
||||
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
|
||||
icon: option<string>,
|
||||
}
|
||||
|
||||
/// Capabilities of a language model.
|
||||
record model-capabilities {
|
||||
/// Whether the model supports image inputs.
|
||||
supports-images: bool,
|
||||
/// Whether the model supports tool/function calling.
|
||||
supports-tools: bool,
|
||||
/// Whether the model supports the "auto" tool choice.
|
||||
supports-tool-choice-auto: bool,
|
||||
/// Whether the model supports the "any" tool choice.
|
||||
supports-tool-choice-any: bool,
|
||||
/// Whether the model supports the "none" tool choice.
|
||||
supports-tool-choice-none: bool,
|
||||
/// Whether the model supports extended thinking/reasoning.
|
||||
supports-thinking: bool,
|
||||
/// The format for tool input schemas.
|
||||
tool-input-format: tool-input-format,
|
||||
}
|
||||
|
||||
/// Format for tool input schemas.
|
||||
enum tool-input-format {
|
||||
/// Standard JSON Schema format.
|
||||
json-schema,
|
||||
/// Simplified schema format for certain providers.
|
||||
simplified,
|
||||
}
|
||||
|
||||
/// Information about a specific model.
|
||||
record model-info {
|
||||
/// Unique identifier for the model.
|
||||
id: string,
|
||||
/// Display name for the model.
|
||||
name: string,
|
||||
/// Maximum input token count.
|
||||
max-token-count: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
max-output-tokens: option<u64>,
|
||||
/// Model capabilities.
|
||||
capabilities: model-capabilities,
|
||||
/// Whether this is the default model for the provider.
|
||||
is-default: bool,
|
||||
/// Whether this is the default fast model.
|
||||
is-default-fast: bool,
|
||||
}
|
||||
|
||||
/// The role of a message participant.
|
||||
enum message-role {
|
||||
/// User message.
|
||||
user,
|
||||
/// Assistant message.
|
||||
assistant,
|
||||
/// System message.
|
||||
system,
|
||||
}
|
||||
|
||||
/// A message in a completion request.
|
||||
record request-message {
|
||||
/// The role of the message sender.
|
||||
role: message-role,
|
||||
/// The content of the message.
|
||||
content: list<message-content>,
|
||||
/// Whether to cache this message for prompt caching.
|
||||
cache: bool,
|
||||
}
|
||||
|
||||
/// Content within a message.
|
||||
variant message-content {
|
||||
/// Plain text content.
|
||||
text(string),
|
||||
/// Image content.
|
||||
image(image-data),
|
||||
/// A tool use request from the assistant.
|
||||
tool-use(tool-use),
|
||||
/// A tool result from the user.
|
||||
tool-result(tool-result),
|
||||
/// Thinking/reasoning content.
|
||||
thinking(thinking-content),
|
||||
/// Redacted/encrypted thinking content.
|
||||
redacted-thinking(string),
|
||||
}
|
||||
|
||||
/// Image data for vision models.
|
||||
record image-data {
|
||||
/// Base64-encoded image data.
|
||||
source: string,
|
||||
/// Image width in pixels (optional).
|
||||
width: option<u32>,
|
||||
/// Image height in pixels (optional).
|
||||
height: option<u32>,
|
||||
}
|
||||
|
||||
/// A tool use request from the model.
|
||||
record tool-use {
|
||||
/// Unique identifier for this tool use.
|
||||
id: string,
|
||||
/// The name of the tool being used.
|
||||
name: string,
|
||||
/// JSON string of the tool input arguments.
|
||||
input: string,
|
||||
/// Thought signature for providers that support it (e.g., Anthropic).
|
||||
thought-signature: option<string>,
|
||||
}
|
||||
|
||||
/// A tool result to send back to the model.
|
||||
record tool-result {
|
||||
/// The ID of the tool use this is a result for.
|
||||
tool-use-id: string,
|
||||
/// The name of the tool.
|
||||
tool-name: string,
|
||||
/// Whether this result represents an error.
|
||||
is-error: bool,
|
||||
/// The content of the result.
|
||||
content: tool-result-content,
|
||||
}
|
||||
|
||||
/// Content of a tool result.
|
||||
variant tool-result-content {
|
||||
/// Text result.
|
||||
text(string),
|
||||
/// Image result.
|
||||
image(image-data),
|
||||
}
|
||||
|
||||
/// Thinking/reasoning content from models that support extended thinking.
|
||||
record thinking-content {
|
||||
/// The thinking text.
|
||||
text: string,
|
||||
/// Signature for the thinking block (provider-specific).
|
||||
signature: option<string>,
|
||||
}
|
||||
|
||||
/// A tool definition for function calling.
|
||||
record tool-definition {
|
||||
/// The name of the tool.
|
||||
name: string,
|
||||
/// Description of what the tool does.
|
||||
description: string,
|
||||
/// JSON Schema for input parameters.
|
||||
input-schema: string,
|
||||
}
|
||||
|
||||
/// Tool choice preference for the model.
|
||||
enum tool-choice {
|
||||
/// Let the model decide whether to use tools.
|
||||
auto,
|
||||
/// Force the model to use at least one tool.
|
||||
any,
|
||||
/// Prevent the model from using tools.
|
||||
none,
|
||||
}
|
||||
|
||||
/// A completion request to send to the model.
|
||||
record completion-request {
|
||||
/// The messages in the conversation.
|
||||
messages: list<request-message>,
|
||||
/// Available tools for the model to use.
|
||||
tools: list<tool-definition>,
|
||||
/// Tool choice preference.
|
||||
tool-choice: option<tool-choice>,
|
||||
/// Stop sequences to end generation.
|
||||
stop-sequences: list<string>,
|
||||
/// Temperature for sampling (0.0-1.0).
|
||||
temperature: option<f32>,
|
||||
/// Whether thinking/reasoning is allowed.
|
||||
thinking-allowed: bool,
|
||||
/// Maximum tokens to generate.
|
||||
max-tokens: option<u64>,
|
||||
}
|
||||
|
||||
/// Events emitted during completion streaming.
|
||||
variant completion-event {
|
||||
/// Completion has started.
|
||||
started,
|
||||
/// Text content chunk.
|
||||
text(string),
|
||||
/// Thinking/reasoning content chunk.
|
||||
thinking(thinking-content),
|
||||
/// Redacted thinking (encrypted) chunk.
|
||||
redacted-thinking(string),
|
||||
/// Tool use request from the model.
|
||||
tool-use(tool-use),
|
||||
/// JSON parse error when parsing tool input.
|
||||
tool-use-json-parse-error(tool-use-json-parse-error),
|
||||
/// Completion stopped.
|
||||
stop(stop-reason),
|
||||
/// Token usage update.
|
||||
usage(token-usage),
|
||||
/// Reasoning details (provider-specific JSON).
|
||||
reasoning-details(string),
|
||||
}
|
||||
|
||||
/// Error information when tool use JSON parsing fails.
|
||||
record tool-use-json-parse-error {
|
||||
/// The tool use ID.
|
||||
id: string,
|
||||
/// The tool name.
|
||||
tool-name: string,
|
||||
/// The raw input that failed to parse.
|
||||
raw-input: string,
|
||||
/// The parse error message.
|
||||
error: string,
|
||||
}
|
||||
|
||||
/// Reason the completion stopped.
|
||||
enum stop-reason {
|
||||
/// The model finished generating.
|
||||
end-turn,
|
||||
/// Maximum tokens reached.
|
||||
max-tokens,
|
||||
/// The model wants to use a tool.
|
||||
tool-use,
|
||||
/// The model refused to respond.
|
||||
refusal,
|
||||
}
|
||||
|
||||
/// Token usage statistics.
|
||||
record token-usage {
|
||||
/// Number of input tokens used.
|
||||
input-tokens: u64,
|
||||
/// Number of output tokens generated.
|
||||
output-tokens: u64,
|
||||
/// Tokens used for cache creation (if supported).
|
||||
cache-creation-input-tokens: option<u64>,
|
||||
/// Tokens read from cache (if supported).
|
||||
cache-read-input-tokens: option<u64>,
|
||||
}
|
||||
|
||||
/// Credential types that can be requested.
|
||||
enum credential-type {
|
||||
/// An API key.
|
||||
api-key,
|
||||
/// An OAuth token.
|
||||
oauth-token,
|
||||
}
|
||||
|
||||
/// Cache configuration for prompt caching.
|
||||
record cache-configuration {
|
||||
/// Maximum number of cache anchors.
|
||||
max-cache-anchors: u32,
|
||||
/// Whether caching should be applied to tool definitions.
|
||||
should-cache-tool-definitions: bool,
|
||||
/// Minimum token count for a message to be cached.
|
||||
min-total-token-count: u64,
|
||||
}
|
||||
|
||||
/// Configuration for starting an OAuth web authentication flow.
|
||||
record oauth-web-auth-config {
|
||||
/// The URL to open in the user's browser to start authentication.
|
||||
/// This should include client_id, redirect_uri, scope, state, etc.
|
||||
/// Use `{port}` as a placeholder in the URL - it will be replaced with
|
||||
/// the actual localhost port before opening the browser.
|
||||
/// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback"
|
||||
auth-url: string,
|
||||
/// The path to listen on for the OAuth callback (e.g., "/callback").
|
||||
/// A localhost server will be started to receive the redirect.
|
||||
callback-path: string,
|
||||
/// Timeout in seconds to wait for the callback (default: 300 = 5 minutes).
|
||||
timeout-secs: option<u32>,
|
||||
}
|
||||
|
||||
/// Result of an OAuth web authentication flow.
|
||||
record oauth-web-auth-result {
|
||||
/// The full callback URL that was received, including query parameters.
|
||||
/// The extension is responsible for parsing the code, state, etc.
|
||||
callback-url: string,
|
||||
/// The port that was used for the localhost callback server.
|
||||
port: u32,
|
||||
}
|
||||
|
||||
/// A generic HTTP request for OAuth token exchange.
|
||||
record oauth-http-request {
|
||||
/// The URL to request.
|
||||
url: string,
|
||||
/// HTTP method (e.g., "POST", "GET").
|
||||
method: string,
|
||||
/// Request headers as key-value pairs.
|
||||
headers: list<tuple<string, string>>,
|
||||
/// Request body as a string (for form-encoded or JSON bodies).
|
||||
body: string,
|
||||
}
|
||||
|
||||
/// Response from an OAuth HTTP request.
|
||||
record oauth-http-response {
|
||||
/// HTTP status code.
|
||||
status: u16,
|
||||
/// Response headers as key-value pairs.
|
||||
headers: list<tuple<string, string>>,
|
||||
/// Response body as a string.
|
||||
body: string,
|
||||
}
|
||||
|
||||
/// Request a credential from the user.
|
||||
/// Returns true if the credential was provided, false if the user cancelled.
|
||||
request-credential: func(
|
||||
provider-id: string,
|
||||
credential-type: credential-type,
|
||||
label: string,
|
||||
placeholder: string
|
||||
) -> result<bool, string>;
|
||||
|
||||
/// Get a stored credential for this provider.
|
||||
get-credential: func(provider-id: string) -> option<string>;
|
||||
|
||||
/// Store a credential for this provider.
|
||||
store-credential: func(provider-id: string, value: string) -> result<_, string>;
|
||||
|
||||
/// Delete a stored credential for this provider.
|
||||
delete-credential: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Read an environment variable.
|
||||
get-env-var: func(name: string) -> option<string>;
|
||||
|
||||
/// Start an OAuth web authentication flow.
|
||||
///
|
||||
/// This will:
|
||||
/// 1. Start a localhost server to receive the OAuth callback
|
||||
/// 2. Open the auth URL in the user's default browser
|
||||
/// 3. Wait for the callback (up to the timeout)
|
||||
/// 4. Return the callback URL with query parameters
|
||||
///
|
||||
/// The extension is responsible for:
|
||||
/// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc.
|
||||
/// - Parsing the callback URL to extract the authorization code
|
||||
/// - Exchanging the code for tokens using oauth-http-request
|
||||
oauth-start-web-auth: func(config: oauth-web-auth-config) -> result<oauth-web-auth-result, string>;
|
||||
|
||||
/// Make an HTTP request for OAuth token exchange.
|
||||
///
|
||||
/// This is a simple HTTP client for OAuth flows, allowing the extension
|
||||
/// to handle token exchange with full control over serialization.
|
||||
send-oauth-http-request: func(request: oauth-http-request) -> result<oauth-http-response, string>;
|
||||
|
||||
/// Open a URL in the user's default browser.
|
||||
///
|
||||
/// Useful for OAuth flows that need to open a browser but handle the
|
||||
/// callback differently (e.g., polling-based flows).
|
||||
oauth-open-browser: func(url: string) -> result<_, string>;
|
||||
}
|
||||
@@ -255,6 +255,21 @@ async fn copy_extension_resources(
|
||||
}
|
||||
}
|
||||
|
||||
for (_, provider_entry) in &manifest.language_model_providers {
|
||||
if let Some(icon_path) = &provider_entry.icon {
|
||||
let source_icon = extension_path.join(icon_path);
|
||||
let dest_icon = output_dir.join(icon_path);
|
||||
|
||||
// Create parent directory if needed
|
||||
if let Some(parent) = dest_icon.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
fs::copy(&source_icon, &dest_icon)
|
||||
.with_context(|| format!("failed to copy LLM provider icon '{}'", icon_path))?;
|
||||
}
|
||||
}
|
||||
|
||||
if !manifest.languages.is_empty() {
|
||||
let output_languages_dir = output_dir.join("languages");
|
||||
fs::create_dir_all(&output_languages_dir)?;
|
||||
|
||||
@@ -22,7 +22,10 @@ async-tar.workspace = true
|
||||
async-trait.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
dap.workspace = true
|
||||
dirs.workspace = true
|
||||
editor.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -30,8 +33,11 @@ gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
markdown.workspace = true
|
||||
lsp.workspace = true
|
||||
menu.workspace = true
|
||||
moka.workspace = true
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -43,10 +49,13 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
task.workspace = true
|
||||
telemetry.workspace = true
|
||||
tempfile.workspace = true
|
||||
theme.workspace = true
|
||||
toml.workspace = true
|
||||
ui.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
wasmparser.workspace = true
|
||||
|
||||
@@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest {
|
||||
)],
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
124
crates/extension_host/src/anthropic_migration.rs
Normal file
124
crates/extension_host/src/anthropic_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const ANTHROPIC_EXTENSION_ID: &str = "anthropic";
|
||||
const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
|
||||
const ANTHROPIC_DEFAULT_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
/// Migrates Anthropic API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_anthropic_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != ANTHROPIC_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
ANTHROPIC_EXTENSION_ID, ANTHROPIC_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(ANTHROPIC_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing Anthropic API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode Anthropic API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing Anthropic API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Anthropic API key to Anthropic extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Anthropic API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Anthropic API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-ant-test-key-12345";
|
||||
|
||||
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-ant-test-key";
|
||||
|
||||
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -113,6 +113,7 @@ mod tests {
|
||||
capabilities: vec![],
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
216
crates/extension_host/src/copilot_migration.rs
Normal file
216
crates/extension_host/src/copilot_migration.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat";
|
||||
const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat";
|
||||
|
||||
/// Migrates Copilot OAuth credentials from the GitHub Copilot config files
|
||||
/// to the new extension-based credential location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != COPILOT_CHAT_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
// Read from copilot config files
|
||||
let oauth_token = match read_copilot_oauth_token().await {
|
||||
Some(token) if !token.is_empty() => token,
|
||||
_ => {
|
||||
log::debug!("No existing Copilot OAuth token found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &_cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Copilot OAuth token: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
async fn read_copilot_oauth_token() -> Option<String> {
|
||||
let config_paths = copilot_config_paths();
|
||||
|
||||
for path in config_paths {
|
||||
if let Some(token) = read_oauth_token_from_file(&path).await {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn copilot_config_paths() -> Vec<PathBuf> {
|
||||
let config_dir = if cfg!(target_os = "windows") {
|
||||
dirs::data_local_dir()
|
||||
} else {
|
||||
std::env::var("XDG_CONFIG_HOME")
|
||||
.map(PathBuf::from)
|
||||
.ok()
|
||||
.or_else(|| dirs::home_dir().map(|h| h.join(".config")))
|
||||
};
|
||||
|
||||
let Some(config_dir) = config_dir else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let copilot_dir = config_dir.join("github-copilot");
|
||||
|
||||
vec![
|
||||
copilot_dir.join("hosts.json"),
|
||||
copilot_dir.join("apps.json"),
|
||||
]
|
||||
}
|
||||
|
||||
async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
|
||||
let contents = match smol::fs::read_to_string(path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
extract_oauth_token(&contents, "github.com")
|
||||
}
|
||||
|
||||
fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(contents).ok()?;
|
||||
let obj = value.as_object()?;
|
||||
|
||||
for (key, value) in obj.iter() {
|
||||
if key.starts_with(domain) {
|
||||
if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
|
||||
return Some(token.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_from_hosts_json() {
|
||||
let contents = r#"{
|
||||
"github.com": {
|
||||
"oauth_token": "ghu_test_token_12345"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("ghu_test_token_12345".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_with_user_suffix() {
|
||||
let contents = r#"{
|
||||
"github.com:user": {
|
||||
"oauth_token": "ghu_another_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("ghu_another_token".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_wrong_domain() {
|
||||
let contents = r#"{
|
||||
"gitlab.com": {
|
||||
"oauth_token": "some_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_invalid_json() {
|
||||
let contents = "not valid json";
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_missing_oauth_token_field() {
|
||||
let contents = r#"{
|
||||
"github.com": {
|
||||
"user": "testuser"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_multiple_entries_picks_first_match() {
|
||||
let contents = r#"{
|
||||
"gitlab.com": {
|
||||
"oauth_token": "gitlab_token"
|
||||
},
|
||||
"github.com": {
|
||||
"oauth_token": "github_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("github_token".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_copilot_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials for other extensions"
|
||||
);
|
||||
}
|
||||
|
||||
// Note: Unlike the other migrations, copilot migration reads from the filesystem
|
||||
// (copilot config files), not from the credentials provider. In tests, these files
|
||||
// don't exist, so no migration occurs.
|
||||
#[gpui::test]
|
||||
async fn test_no_credentials_when_no_copilot_config_exists(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"No credentials should be written when copilot config doesn't exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
mod anthropic_migration;
|
||||
mod capability_granter;
|
||||
mod copilot_migration;
|
||||
pub mod extension_settings;
|
||||
mod google_ai_migration;
|
||||
pub mod headless_host;
|
||||
mod open_router_migration;
|
||||
mod openai_migration;
|
||||
pub mod wasm_host;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -12,13 +17,14 @@ use async_tar::Archive;
|
||||
use client::ExtensionProvides;
|
||||
use client::{Client, ExtensionMetadata, GetExtensionsResponse, proto, telemetry::Telemetry};
|
||||
use collections::{BTreeMap, BTreeSet, HashSet, btree_map};
|
||||
|
||||
pub use extension::ExtensionManifest;
|
||||
use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder};
|
||||
use extension::{
|
||||
ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy,
|
||||
ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy,
|
||||
ExtensionThemeProxy,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageModelProviderProxy,
|
||||
ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy,
|
||||
ExtensionSnippetProxy, ExtensionThemeProxy,
|
||||
};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::future::join_all;
|
||||
@@ -32,8 +38,8 @@ use futures::{
|
||||
select_biased,
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task, WeakEntity,
|
||||
actions,
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, SharedString, Task,
|
||||
WeakEntity, actions,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use language::{
|
||||
@@ -53,15 +59,24 @@ use std::{
|
||||
cmp::Ordering,
|
||||
path::{self, Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
time::Duration,
|
||||
};
|
||||
use url::Url;
|
||||
use util::{ResultExt, paths::RemotePathBuf};
|
||||
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
|
||||
use wasm_host::{
|
||||
WasmExtension, WasmHost,
|
||||
wit::{is_supported_wasm_api_version, wasm_api_version_range},
|
||||
wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
|
||||
};
|
||||
|
||||
struct LlmProviderWithModels {
|
||||
provider_info: LlmProviderInfo,
|
||||
models: Vec<LlmModelInfo>,
|
||||
is_authenticated: bool,
|
||||
icon_path: Option<SharedString>,
|
||||
auth_config: Option<extension::LanguageModelAuthConfig>,
|
||||
}
|
||||
|
||||
pub use extension::{
|
||||
ExtensionLibraryKind, GrammarManifestEntry, OldExtensionManifest, SchemaVersion,
|
||||
};
|
||||
@@ -70,6 +85,79 @@ pub use extension_settings::ExtensionSettings;
|
||||
pub const RELOAD_DEBOUNCE_DURATION: Duration = Duration::from_millis(200);
|
||||
const FS_WATCH_LATENCY: Duration = Duration::from_millis(100);
|
||||
|
||||
/// Extension IDs that are being migrated from hardcoded LLM providers.
|
||||
/// For backwards compatibility, if the user has the corresponding env var set,
|
||||
/// we automatically enable env var reading for these extensions on first install.
|
||||
const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
|
||||
"anthropic",
|
||||
"copilot-chat",
|
||||
"google-ai",
|
||||
"openrouter",
|
||||
"openai",
|
||||
];
|
||||
|
||||
/// Migrates legacy LLM provider extensions by auto-enabling env var reading
|
||||
/// if the env var is currently present in the environment.
|
||||
///
|
||||
/// This is idempotent: if the provider is already in `allowed_env_var_providers`,
|
||||
/// we skip. This means if a user explicitly removes it, it will be re-added on
|
||||
/// next launch if the env var is still set - but that's predictable behavior.
|
||||
fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
|
||||
// Only apply migration to known legacy LLM extensions
|
||||
if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check each provider in the manifest
|
||||
for (provider_id, provider_entry) in &manifest.language_model_providers {
|
||||
let Some(auth_config) = &provider_entry.auth else {
|
||||
continue;
|
||||
};
|
||||
let Some(env_var_name) = &auth_config.env_var else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", manifest.id, provider_id).into();
|
||||
|
||||
// Check if the env var is present and non-empty
|
||||
let env_var_is_set = std::env::var(env_var_name)
|
||||
.map(|v| !v.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
// If env var isn't set, no need to do anything
|
||||
if !env_var_is_set {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if already enabled in settings
|
||||
let already_enabled = ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(full_provider_id.as_ref());
|
||||
|
||||
if already_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Enable env var reading since the env var is set
|
||||
settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
|
||||
let full_provider_id = full_provider_id.clone();
|
||||
move |settings, _| {
|
||||
let providers = settings
|
||||
.extension
|
||||
.allowed_env_var_providers
|
||||
.get_or_insert_with(Vec::new);
|
||||
|
||||
if !providers
|
||||
.iter()
|
||||
.any(|id| id.as_ref() == full_provider_id.as_ref())
|
||||
{
|
||||
providers.push(full_provider_id);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// The current extension [`SchemaVersion`] supported by Zed.
|
||||
const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1);
|
||||
|
||||
@@ -131,6 +219,8 @@ pub struct ExtensionStore {
|
||||
pub enum ExtensionOperation {
|
||||
Upgrade,
|
||||
Install,
|
||||
/// Auto-install from settings - triggers legacy LLM provider migrations
|
||||
AutoInstall,
|
||||
Remove,
|
||||
}
|
||||
|
||||
@@ -606,15 +696,68 @@ impl ExtensionStore {
|
||||
.extension_index
|
||||
.extensions
|
||||
.contains_key(extension_id.as_ref());
|
||||
!is_already_installed && !SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref())
|
||||
let dominated = SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref());
|
||||
!is_already_installed && !dominated
|
||||
})
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
for extension_id in extensions_to_install {
|
||||
// When enabled, this checks if an extension exists locally in the repo's extensions/
|
||||
// directory and installs it as a dev extension instead of fetching from the registry.
|
||||
// This is useful for testing auto-installed extensions before they've been published.
|
||||
// Set to `true` only during local development/testing of new auto-install extensions.
|
||||
#[cfg(debug_assertions)]
|
||||
const DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS: bool = false;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
if DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS {
|
||||
let local_extension_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.parent()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("extensions")
|
||||
.join(extension_id.as_ref());
|
||||
|
||||
if local_extension_path.exists() {
|
||||
// Force-remove existing extension directory if it exists and isn't a symlink
|
||||
// This handles the case where the extension was previously installed from the registry
|
||||
if let Some(installed_dir) = this
|
||||
.update(cx, |this, _cx| this.installed_dir.clone())
|
||||
.ok()
|
||||
{
|
||||
let existing_path = installed_dir.join(extension_id.as_ref());
|
||||
if existing_path.exists() {
|
||||
let metadata = std::fs::symlink_metadata(&existing_path);
|
||||
let is_symlink = metadata.map(|m| m.is_symlink()).unwrap_or(false);
|
||||
if !is_symlink {
|
||||
if let Err(e) = std::fs::remove_dir_all(&existing_path) {
|
||||
log::error!(
|
||||
"Failed to remove existing extension directory {:?}: {}",
|
||||
existing_path,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(task) = this
|
||||
.update(cx, |this, cx| {
|
||||
this.install_dev_extension(local_extension_path, cx)
|
||||
})
|
||||
.ok()
|
||||
{
|
||||
task.await.log_err();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.install_latest_extension(extension_id.clone(), cx);
|
||||
this.auto_install_latest_extension(extension_id.clone(), cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -769,7 +912,10 @@ impl ExtensionStore {
|
||||
this.update(cx, |this, cx| this.reload(Some(extension_id.clone()), cx))?
|
||||
.await;
|
||||
|
||||
if let ExtensionOperation::Install = operation {
|
||||
if matches!(
|
||||
operation,
|
||||
ExtensionOperation::Install | ExtensionOperation::AutoInstall
|
||||
) {
|
||||
this.update(cx, |this, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
|
||||
if let Some(events) = ExtensionEvents::try_global(cx)
|
||||
@@ -779,6 +925,27 @@ impl ExtensionStore {
|
||||
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
|
||||
});
|
||||
}
|
||||
|
||||
// Run legacy LLM provider migrations only for auto-installed extensions
|
||||
if matches!(operation, ExtensionOperation::AutoInstall) {
|
||||
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
|
||||
migrate_legacy_llm_provider_env_var(&manifest, cx);
|
||||
}
|
||||
copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
|
||||
anthropic_migration::migrate_anthropic_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
google_ai_migration::migrate_google_ai_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx);
|
||||
open_router_migration::migrate_open_router_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -788,8 +955,24 @@ impl ExtensionStore {
|
||||
}
|
||||
|
||||
pub fn install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
|
||||
log::info!("installing extension {extension_id} latest version");
|
||||
self.install_latest_extension_with_operation(extension_id, ExtensionOperation::Install, cx);
|
||||
}
|
||||
|
||||
/// Auto-install an extension, triggering legacy LLM provider migrations.
|
||||
fn auto_install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
|
||||
self.install_latest_extension_with_operation(
|
||||
extension_id,
|
||||
ExtensionOperation::AutoInstall,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
fn install_latest_extension_with_operation(
|
||||
&mut self,
|
||||
extension_id: Arc<str>,
|
||||
operation: ExtensionOperation,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let schema_versions = schema_version_range();
|
||||
let wasm_api_versions = wasm_api_version_range(ReleaseChannel::global(cx));
|
||||
|
||||
@@ -812,13 +995,8 @@ impl ExtensionStore {
|
||||
return;
|
||||
};
|
||||
|
||||
self.install_or_upgrade_extension_at_endpoint(
|
||||
extension_id,
|
||||
url,
|
||||
ExtensionOperation::Install,
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
self.install_or_upgrade_extension_at_endpoint(extension_id, url, operation, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn upgrade_extension(
|
||||
@@ -837,7 +1015,6 @@ impl ExtensionStore {
|
||||
operation: ExtensionOperation,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!("installing extension {extension_id} {version}");
|
||||
let Some(url) = self
|
||||
.http_client
|
||||
.build_zed_api_url(
|
||||
@@ -1134,18 +1311,6 @@ impl ExtensionStore {
|
||||
return Task::ready(());
|
||||
}
|
||||
|
||||
let reload_count = extensions_to_unload
|
||||
.iter()
|
||||
.filter(|id| extensions_to_load.contains(id))
|
||||
.count();
|
||||
|
||||
log::info!(
|
||||
"extensions updated. loading {}, reloading {}, unloading {}",
|
||||
extensions_to_load.len() - reload_count,
|
||||
reload_count,
|
||||
extensions_to_unload.len() - reload_count
|
||||
);
|
||||
|
||||
let extension_ids = extensions_to_load
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
@@ -1220,6 +1385,11 @@ impl ExtensionStore {
|
||||
for command_name in extension.manifest.slash_commands.keys() {
|
||||
self.proxy.unregister_slash_command(command_name.clone());
|
||||
}
|
||||
for provider_id in extension.manifest.language_model_providers.keys() {
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
self.proxy
|
||||
.unregister_language_model_provider(full_provider_id, cx);
|
||||
}
|
||||
}
|
||||
|
||||
self.wasm_extensions
|
||||
@@ -1358,7 +1528,11 @@ impl ExtensionStore {
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut wasm_extensions = Vec::new();
|
||||
let mut wasm_extensions: Vec<(
|
||||
Arc<ExtensionManifest>,
|
||||
WasmExtension,
|
||||
Vec<LlmProviderWithModels>,
|
||||
)> = Vec::new();
|
||||
for extension in extension_entries {
|
||||
if extension.manifest.lib.kind.is_none() {
|
||||
continue;
|
||||
@@ -1376,7 +1550,122 @@ impl ExtensionStore {
|
||||
|
||||
match wasm_extension {
|
||||
Ok(wasm_extension) => {
|
||||
wasm_extensions.push((extension.manifest.clone(), wasm_extension))
|
||||
// Query for LLM providers if the manifest declares any
|
||||
let mut llm_providers_with_models = Vec::new();
|
||||
if !extension.manifest.language_model_providers.is_empty() {
|
||||
let providers_result = wasm_extension
|
||||
.call(|ext, store| {
|
||||
async move { ext.call_llm_providers(store).await }.boxed()
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Ok(Ok(providers)) = providers_result {
|
||||
for provider_info in providers {
|
||||
let models_result = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_provider_models(store, &provider_id)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let models: Vec<LlmModelInfo> = match models_result {
|
||||
Ok(Ok(Ok(models))) => models,
|
||||
Ok(Ok(Err(e))) => {
|
||||
log::error!(
|
||||
"Failed to get models for LLM provider {} in extension {}: {}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
log::error!(
|
||||
"Wasm error calling llm_provider_models for {} in extension {}: {:?}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"Extension call failed for llm_provider_models {} in extension {}: {:?}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Query initial authentication state
|
||||
let is_authenticated = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_provider_is_authenticated(
|
||||
store,
|
||||
&provider_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(Ok(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Resolve icon path if provided
|
||||
let icon_path = provider_info.icon.as_ref().map(|icon| {
|
||||
let icon_file_path = extension_path.join(icon);
|
||||
// Canonicalize to resolve symlinks (dev extensions are symlinked)
|
||||
let absolute_icon_path = icon_file_path
|
||||
.canonicalize()
|
||||
.unwrap_or(icon_file_path)
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
SharedString::from(absolute_icon_path)
|
||||
});
|
||||
|
||||
let provider_id_arc: Arc<str> =
|
||||
provider_info.id.as_str().into();
|
||||
let auth_config = extension
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(&provider_id_arc)
|
||||
.and_then(|entry| entry.auth.clone());
|
||||
|
||||
llm_providers_with_models.push(LlmProviderWithModels {
|
||||
provider_info,
|
||||
models,
|
||||
is_authenticated,
|
||||
icon_path,
|
||||
auth_config,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
log::error!(
|
||||
"Failed to get LLM providers from extension {}: {:?}",
|
||||
extension.manifest.id,
|
||||
providers_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
wasm_extensions.push((
|
||||
extension.manifest.clone(),
|
||||
wasm_extension,
|
||||
llm_providers_with_models,
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
@@ -1395,7 +1684,7 @@ impl ExtensionStore {
|
||||
this.update(cx, |this, cx| {
|
||||
this.reload_complete_senders.clear();
|
||||
|
||||
for (manifest, wasm_extension) in &wasm_extensions {
|
||||
for (manifest, wasm_extension, llm_providers_with_models) in &wasm_extensions {
|
||||
let extension = Arc::new(wasm_extension.clone());
|
||||
|
||||
for (language_server_id, language_server_config) in &manifest.language_servers {
|
||||
@@ -1449,9 +1738,41 @@ impl ExtensionStore {
|
||||
this.proxy
|
||||
.register_debug_locator(extension.clone(), debug_adapter.clone());
|
||||
}
|
||||
|
||||
// Register LLM providers
|
||||
for llm_provider in llm_providers_with_models {
|
||||
let provider_id: Arc<str> =
|
||||
format!("{}:{}", manifest.id, llm_provider.provider_info.id).into();
|
||||
let wasm_ext = extension.as_ref().clone();
|
||||
let pinfo = llm_provider.provider_info.clone();
|
||||
let mods = llm_provider.models.clone();
|
||||
let auth = llm_provider.is_authenticated;
|
||||
let icon = llm_provider.icon_path.clone();
|
||||
let auth_config = llm_provider.auth_config.clone();
|
||||
|
||||
this.proxy.register_language_model_provider(
|
||||
provider_id.clone(),
|
||||
Box::new(move |cx: &mut App| {
|
||||
let provider = Arc::new(ExtensionLanguageModelProvider::new(
|
||||
wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
|
||||
));
|
||||
language_model::LanguageModelRegistry::global(cx).update(
|
||||
cx,
|
||||
|registry, cx| {
|
||||
registry.register_provider(provider, cx);
|
||||
},
|
||||
);
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
this.wasm_extensions.extend(wasm_extensions);
|
||||
let wasm_extensions_without_llm: Vec<_> = wasm_extensions
|
||||
.into_iter()
|
||||
.map(|(manifest, ext, _)| (manifest, ext))
|
||||
.collect();
|
||||
this.wasm_extensions.extend(wasm_extensions_without_llm);
|
||||
this.proxy.set_extensions_loaded();
|
||||
this.proxy.reload_current_theme(cx);
|
||||
this.proxy.reload_current_icon_theme(cx);
|
||||
@@ -1473,7 +1794,6 @@ impl ExtensionStore {
|
||||
let index_path = self.index_path.clone();
|
||||
let proxy = self.proxy.clone();
|
||||
cx.background_spawn(async move {
|
||||
let start_time = Instant::now();
|
||||
let mut index = ExtensionIndex::default();
|
||||
|
||||
fs.create_dir(&work_dir).await.log_err();
|
||||
@@ -1511,7 +1831,6 @@ impl ExtensionStore {
|
||||
.log_err();
|
||||
}
|
||||
|
||||
log::info!("rebuilt extension index in {:?}", start_time.elapsed());
|
||||
index
|
||||
})
|
||||
}
|
||||
@@ -1785,11 +2104,6 @@ impl ExtensionStore {
|
||||
})?,
|
||||
path_style,
|
||||
);
|
||||
log::info!(
|
||||
"Uploading extension {} to {:?}",
|
||||
missing_extension.clone().id,
|
||||
dest_dir
|
||||
);
|
||||
|
||||
client
|
||||
.update(cx, |client, cx| {
|
||||
@@ -1797,11 +2111,6 @@ impl ExtensionStore {
|
||||
})?
|
||||
.await?;
|
||||
|
||||
log::info!(
|
||||
"Finished uploading extension {}",
|
||||
missing_extension.clone().id
|
||||
);
|
||||
|
||||
let result = client
|
||||
.update(cx, |client, _cx| {
|
||||
client.proto_client().request(proto::InstallExtension {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, HashSet};
|
||||
use extension::{
|
||||
DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability,
|
||||
};
|
||||
@@ -16,6 +16,10 @@ pub struct ExtensionSettings {
|
||||
pub auto_install_extensions: HashMap<Arc<str>, bool>,
|
||||
pub auto_update_extensions: HashMap<Arc<str>, bool>,
|
||||
pub granted_capabilities: Vec<ExtensionCapability>,
|
||||
/// The extension language model providers that are allowed to read API keys
|
||||
/// from environment variables. Each entry is a provider ID in the format
|
||||
/// "extension_id:provider_id".
|
||||
pub allowed_env_var_providers: HashSet<Arc<str>>,
|
||||
}
|
||||
|
||||
impl ExtensionSettings {
|
||||
@@ -60,6 +64,13 @@ impl Settings for ExtensionSettings {
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
allowed_env_var_providers: content
|
||||
.extension
|
||||
.allowed_env_var_providers
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
@@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
@@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
|
||||
capabilities: Vec::new(),
|
||||
debug_adapters: Default::default(),
|
||||
debug_locators: Default::default(),
|
||||
language_model_providers: BTreeMap::default(),
|
||||
}),
|
||||
dev: false,
|
||||
},
|
||||
|
||||
124
crates/extension_host/src/google_ai_migration.rs
Normal file
124
crates/extension_host/src/google_ai_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const GOOGLE_AI_EXTENSION_ID: &str = "google-ai";
|
||||
const GOOGLE_AI_PROVIDER_ID: &str = "google-ai";
|
||||
const GOOGLE_AI_DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
/// Migrates Google AI API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_google_ai_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != GOOGLE_AI_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
GOOGLE_AI_EXTENSION_ID, GOOGLE_AI_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(GOOGLE_AI_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing Google AI API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode Google AI API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing Google AI API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Google AI API key to Google AI extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Google AI API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Google AI API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "AIzaSy-test-key-12345";
|
||||
|
||||
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "AIzaSy-test-key";
|
||||
|
||||
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
124
crates/extension_host/src/open_router_migration.rs
Normal file
124
crates/extension_host/src/open_router_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const OPEN_ROUTER_EXTENSION_ID: &str = "openrouter";
|
||||
const OPEN_ROUTER_PROVIDER_ID: &str = "openrouter";
|
||||
const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
/// Migrates OpenRouter API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != OPEN_ROUTER_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing OpenRouter API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode OpenRouter API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing OpenRouter API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing OpenRouter API key to OpenRouter extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated OpenRouter API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate OpenRouter API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-or-test-key-12345";
|
||||
|
||||
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-or-test-key";
|
||||
|
||||
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
124
crates/extension_host/src/openai_migration.rs
Normal file
124
crates/extension_host/src/openai_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const OPENAI_EXTENSION_ID: &str = "openai";
|
||||
const OPENAI_PROVIDER_ID: &str = "openai";
|
||||
const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
/// Migrates OpenAI API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != OPENAI_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(OPENAI_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing OpenAI API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode OpenAI API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing OpenAI API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing OpenAI API key to OpenAI extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated OpenAI API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate OpenAI API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-test-key-12345";
|
||||
|
||||
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-test-key";
|
||||
|
||||
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
pub mod llm_provider;
|
||||
pub mod wit;
|
||||
|
||||
use crate::capability_granter::CapabilityGranter;
|
||||
use crate::{ExtensionManifest, ExtensionSettings};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
|
||||
use extension::{
|
||||
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
|
||||
@@ -64,7 +66,7 @@ pub struct WasmHost {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WasmExtension {
|
||||
tx: UnboundedSender<ExtensionCall>,
|
||||
tx: Arc<UnboundedSender<ExtensionCall>>,
|
||||
pub manifest: Arc<ExtensionManifest>,
|
||||
pub work_dir: Arc<Path>,
|
||||
#[allow(unused)]
|
||||
@@ -74,7 +76,10 @@ pub struct WasmExtension {
|
||||
|
||||
impl Drop for WasmExtension {
|
||||
fn drop(&mut self) {
|
||||
self.tx.close_channel();
|
||||
// Only close the channel when this is the last clone holding the sender
|
||||
if Arc::strong_count(&self.tx) == 1 {
|
||||
self.tx.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -671,7 +676,7 @@ impl WasmHost {
|
||||
Ok(WasmExtension {
|
||||
manifest,
|
||||
work_dir,
|
||||
tx,
|
||||
tx: Arc::new(tx),
|
||||
zed_api_version,
|
||||
_task: task,
|
||||
})
|
||||
|
||||
1464
crates/extension_host/src/wasm_host/llm_provider.rs
Normal file
1464
crates/extension_host/src/wasm_host/llm_provider.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ use lsp::LanguageServerName;
|
||||
use release_channel::ReleaseChannel;
|
||||
use task::{DebugScenario, SpawnInTerminal, TaskTemplate, ZedDebugConfig};
|
||||
|
||||
use crate::wasm_host::wit::since_v0_6_0::dap::StartDebuggingRequestArgumentsRequest;
|
||||
use crate::wasm_host::wit::since_v0_8_0::dap::StartDebuggingRequestArgumentsRequest;
|
||||
|
||||
use super::{WasmState, wasm_engine};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -33,6 +33,19 @@ pub use latest::CodeLabelSpanLiteral;
|
||||
pub use latest::{
|
||||
CodeLabel, CodeLabelSpan, Command, DebugAdapterBinary, ExtensionProject, Range, SlashCommand,
|
||||
zed::extension::context_server::ContextServerConfiguration,
|
||||
zed::extension::llm_provider::{
|
||||
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
|
||||
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
|
||||
ImageData as LlmImageData, MessageContent as LlmMessageContent,
|
||||
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
|
||||
ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
|
||||
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
|
||||
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
|
||||
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
|
||||
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
|
||||
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
|
||||
ToolUseJsonParseError as LlmToolUseJsonParseError,
|
||||
},
|
||||
zed::extension::lsp::{
|
||||
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
|
||||
},
|
||||
@@ -1007,6 +1020,20 @@ impl Extension {
|
||||
resource: Resource<Arc<dyn WorktreeDelegate>>,
|
||||
) -> Result<Result<DebugAdapterBinary, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let dap_binary = ext
|
||||
.call_get_dap_binary(
|
||||
store,
|
||||
&adapter_name,
|
||||
&task.try_into()?,
|
||||
user_installed_path.as_ref().and_then(|p| p.to_str()),
|
||||
resource,
|
||||
)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(dap_binary))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let dap_binary = ext
|
||||
.call_get_dap_binary(
|
||||
@@ -1032,6 +1059,16 @@ impl Extension {
|
||||
config: serde_json::Value,
|
||||
) -> Result<Result<StartDebuggingRequestArgumentsRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let config =
|
||||
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
|
||||
let result = ext
|
||||
.call_dap_request_kind(store, &adapter_name, &config)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(result))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let config =
|
||||
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
|
||||
@@ -1052,6 +1089,15 @@ impl Extension {
|
||||
config: ZedDebugConfig,
|
||||
) -> Result<Result<DebugScenario, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let config = config.into();
|
||||
let result = ext
|
||||
.call_dap_config_to_scenario(store, &config)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(result.try_into()?))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let config = config.into();
|
||||
let dap_binary = ext
|
||||
@@ -1074,6 +1120,20 @@ impl Extension {
|
||||
debug_adapter_name: String,
|
||||
) -> Result<Option<DebugScenario>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let build_config_template = build_config_template.into();
|
||||
let result = ext
|
||||
.call_dap_locator_create_scenario(
|
||||
store,
|
||||
&locator_name,
|
||||
&build_config_template,
|
||||
&resolved_label,
|
||||
&debug_adapter_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(result.map(TryInto::try_into).transpose()?)
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = build_config_template.into();
|
||||
let dap_binary = ext
|
||||
@@ -1099,6 +1159,15 @@ impl Extension {
|
||||
resolved_build_task: SpawnInTerminal,
|
||||
) -> Result<Result<DebugRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
.call_run_dap_locator(store, &locator_name, &build_config_template)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(dap_request.into()))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
@@ -1111,6 +1180,174 @@ impl Extension {
|
||||
_ => anyhow::bail!("`dap_locator_create_scenario` not available prior to v0.6.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_providers(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
) -> Result<Vec<latest::llm_provider::ProviderInfo>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_providers(store).await,
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_models(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<Vec<latest::llm_provider::ModelInfo>, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_provider_models(store, provider_id).await,
|
||||
_ => anyhow::bail!("`llm_provider_models` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_settings_markdown(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Option<String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_settings_markdown(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_is_authenticated(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<bool> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_is_authenticated(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_start_device_flow_sign_in(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<String, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!(
|
||||
"`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_poll_device_flow_sign_in(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<(), String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!(
|
||||
"`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_reset_credentials(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<(), String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_reset_credentials(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_provider_reset_credentials` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_count_tokens(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request: &latest::llm_provider::CompletionRequest,
|
||||
) -> Result<Result<u64, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_count_tokens(store, provider_id, model_id, request)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_count_tokens` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_start(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request: &latest::llm_provider::CompletionRequest,
|
||||
) -> Result<Result<String, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_stream_completion_start(store, provider_id, model_id, request)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_stream_completion_start` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_next(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
stream_id: &str,
|
||||
) -> Result<Result<Option<latest::llm_provider::CompletionEvent>, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_next(store, stream_id).await,
|
||||
_ => anyhow::bail!("`llm_stream_completion_next` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_close(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
stream_id: &str,
|
||||
) -> Result<()> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_close(store, stream_id).await,
|
||||
_ => anyhow::bail!("`llm_stream_completion_close` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_cache_configuration(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
) -> Result<Option<latest::llm_provider::CacheConfiguration>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_cache_configuration(store, provider_id, model_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait ToWasmtimeResult<T> {
|
||||
|
||||
@@ -32,8 +32,6 @@ wasmtime::component::bindgen!({
|
||||
},
|
||||
});
|
||||
|
||||
pub use self::zed::extension::*;
|
||||
|
||||
mod settings {
|
||||
#![allow(dead_code)]
|
||||
include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs"));
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use crate::wasm_host::wit::since_v0_6_0::{
|
||||
use crate::wasm_host::wit::since_v0_8_0::{
|
||||
dap::{
|
||||
AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
|
||||
StartDebuggingRequestArguments, TcpArguments, TcpArgumentsTemplate,
|
||||
},
|
||||
lsp::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind},
|
||||
slash_command::SlashCommandOutputSection,
|
||||
};
|
||||
use crate::wasm_host::wit::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind};
|
||||
use crate::wasm_host::{WasmState, wit::ToWasmtimeResult};
|
||||
use ::http_client::{AsyncBody, HttpRequestExt};
|
||||
use ::settings::{Settings, WorktreeId};
|
||||
@@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, bail};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_tar::Archive;
|
||||
use async_trait::async_trait;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use extension::{
|
||||
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
|
||||
};
|
||||
@@ -22,12 +23,14 @@ use gpui::{BackgroundExecutor, SharedString};
|
||||
use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
|
||||
use project::project_settings::ProjectSettings;
|
||||
use semver::Version;
|
||||
use smol::net::TcpListener;
|
||||
use std::{
|
||||
env,
|
||||
net::Ipv4Addr,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
sync::{Arc, OnceLock},
|
||||
time::Duration,
|
||||
};
|
||||
use task::{SpawnInTerminal, ZedDebugConfig};
|
||||
use url::Url;
|
||||
@@ -1107,3 +1110,361 @@ impl ExtensionImports for WasmState {
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
}
|
||||
|
||||
impl llm_provider::Host for WasmState {
|
||||
async fn request_credential(
|
||||
&mut self,
|
||||
_provider_id: String,
|
||||
_credential_type: llm_provider::CredentialType,
|
||||
_label: String,
|
||||
_placeholder: String,
|
||||
) -> wasmtime::Result<Result<bool, String>> {
|
||||
// For now, credential requests return false (not provided)
|
||||
// Extensions should use get_env_var to check for env vars first,
|
||||
// then store_credential/get_credential for manual storage
|
||||
// Full UI credential prompting will be added in a future phase
|
||||
Ok(Ok(false))
|
||||
}
|
||||
|
||||
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
|
||||
// Check if this provider has an env var configured and if the user has allowed it
|
||||
let env_var_name = self
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(&Arc::<str>::from(provider_id.as_str()))
|
||||
.and_then(|entry| entry.auth.as_ref())
|
||||
.and_then(|auth| auth.env_var.clone());
|
||||
|
||||
if let Some(env_var_name) = env_var_name {
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
// Read settings dynamically to get current allowed_env_var_providers
|
||||
let is_allowed = self
|
||||
.on_main_thread({
|
||||
let full_provider_id = full_provider_id.clone();
|
||||
move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
crate::extension_settings::ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(&full_provider_id)
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_allowed {
|
||||
if let Ok(value) = env::var(&env_var_name) {
|
||||
if !value.is_empty() {
|
||||
return Ok(Some(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to credential store
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
let result = credentials_provider
|
||||
.read_credentials(&credential_key, cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
Ok(result.map(|(_, password)| String::from_utf8_lossy(&password).to_string()))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_credential(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
value: String,
|
||||
) -> wasmtime::Result<Result<(), String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
credentials_provider
|
||||
.write_credentials(&credential_key, "api_key", value.as_bytes(), cx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn delete_credential(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
) -> wasmtime::Result<Result<(), String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
credentials_provider
|
||||
.delete_credentials(&credential_key, cx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
|
||||
// Find which provider (if any) declares this env var in its auth config
|
||||
let mut allowed_provider_id: Option<Arc<str>> = None;
|
||||
for (provider_id, provider_entry) in &self.manifest.language_model_providers {
|
||||
if let Some(auth_config) = &provider_entry.auth {
|
||||
if auth_config.env_var.as_deref() == Some(&name) {
|
||||
allowed_provider_id = Some(provider_id.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no provider declares this env var, deny access
|
||||
let Some(provider_id) = allowed_provider_id else {
|
||||
log::warn!(
|
||||
"Extension {} attempted to read env var {} which is not declared in any provider auth config",
|
||||
extension_id,
|
||||
name
|
||||
);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Check if the user has allowed this provider to read env vars
|
||||
// Read settings dynamically to get current allowed_env_var_providers
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
let is_allowed = self
|
||||
.on_main_thread({
|
||||
let full_provider_id = full_provider_id.clone();
|
||||
move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
crate::extension_settings::ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(&full_provider_id)
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_allowed {
|
||||
log::debug!(
|
||||
"Extension {} provider {} is not allowed to read env var {}",
|
||||
extension_id,
|
||||
provider_id,
|
||||
name
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(env::var(&name).ok())
|
||||
}
|
||||
|
||||
async fn oauth_start_web_auth(
|
||||
&mut self,
|
||||
config: llm_provider::OauthWebAuthConfig,
|
||||
) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
|
||||
let auth_url = config.auth_url;
|
||||
let callback_path = config.callback_path;
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(300);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
|
||||
.port();
|
||||
|
||||
let auth_url_with_port = auth_url.replace("{port}", &port.to_string());
|
||||
cx.update(|cx| {
|
||||
cx.open_url(&auth_url_with_port);
|
||||
})?;
|
||||
|
||||
let accept_future = async {
|
||||
let (mut stream, _) = listener
|
||||
.accept()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
|
||||
|
||||
let mut request_line = String::new();
|
||||
{
|
||||
let mut reader = smol::io::BufReader::new(&mut stream);
|
||||
smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
|
||||
}
|
||||
|
||||
let callback_url = if let Some(path_start) = request_line.find(' ') {
|
||||
if let Some(path_end) = request_line[path_start + 1..].find(' ') {
|
||||
let path = &request_line[path_start + 1..path_start + 1 + path_end];
|
||||
if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) {
|
||||
format!("http://localhost:{}{}", port, path)
|
||||
} else {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unexpected callback path: {}",
|
||||
path
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Malformed HTTP request"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Malformed HTTP request"));
|
||||
};
|
||||
|
||||
let response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/html\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n\
|
||||
<!DOCTYPE html>\
|
||||
<html><head><title>Authentication Complete</title></head>\
|
||||
<body style=\"font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;\">\
|
||||
<div style=\"text-align: center;\">\
|
||||
<h1>Authentication Complete</h1>\
|
||||
<p>You can close this window and return to Zed.</p>\
|
||||
</div></body></html>";
|
||||
|
||||
smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
|
||||
.await
|
||||
.ok();
|
||||
smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
|
||||
|
||||
Ok(callback_url)
|
||||
};
|
||||
|
||||
let timeout_duration = Duration::from_secs(timeout_secs as u64);
|
||||
let callback_url = smol::future::or(
|
||||
accept_future,
|
||||
async {
|
||||
smol::Timer::after(timeout_duration).await;
|
||||
Err(anyhow::anyhow!(
|
||||
"OAuth callback timed out after {} seconds",
|
||||
timeout_secs
|
||||
))
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(llm_provider::OauthWebAuthResult {
|
||||
callback_url,
|
||||
port: port as u32,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn send_oauth_http_request(
|
||||
&mut self,
|
||||
request: llm_provider::OauthHttpRequest,
|
||||
) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
|
||||
let http_client = self.host.http_client.clone();
|
||||
|
||||
self.on_main_thread(move |_cx| {
|
||||
async move {
|
||||
let method = match request.method.to_uppercase().as_str() {
|
||||
"GET" => ::http_client::Method::GET,
|
||||
"POST" => ::http_client::Method::POST,
|
||||
"PUT" => ::http_client::Method::PUT,
|
||||
"DELETE" => ::http_client::Method::DELETE,
|
||||
"PATCH" => ::http_client::Method::PATCH,
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Unsupported HTTP method: {}",
|
||||
request.method
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut builder = ::http_client::Request::builder()
|
||||
.method(method)
|
||||
.uri(&request.url);
|
||||
|
||||
for (key, value) in &request.headers {
|
||||
builder = builder.header(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let body = if request.body.is_empty() {
|
||||
AsyncBody::empty()
|
||||
} else {
|
||||
AsyncBody::from(request.body.into_bytes())
|
||||
};
|
||||
|
||||
let http_request = builder
|
||||
.body(body)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
|
||||
|
||||
let mut response = http_client
|
||||
.send(http_request)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let headers: Vec<(String, String)> = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let mut body_bytes = Vec::new();
|
||||
futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
|
||||
|
||||
let body = String::from_utf8_lossy(&body_bytes).to_string();
|
||||
|
||||
Ok(llm_provider::OauthHttpResponse {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn oauth_open_browser(&mut self, url: String) -> wasmtime::Result<Result<(), String>> {
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
cx.open_url(&url);
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,7 +442,9 @@ impl ExtensionsPage {
|
||||
let extension_store = ExtensionStore::global(cx).read(cx);
|
||||
|
||||
match extension_store.outstanding_operations().get(extension_id) {
|
||||
Some(ExtensionOperation::Install) => ExtensionStatus::Installing,
|
||||
Some(ExtensionOperation::Install) | Some(ExtensionOperation::AutoInstall) => {
|
||||
ExtensionStatus::Installing
|
||||
}
|
||||
Some(ExtensionOperation::Remove) => ExtensionStatus::Removing,
|
||||
Some(ExtensionOperation::Upgrade) => ExtensionStatus::Upgrading,
|
||||
None => match extension_store.installed_extensions().get(extension_id) {
|
||||
|
||||
@@ -232,14 +232,12 @@ impl From<Oid> for usize {
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RunHook {
|
||||
PreCommit,
|
||||
PrePush,
|
||||
}
|
||||
|
||||
impl RunHook {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::PreCommit => "pre-commit",
|
||||
Self::PrePush => "pre-push",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +248,6 @@ impl RunHook {
|
||||
pub fn from_proto(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::PreCommit),
|
||||
1 => Some(Self::PrePush),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,6 +652,7 @@ pub struct RealGitRepository {
|
||||
pub repository: Arc<Mutex<git2::Repository>>,
|
||||
pub system_git_binary_path: Option<PathBuf>,
|
||||
pub any_git_binary_path: PathBuf,
|
||||
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
@@ -670,6 +671,7 @@ impl RealGitRepository {
|
||||
system_git_binary_path,
|
||||
any_git_binary_path,
|
||||
executor,
|
||||
any_git_binary_help_output: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -680,6 +682,27 @@ impl RealGitRepository {
|
||||
.context("failed to read git work directory")
|
||||
.map(Path::to_path_buf)
|
||||
}
|
||||
|
||||
async fn any_git_binary_help_output(&self) -> SharedString {
|
||||
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
|
||||
return output;
|
||||
}
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
let working_directory = self.working_directory();
|
||||
let output: SharedString = self
|
||||
.executor
|
||||
.spawn(async move {
|
||||
GitBinary::new(git_binary_path, working_directory?, executor)
|
||||
.run(["help", "-a"])
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
*self.any_git_binary_help_output.lock() = Some(output.clone());
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -2290,18 +2313,47 @@ impl GitRepository for RealGitRepository {
|
||||
env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
let working_directory = self.working_directory();
|
||||
let repository = self.repository.clone();
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let working_directory = working_directory?;
|
||||
let git = GitBinary::new(git_binary_path, working_directory, executor)
|
||||
.envs(HashMap::clone(&env));
|
||||
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.boxed()
|
||||
let help_output = self.any_git_binary_help_output();
|
||||
|
||||
// Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang.
|
||||
async move {
|
||||
let working_directory = working_directory?;
|
||||
if !help_output
|
||||
.await
|
||||
.lines()
|
||||
.any(|line| line.trim().starts_with("hook "))
|
||||
{
|
||||
let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
|
||||
if hook_abs_path.is_file() {
|
||||
let output = new_smol_command(&hook_abs_path)
|
||||
.envs(env.iter())
|
||||
.current_dir(&working_directory)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(GitBinaryCommandError {
|
||||
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
|
||||
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
|
||||
status: output.status,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let git = GitBinary::new(git_binary_path, working_directory, executor)
|
||||
.envs(HashMap::clone(&env));
|
||||
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,6 +24,7 @@ pub struct GitPanelSettings {
|
||||
pub fallback_branch_name: String,
|
||||
pub sort_by_path: bool,
|
||||
pub collapse_untracked_diff: bool,
|
||||
pub tree_view: bool,
|
||||
}
|
||||
|
||||
impl ScrollbarVisibility for GitPanelSettings {
|
||||
@@ -56,6 +57,7 @@ impl Settings for GitPanelSettings {
|
||||
fallback_branch_name: git_panel.fallback_branch_name.unwrap(),
|
||||
sort_by_path: git_panel.sort_by_path.unwrap(),
|
||||
collapse_untracked_diff: git_panel.collapse_untracked_diff.unwrap(),
|
||||
tree_view: git_panel.tree_view.unwrap(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user